def test_tunnel_protocol_ignores_warnings(self): """ When using ssh to setup the tunnel, we typically recieve a unknown host key warning message on stderr, the tunnel protocol will filter these messages. """ log = self.capture_logging(level=logging.ERROR) protocol = ClientTunnelProtocol(None, None) protocol.errReceived("Warning: Permanently added") self.assertEqual(log.getvalue(), "")
def test_tunnel_protocol_notes_connection_refused(self): client = self.mocker.mock() client.close() self.mocker.replay() tunnel_deferred = Deferred() protocol = ClientTunnelProtocol(client, tunnel_deferred) message = "blah blah blah Connection refused blah blah" protocol.errReceived(message) expected_message = "Connection refused" def verify_failure(error): self.assertEqual(error.args[0], expected_message) self.failUnlessFailure(tunnel_deferred, NoConnection) tunnel_deferred.addCallback(verify_failure) return tunnel_deferred
def test_tunnel_protocol_notes_invalid_key(self): log = self.capture_logging(level=logging.ERROR) client = self.mocker.mock() client.close() self.mocker.replay() tunnel_deferred = Deferred() protocol = ClientTunnelProtocol(client, tunnel_deferred) message = "Permission denied" protocol.errReceived(message) expected_message = "Invalid SSH key" def verify_failure(error): self.assertEqual(error.args[0], expected_message) self.assertEqual(log.getvalue().strip(), expected_message) self.failUnlessFailure(tunnel_deferred, NoConnection) tunnel_deferred.addCallback(verify_failure) return tunnel_deferred
def test_tunnel_protocol_notes_invalid_host(self): log = self.capture_logging(level=logging.ERROR) client = self.mocker.mock() client.close() self.mocker.replay() tunnel_deferred = Deferred() protocol = ClientTunnelProtocol(client, tunnel_deferred) message = "ssh: Could not resolve hostname magicbean" protocol.errReceived(message) expected_message = "Invalid host for SSH forwarding: %s" % message def verify_failure(error): self.assertEqual(error.args[0], expected_message) self.assertEqual(log.getvalue().strip(), expected_message) self.assertFailure(tunnel_deferred, NoConnection) tunnel_deferred.addCallback(verify_failure) return tunnel_deferred
def test_tunnel_protocol_closes_on_error(self): """ When using ssh to setup the tunnel, we typically recieve a unknown host key warning message on stderr, the tunnel protocol will filter these messages. """ log = self.capture_logging(level=logging.ERROR) client = self.mocker.mock() tunnel_deferred = Deferred() protocol = ClientTunnelProtocol(client, tunnel_deferred) client.close() self.mocker.replay() protocol.errReceived("badness") def verify_failure(error): self.assertEqual(error.args[0], "SSH forwarding error: badness") self.assertEqual(log.getvalue().strip(), "SSH forwarding error: badness") self.assertFailure(tunnel_deferred, NoConnection) tunnel_deferred.addCallback(verify_failure) return tunnel_deferred
def _internal_connect(self, server, timeout, share=False): """Connect to the remote host provided via an ssh port forward. An SSH process is fired with port forwarding established on localhost 22181, which the zookeeper client connects to. :param server: Remote host to connect to, specified as hostname:port :type string :param timeout: An timeout interval in seconds. :type float Returns a connected client or error. """ hostname, port = self._parse_servers(server or self._servers) start_time = time.time() # Determine which port we'll be using. local_port = get_open_port() port_watcher = PortWatcher("localhost", local_port, timeout) tunnel_error = Deferred() # On a tunnel error, stop the port watch early and bail with error. tunnel_error.addErrback(port_watcher.stop) # If a tunnel error happens now or later, close the connection. tunnel_error.addErrback(lambda x: self.close()) # Setup tunnel via an ssh process for port forwarding. protocol = ClientTunnelProtocol(self, tunnel_error) self._process = forward_port(self.remote_user, local_port, hostname, int(port), process_protocol=protocol, share=share) # Wait for the tunneled port to open. try: yield port_watcher.async_wait() except socket.error: self.close() # Stop the tunnel process. raise ConnectionTimeoutException("could not connect") else: # If we stopped because of a tunnel error, raise it. if protocol.error: yield tunnel_error # Check timeout new_timeout = timeout - (time.time() - start_time) if new_timeout <= 0: self.close() raise ConnectionTimeoutException( "could not connect before timeout") # Connect the client try: yield super(SSHClient, self).connect("localhost:%d" % local_port, new_timeout) except: self.close() # Stop the tunnel raise returnValue(self)