예제 #1
0
    def test_tunnel_protocol_ignores_warnings(self):
        """
        When using ssh to setup the tunnel, we typically recieve a unknown host
        key warning message on stderr, the tunnel protocol will filter these
        messages.
        """
        log = self.capture_logging(level=logging.ERROR)
        protocol = ClientTunnelProtocol(None, None)

        protocol.errReceived("Warning: Permanently added")
        self.assertEqual(log.getvalue(), "")
예제 #2
0
    def test_tunnel_protocol_notes_connection_refused(self):
        client = self.mocker.mock()
        client.close()
        self.mocker.replay()

        tunnel_deferred = Deferred()
        protocol = ClientTunnelProtocol(client, tunnel_deferred)

        message = "blah blah blah Connection refused blah blah"
        protocol.errReceived(message)

        expected_message = "Connection refused"

        def verify_failure(error):
            self.assertEqual(error.args[0], expected_message)

        self.failUnlessFailure(tunnel_deferred, NoConnection)
        tunnel_deferred.addCallback(verify_failure)
        return tunnel_deferred
예제 #3
0
    def test_tunnel_protocol_notes_invalid_key(self):
        log = self.capture_logging(level=logging.ERROR)
        client = self.mocker.mock()
        client.close()
        self.mocker.replay()

        tunnel_deferred = Deferred()
        protocol = ClientTunnelProtocol(client, tunnel_deferred)

        message = "Permission denied"
        protocol.errReceived(message)

        expected_message = "Invalid SSH key"

        def verify_failure(error):
            self.assertEqual(error.args[0], expected_message)
            self.assertEqual(log.getvalue().strip(), expected_message)

        self.failUnlessFailure(tunnel_deferred, NoConnection)
        tunnel_deferred.addCallback(verify_failure)
        return tunnel_deferred
예제 #4
0
    def test_tunnel_protocol_notes_invalid_host(self):
        log = self.capture_logging(level=logging.ERROR)
        client = self.mocker.mock()
        client.close()
        self.mocker.replay()

        tunnel_deferred = Deferred()
        protocol = ClientTunnelProtocol(client, tunnel_deferred)

        message = "ssh: Could not resolve hostname magicbean"
        protocol.errReceived(message)

        expected_message = "Invalid host for SSH forwarding: %s" % message

        def verify_failure(error):
            self.assertEqual(error.args[0], expected_message)
            self.assertEqual(log.getvalue().strip(), expected_message)

        self.assertFailure(tunnel_deferred, NoConnection)
        tunnel_deferred.addCallback(verify_failure)
        return tunnel_deferred
예제 #5
0
    def test_tunnel_protocol_closes_on_error(self):
        """
        When using ssh to setup the tunnel, we typically recieve a unknown host
        key warning message on stderr, the tunnel protocol will filter these
        messages.
        """
        log = self.capture_logging(level=logging.ERROR)
        client = self.mocker.mock()
        tunnel_deferred = Deferred()
        protocol = ClientTunnelProtocol(client, tunnel_deferred)
        client.close()
        self.mocker.replay()

        protocol.errReceived("badness")

        def verify_failure(error):
            self.assertEqual(error.args[0], "SSH forwarding error: badness")
            self.assertEqual(log.getvalue().strip(),
                             "SSH forwarding error: badness")

        self.assertFailure(tunnel_deferred, NoConnection)
        tunnel_deferred.addCallback(verify_failure)
        return tunnel_deferred
예제 #6
0
    def _internal_connect(self, server, timeout, share=False):
        """Connect to the remote host provided via an ssh port forward.

        An SSH process is fired with port forwarding established on localhost
        22181, which the zookeeper client connects to.

        :param server: Remote host to connect to, specified as hostname:port
        :type string

        :param timeout: An timeout interval in seconds.
        :type float

        Returns a connected client or error.
        """
        hostname, port = self._parse_servers(server or self._servers)
        start_time = time.time()

        # Determine which port we'll be using.
        local_port = get_open_port()
        port_watcher = PortWatcher("localhost", local_port, timeout)

        tunnel_error = Deferred()
        # On a tunnel error, stop the port watch early and bail with error.
        tunnel_error.addErrback(port_watcher.stop)
        # If a tunnel error happens now or later, close the connection.
        tunnel_error.addErrback(lambda x: self.close())

        # Setup tunnel via an ssh process for port forwarding.
        protocol = ClientTunnelProtocol(self, tunnel_error)
        self._process = forward_port(self.remote_user,
                                     local_port,
                                     hostname,
                                     int(port),
                                     process_protocol=protocol,
                                     share=share)

        # Wait for the tunneled port to open.
        try:
            yield port_watcher.async_wait()
        except socket.error:
            self.close()  # Stop the tunnel process.
            raise ConnectionTimeoutException("could not connect")
        else:
            # If we stopped because of a tunnel error, raise it.
            if protocol.error:
                yield tunnel_error

        # Check timeout
        new_timeout = timeout - (time.time() - start_time)
        if new_timeout <= 0:
            self.close()
            raise ConnectionTimeoutException(
                "could not connect before timeout")

        # Connect the client
        try:
            yield super(SSHClient, self).connect("localhost:%d" % local_port,
                                                 new_timeout)
        except:
            self.close()  # Stop the tunnel
            raise

        returnValue(self)