Ejemplo n.º 1
0
class AmqpSocket(object):
    """
    Associates a transport with a connection and a socket and can be
    used in an io loop to track the io for an AMQP 1.0 connection.
    """
    def __init__(self, conn, sock, events, heartbeat=None):
        self.events = events
        self.conn = conn
        self.transport = Transport()
        if heartbeat: self.transport.idle_timeout = heartbeat
        self.transport.bind(self.conn)
        self.socket = sock
        self.socket.setblocking(0)
        self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
        self.write_done = False
        self.read_done = False
        self._closed = False

    def accept(self, force_sasl=True, ssl_domain=None):
        if ssl_domain:
            self.ssl = SSL(self.transport, ssl_domain)
        if force_sasl:
            sasl = self.transport.sasl()
            sasl.mechanisms("ANONYMOUS")
            sasl.server()
            sasl.done(SASL.OK)
        #TODO: use SASL anyway if requested by peer
        return self

    def connect(self,
                host,
                port=None,
                username=None,
                password=None,
                force_sasl=True,
                ssl_domain=None):
        if ssl_domain:
            self.ssl = SSL(self.transport, ssl_domain)
            self.ssl.peer_hostname = host
        if username and password:
            sasl = self.transport.sasl()
            sasl.plain(username, password)
        elif force_sasl:
            sasl = self.transport.sasl()
            sasl.mechanisms('ANONYMOUS')
            sasl.client()
        try:
            self.socket.connect_ex((host, port or 5672))
        except socket.gaierror, e:
            raise ConnectionException("Cannot resolve '%s': %s" % (host, e))
        return self
Ejemplo n.º 2
0
    def _connect(self, connection):
        url = self.address.next()
        # IoHandler uses the hostname to determine where to try to connect to
        connection.hostname = "%s:%s" % (url.host, url.port)
        logging.info("connecting to %s..." % connection.hostname)

        transport = Transport()
        if self.sasl_enabled:
            sasl = transport.sasl()
            sasl.allow_insecure_mechs = self.allow_insecure_mechs
            if url.username:
                connection.user = url.username
            elif self.user:
                connection.user = self.user
            if url.password:
                connection.password = url.password
            elif self.password:
                connection.password = self.password
            if self.allowed_mechs:
                sasl.allowed_mechs(self.allowed_mechs)
        transport.bind(connection)
        if self.heartbeat:
            transport.idle_timeout = self.heartbeat
        if url.scheme == 'amqps':
            if not self.ssl_domain:
                raise SSLUnavailable("amqps: SSL libraries not found")
            self.ssl = SSL(transport, self.ssl_domain)
            self.ssl.peer_hostname = url.host
Ejemplo n.º 3
0
    def _connect(self, connection, reactor):
        assert (reactor is not None)
        url = self.address.next()
        reactor.set_connection_host(connection, url.host, str(url.port))
        # if virtual-host not set, use host from address as default
        if self.virtual_host is None:
            connection.hostname = url.host
        log.debug("connecting to %s..." % url)

        transport = Transport()
        if self.sasl_enabled:
            sasl = transport.sasl()
            sasl.allow_insecure_mechs = self.allow_insecure_mechs
            if url.username:
                connection.user = url.username
            elif self.user:
                connection.user = self.user
            if url.password:
                connection.password = url.password
            elif self.password:
                connection.password = self.password
            if self.allowed_mechs:
                sasl.allowed_mechs(self.allowed_mechs)
        transport.bind(connection)
        if self.heartbeat:
            transport.idle_timeout = self.heartbeat
        if url.scheme == 'amqps':
            if not self.ssl_domain:
                raise SSLUnavailable("amqps: SSL libraries not found")
            self.ssl = SSL(transport, self.ssl_domain)
            self.ssl.peer_hostname = self.ssl_sni or self.virtual_host or url.host
        if self.max_frame_size:
            transport.max_frame_size = self.max_frame_size
Ejemplo n.º 4
0
    def _connect(self, connection, reactor):
        assert(reactor is not None)
        url = self.address.next()
        reactor.set_connection_host(connection, url.host, str(url.port))
        logging.debug("connecting to %s..." % url)

        transport = Transport()
        if self.sasl_enabled:
            sasl = transport.sasl()
            sasl.allow_insecure_mechs = self.allow_insecure_mechs
            if url.username:
                connection.user = url.username
            elif self.user:
                connection.user = self.user
            if url.password:
                connection.password = url.password
            elif self.password:
                connection.password = self.password
            if self.allowed_mechs:
                sasl.allowed_mechs(self.allowed_mechs)
        transport.bind(connection)
        if self.heartbeat:
            transport.idle_timeout = self.heartbeat
        if url.scheme == 'amqps':
            if not self.ssl_domain:
                raise SSLUnavailable("amqps: SSL libraries not found")
            self.ssl = SSL(transport, self.ssl_domain)
            self.ssl.peer_hostname = url.host
Ejemplo n.º 5
0
    def _connect(self, connection):
        url = self.address.next()
        # IoHandler uses the hostname to determine where to try to connect to
        connection.hostname = "%s:%s" % (url.host, url.port)
        logging.info("connecting to %s..." % connection.hostname)

        transport = Transport()
        if self.sasl_enabled:
            sasl = transport.sasl()
            sasl.allow_insecure_mechs = self.allow_insecure_mechs
            if url.username:
                connection.user = url.username
            elif self.user:
                connection.user = self.user
            if url.password:
                connection.password = url.password
            elif self.password:
                connection.password = self.password
            if self.allowed_mechs:
                sasl.allowed_mechs(self.allowed_mechs)
        transport.bind(connection)
        if self.heartbeat:
            transport.idle_timeout = self.heartbeat
        if url.scheme == 'amqps' and self.ssl_domain:
            self.ssl = SSL(transport, self.ssl_domain)
            self.ssl.peer_hostname = url.host
Ejemplo n.º 6
0
    def _connect(self, connection):
        url = self.address.next()
        # IoHandler uses the hostname to determine where to try to connect to
        connection.hostname = "%s:%s" % (url.host, url.port)
        logging.info("connecting to %s..." % connection.hostname)

        transport = Transport()
        sasl = None
        if url.username:
            connection.user = url.username
            sasl = transport.sasl()
            sasl.allow_insecure_mechs = self.allow_insecure_mechs
        if url.password:
            connection.password = url.password
        if self.allowed_mechs:
            if sasl == None:
                sasl = transport.sasl()
            sasl.allowed_mechs(self.allowed_mechs)
        transport.bind(connection)
        if self.heartbeat:
            transport.idle_timeout = self.heartbeat
        if url.scheme == 'amqps' and self.ssl_domain:
            self.ssl = SSL(transport, self.ssl_domain)
            self.ssl.peer_hostname = url.host
Ejemplo n.º 7
0
    def _connect(self, connection):
        url = self.address.next()
        # IoHandler uses the hostname to determine where to try to connect to
        connection.hostname = "%s:%i" % (url.host, url.port)
        logging.info("connecting to %s..." % connection.hostname)

        transport = Transport()
        transport.bind(connection)
        if self.heartbeat:
            transport.idle_timeout = self.heartbeat
        if url.scheme == 'amqps' and self.ssl_domain:
            self.ssl = SSL(transport, self.ssl_domain)
            self.ssl.peer_hostname = url.host
        if url.username:
            sasl = transport.sasl()
            if url.username == 'anonymous':
                sasl.mechanisms('ANONYMOUS')
            else:
                sasl.plain(url.username, url.password)
Ejemplo n.º 8
0
class Connection(object):
    """Provides network I/O for a Proton connection via a socket-like object.
    """
    def __init__(self, socket, eventHandler, name=None):
        """socket - Python socket. Expected to be configured and connected.
        name - optional name for this SocketTransport
        """
        self._name = name
        self._socket = socket
        self._pn_transport = Transport()
        self._pn_connection = Connection()
        self._pn_transport.bind(self._pn_connection)
        self._pn_transport.trace(proton.Transport.TRACE_FRM)
        self._handler = eventHandler
        self._read_done = False
        self._write_done = False
        self._next_tick = 0

        self._sasl_done = False
        self._pn_connection.open()


    def fileno(self):
        """Allows use of a Connection by the python select() call.
        """
        return self._socket.fileno()

    @property
    # @todo - hopefully remove
    def transport(self):
        return self._pn_transport

    @property
    # @todo - hopefully remove
    def connection(self):
        return self._pn_connection

    @property
    def socket(self):
        return self._socket

    @property
    def name(self):
        return self._name

    @property
    # @todo - think about server side use of this!
    def sasl(self):
        return self._pn_transport.sasl()

    def tick(self, now):
        """Invoke the transport's tick method.  'now' is seconds since Epoch.
        Returns the timestamp for the next tick, or zero if no next tick.
        """
        self._next_tick = self._pn_transport.tick(now)
        return self._next_tick

    @property
    def next_tick(self):
        """Timestamp for next call to tick()
        """
        return self._next_tick

    @property
    def need_read(self):
        """True when the Transport requires data from the network layer.
        """
        return (not self._read_done) and self._pn_transport.capacity() > 0

    @property
    def need_write(self):
        """True when the Transport has data to write to the network layer.
        """
        return (not self._write_done) and self._pn_transport.pending() > 0

    def read_input(self):
        """Read from the network layer and processes all data read.  Can
        support both blocking and non-blocking sockets.
        """
        if self._read_done:
            return None

        c = self._pn_transport.capacity()
        if c < 0:
            try:
                self._socket.shutdown(socket.SHUT_RD)
            except:
                pass
            self._read_done = True
            return None

        if c > 0:
            try:
                buf = self._socket.recv(c)
                if buf:
                    self._pn_transport.push(buf)
                    return len(buf)
                # else socket closed
                self._pn_transport.close_tail()
                self._read_done = True
                return None
            except socket.timeout, e:
                raise  # let the caller handle this
            except socket.error, e:
                err = e.args[0]
                if (err != errno.EAGAIN and
                    err != errno.EWOULDBLOCK and
                    err != errno.EINTR):
                    # otherwise, unrecoverable:
                    self._pn_transport.close_tail()
                    self._read_done = True
                raise
            except:  # beats me...