def setUp(self):
        super(ConnectionStrategyTest, self).setUp()

        self.connection = mock()
        self.connection.logger = mock()
        self.strategy = ConnectionStrategy(self.connection, 'localhost')
 def test_init_with_reconnect_cb(self):
     strategy = ConnectionStrategy(self.connection,
                                   'localhost',
                                   reconnect_cb='my_reconnect_callback')
     self.assertEqual(['my_reconnect_callback'],
                      strategy.reconnect_callbacks)
Exemple #3
0
  def __init__(self, **kwargs):
    '''
    Initialize the connection.
    '''
    self._debug = kwargs.get('debug', False)
    self._logger = kwargs.get('logger', root_logger)

    self._user = kwargs.get('user', 'guest')
    self._password = kwargs.get('password', 'guest')
    self._host = kwargs.get('host', 'localhost')
    self._vhost = kwargs.get('vhost', '/')

    self._connect_timeout = kwargs.get('connect_timeout', 5)
    self._sock_opts = kwargs.get('sock_opts')
    self._sock = None
    self._heartbeat = kwargs.get('heartbeat')
    self._reconnect_cb = kwargs.get('reconnect_cb')
    self._close_cb = kwargs.get('close_cb')

    self._login_method = kwargs.get('login_method', 'AMQPLAIN')
    self._locale = kwargs.get('locale', 'en_US')
    self._client_properties = kwargs.get('client_properties')

    self._properties = LIBRARY_PROPERTIES.copy()
    if self._client_properties:
      self._properties.update( self._client_properties )

    self._closed = False
    self._connected = False
    self._close_info = {
      'reply_code'    : 0,
      'reply_text'    : 'first connect',
      'class_id'      : 0,
      'method_id'     : 0
    }
    
    self._channels = {
      0 : ConnectionChannel(self, 0)
    } 
    
    login_response = Writer()
    login_response.write_table({'LOGIN': self._user, 'PASSWORD': self._password})
    #stream = BytesIO()
    #login_response.flush(stream)
    #self._login_response = stream.getvalue()[4:]  #Skip the length
                                                      #at the beginning
    self._login_response = login_response.buffer()[4:]
    
    self._channel_counter = 0
    self._channel_max = 65535
    self._frame_max = 65535

    self._frames_read = 0
    self._frames_written = 0

    self._strategy = kwargs.get('connection_strategy')
    if not self._strategy:
      self._strategy = ConnectionStrategy( self, self._host, reconnect_cb = self._reconnect_cb )
    self._strategy.connect()

    self._output_frame_buffer = []
class ConnectionStrategyTest(Chai):
    def setUp(self):
        super(ConnectionStrategyTest, self).setUp()

        self.connection = mock()
        self.connection.logger = mock()
        self.strategy = ConnectionStrategy(self.connection, 'localhost')

    def test_init_is_doin_it_right(self):
        self.assertEquals(self.strategy._connection, self.connection)
        self.assertEquals(Host(socket.gethostname()), self.strategy._orig_host)
        self.assertEquals([self.strategy._orig_host],
                          self.strategy._known_hosts)
        self.assertEquals(self.strategy._orig_host, self.strategy._cur_host)
        self.assertFalse(self.strategy._reconnecting)
        self.assertEqual([], self.strategy.reconnect_callbacks)

    def test_init_with_reconnect_cb(self):
        strategy = ConnectionStrategy(self.connection,
                                      'localhost',
                                      reconnect_cb='my_reconnect_callback')
        self.assertEqual(['my_reconnect_callback'],
                         strategy.reconnect_callbacks)

    def test_set_known_hosts_is_single_entry(self):
        self.assertEquals([self.strategy._orig_host],
                          self.strategy._known_hosts)
        self.strategy.set_known_hosts(socket.gethostname())
        self.assertEquals([self.strategy._orig_host],
                          self.strategy._known_hosts)

    def test_set_known_hosts_updates_list_correctly(self):
        self.assertEquals([self.strategy._orig_host],
                          self.strategy._known_hosts)
        self.strategy.set_known_hosts('localhost:4200,localhost,foo:1234')

        self.assertEquals(
            [Host('localhost'),
             Host('localhost:4200'),
             Host('foo:1234')], self.strategy._known_hosts)

    def test_set_known_hosts_handles_misconfigured_cluster(self):
        self.strategy._cur_host = Host('bar')
        self.strategy._orig_host = Host('foo:5678')
        self.strategy._known_hosts = [self.strategy._orig_host]

        expect(self.connection.logger.warning).args(
            "current host %s not in known hosts %s, reconnecting to %s in %ds!",
            self.strategy._cur_host,
            [Host('foo:5678'), Host('foo:1234')], self.strategy._orig_host, 5)

        expect(self.strategy.connect).args(5)

        self.strategy.set_known_hosts('foo:1234')

    def test_next_host_handles_simple_base_case(self):
        self.strategy._cur_host = Host('localhost')
        self.strategy._known_hosts = [Host('localhost'), Host('foo')]

        expect(self.strategy.connect)

        self.strategy.next_host()
        self.assertEquals(Host('foo'), self.strategy._cur_host)
        self.assertFalse(self.strategy._reconnecting)

    def test_next_host_finds_first_unconnected_host(self):
        self.strategy._cur_host = Host('localhost')
        self.strategy._known_hosts = [
            Host('localhost'), Host('foo'),
            Host('bar')
        ]
        self.strategy._known_hosts[0].state = CONNECTED
        self.strategy._known_hosts[1].state = CONNECTED

        expect(self.strategy.connect)

        self.strategy.next_host()
        self.assertEquals(Host('bar'), self.strategy._cur_host)
        self.assertFalse(self.strategy._reconnecting)

    def test_next_host_searches_for_unfailed_hosts_if_all_hosts_not_unconnected(
            self):
        self.strategy._cur_host = Host('foo')
        self.strategy._known_hosts = [
            Host('localhost'),
            Host('foo'),
            Host('bar'),
            Host('cat')
        ]
        self.strategy._known_hosts[0].state = CONNECTED
        self.strategy._known_hosts[1].state = CONNECTED
        self.strategy._known_hosts[2].state = CONNECTED
        self.strategy._known_hosts[3].state = CONNECTED

        expect(self.strategy.connect)

        self.strategy.next_host()
        self.assertEquals(Host('localhost'), self.strategy._cur_host)

    def test_next_host_searches_for_unfailed_hosts_even_if_orig_host_is_failed(
            self):
        self.strategy._cur_host = Host('foo')
        self.strategy._known_hosts = [
            Host('localhost'),
            Host('foo'),
            Host('bar'),
            Host('cat')
        ]
        self.strategy._known_hosts[0].state = FAILED
        self.strategy._known_hosts[1].state = CONNECTED
        self.strategy._known_hosts[2].state = CONNECTED
        self.strategy._known_hosts[3].state = CONNECTED

        expect(self.strategy.connect)

        self.strategy.next_host()
        self.assertEquals(Host('foo'), self.strategy._cur_host)

    def test_next_host_defaults_to_original_with_delay_if_all_hosts_failed(
            self):
        self.strategy._orig_host = Host('foo')
        self.strategy._cur_host = Host('bar')
        self.strategy._known_hosts = [
            Host('foo'), Host('bar'),
            Host('cat'), Host('dog')
        ]
        self.strategy._known_hosts[0].state = FAILED
        self.strategy._known_hosts[1].state = FAILED
        self.strategy._known_hosts[2].state = FAILED
        self.strategy._known_hosts[3].state = FAILED

        expect(self.connection.logger.warning).args(
            'Failed to connect to any of %s, will retry %s in %d seconds',
            self.strategy._known_hosts, self.strategy._orig_host, 5)
        expect(self.strategy.connect).args(5)

        self.strategy.next_host()
        self.assertEquals(Host('foo'), self.strategy._cur_host)
        self.assertTrue(self.strategy._reconnecting)

    def test_fail_is_not_stoopehd(self):
        self.strategy._cur_host = Host('foo')
        self.assertEquals(UNCONNECTED, self.strategy._cur_host.state)

        stub(self.strategy.connect)
        stub(self.strategy.next_host)

        self.strategy.fail()
        self.assertEquals(FAILED, self.strategy._cur_host.state)


# FIXME: These tests need to be fixed and converted to Chai
#  def test_connect_basics(self):
#    self.strategy._pending_connect = None

#    self.mox.StubOutWithMock( self.connection, 'disconnect' )
#    self.connection.disconnect()

#    connection_strategy.event.timeout(0, self.strategy._connect_cb).AndReturn('foo')

#    self.mox.ReplayAll()
#    self.strategy.connect()
#    self.assertTrue( self.strategy._pending_connect, 'foo' )

#  def test_connect_handles_disconnect_errors(self):
#    self.strategy._pending_connect = None

#    self.mox.StubOutWithMock( self.connection, 'disconnect' )
#    self.connection.disconnect().AndRaise( Exception("can'na do it cap'n") )
#
#    self.mox.StubOutWithMock( self.connection, 'log' )
#    self.connection.log( 'error while disconnecting', logging.ERROR )

#    connection_strategy.event.timeout(0, self.strategy._connect_cb)

#    self.mox.ReplayAll()
#    self.strategy.connect()

#  def test_connect_honors_delay(self):
#    self.strategy._pending_connect = None

#    self.mox.StubOutWithMock( self.connection, 'disconnect' )
#    self.connection.disconnect()
#
#    connection_strategy.event.timeout(42, self.strategy._connect_cb).AndReturn('foo')

#    self.mox.ReplayAll()
#    self.strategy.connect( 42 )
#    self.assertTrue( self.strategy._pending_connect, 'foo' )

    def test_connect_had_single_pending_event(self):
        self.strategy._pending_connect = 'foo'

        expect(self.connection.logger.debug).args("disconnecting connection")
        expect(self.connection.disconnect)
        expect(self.connection.logger.debug).args("Pending connect: %s", 'foo')

        self.strategy.connect()
        self.assertTrue(self.strategy._pending_connect, 'foo')

    def test_connect_cb_when_successful_and_not_reconnecting(self):
        self.strategy._pending_connect = 'foo'
        self.strategy._cur_host = Host('bar')
        self.strategy._reconnecting = False

        expect(self.connection.logger.debug).args("Connecting to %s on %s",
                                                  'bar', 5672)
        expect(self.connection.connect).args('bar', 5672)
        expect(self.connection.logger.debug).args('Connected to %s',
                                                  self.strategy._cur_host)

        self.strategy._connect_cb()
        self.assertTrue(self.strategy._pending_connect is None)
        self.assertEquals(CONNECTED, self.strategy._cur_host.state)

    def test_connect_cb_when_successful_and_reconnecting(self):
        reconnect_cb = mock()
        self.strategy._pending_connect = 'foo'
        self.strategy._cur_host = Host('bar')
        self.strategy._reconnecting = True
        self.strategy.reconnect_callbacks = [reconnect_cb]

        expect(self.connection.logger.debug).args("Connecting to %s on %s",
                                                  'bar', 5672)
        expect(self.connection.connect).args('bar', 5672)
        expect(self.connection.logger.info).args('Connected to %s',
                                                 self.strategy._cur_host)
        expect(reconnect_cb)

        self.strategy._connect_cb()
        self.assertTrue(self.strategy._pending_connect is None)
        self.assertEquals(CONNECTED, self.strategy._cur_host.state)
        self.assertFalse(self.strategy._reconnecting)

    def test_connect_cb_on_fail_and_first_connect_attempt(self):
        self.strategy._cur_host = Host('bar')

        expect(self.connection.logger.debug).args("Connecting to %s on %s",
                                                  'bar', 5672)
        expect(self.connection.connect).args('bar', 5672).raises(
            socket.error('fail sauce'))

        expect(self.connection.logger.exception).args(
            "Failed to connect to %s, will try again in %d seconds",
            self.strategy._cur_host, 2)
        expect(self.strategy.connect).args(2)

        self.strategy._connect_cb()
        self.assertEquals(FAILED, self.strategy._cur_host.state)

    def test_connect_cb_on_fail_and_second_connect_attempt(self):
        self.strategy._cur_host = Host('bar')
        self.strategy._cur_host.state = FAILED

        expect(self.connection.logger.debug).args("Connecting to %s on %s",
                                                  'bar', 5672)
        expect(self.connection.connect).args('bar', 5672).raises(
            socket.error('fail sauce'))
        expect(self.connection.logger.critical).args("Failed to connect to %s",
                                                     self.strategy._cur_host)
        expect(self.strategy.next_host)

        self.strategy._connect_cb()
Exemple #5
0
class Connection(object):

  class TooManyChannels(ConnectionError): '''This connection has too many channels open.  Non-fatal.'''
  class InvalidChannel(ConnectionError): '''The channel id does not correspond to an existing channel.  Non-fatal.'''

  def __init__(self, **kwargs):
    '''
    Initialize the connection.
    '''
    self._debug = kwargs.get('debug', False)
    self._logger = kwargs.get('logger', root_logger)

    self._user = kwargs.get('user', 'guest')
    self._password = kwargs.get('password', 'guest')
    self._host = kwargs.get('host', 'localhost')
    self._vhost = kwargs.get('vhost', '/')

    self._connect_timeout = kwargs.get('connect_timeout', 5)
    self._sock_opts = kwargs.get('sock_opts')
    self._sock = None
    self._heartbeat = kwargs.get('heartbeat')
    self._reconnect_cb = kwargs.get('reconnect_cb')
    self._close_cb = kwargs.get('close_cb')

    self._login_method = kwargs.get('login_method', 'AMQPLAIN')
    self._locale = kwargs.get('locale', 'en_US')
    self._client_properties = kwargs.get('client_properties')

    self._properties = LIBRARY_PROPERTIES.copy()
    if self._client_properties:
      self._properties.update( self._client_properties )

    self._closed = False
    self._connected = False
    self._close_info = {
      'reply_code'    : 0,
      'reply_text'    : 'first connect',
      'class_id'      : 0,
      'method_id'     : 0
    }
    
    self._channels = {
      0 : ConnectionChannel(self, 0)
    } 
    
    login_response = Writer()
    login_response.write_table({'LOGIN': self._user, 'PASSWORD': self._password})
    #stream = BytesIO()
    #login_response.flush(stream)
    #self._login_response = stream.getvalue()[4:]  #Skip the length
                                                      #at the beginning
    self._login_response = login_response.buffer()[4:]
    
    self._channel_counter = 0
    self._channel_max = 65535
    self._frame_max = 65535

    self._frames_read = 0
    self._frames_written = 0

    self._strategy = kwargs.get('connection_strategy')
    if not self._strategy:
      self._strategy = ConnectionStrategy( self, self._host, reconnect_cb = self._reconnect_cb )
    self._strategy.connect()

    self._output_frame_buffer = []
    
  @property
  def logger(self):
    return self._logger

  @property
  def debug(self):
    return self._debug

  @property
  def frame_max(self):
    return self._frame_max

  @property
  def channel_max(self):
    return self._channel_max

  @property
  def frames_read(self):
    '''Number of frames read in the lifetime of this connection.'''
    return self._frames_read

  @property
  def frames_written(self):
    '''Number of frames written in the lifetime of this connection.'''
    return self._frames_written

  @property
  def close_info(self):
    '''Return dict with information on why this connection is closed.  Will
    return None if the connections is open.'''
    return self._close_info if self._closed else None
  
  def reconnect(self):
    '''Reconnect to the configured host and port.'''
    self._strategy.connect()
  
  def connect(self, host, port):
    '''
    Connect to a host and port. Can be called directly, or is called by the
    strategy as it tries to find and connect to hosts.
    '''
    # Clear the connect state immediately since we're no longer connected
    # at this point.
    self._connected = False
    
    # NOTE: purposefully leave output_frame_buffer alone so that pending writes can
    # still occur.  this allows the reconnect to occur silently without
    # completely breaking any pending data on, say, a channel that was just
    # opened.
    self._sock = EventSocket( read_cb=self._sock_read_cb,
      close_cb=self._sock_close_cb, error_cb=self._sock_error_cb,
      debug=self._debug, logger=self._logger )
    self._sock.settimeout( self._connect_timeout )
    if self._sock_opts:
      for k,v in self._sock_opts.iteritems():
        family,type = k
        self._sock.setsockopt(family, type, v)
    self._sock.connect( (host,port) )
    self._sock.setblocking( False )

    # Only after the socket has connected do we clear this state; closed must
    # be False so that writes can be buffered in writePacket().  The closed
    # state might have been set to True due to a socket error or a redirect.
    self._host = "%s:%d"%(host,port)
    self._closed = False
    self._close_info = {
      'reply_code'    : 0,
      'reply_text'    : 'failed to connect to %s'%(self._host),
      'class_id'      : 0,
      'method_id'     : 0
    }

    self._sock.write( PROTOCOL_HEADER )
  
  def disconnect(self):
    '''
    Disconnect from the current host, but otherwise leave this object "open"
    so that it can be reconnected.
    '''
    self._connected = False
    if self._sock!=None:
      self._sock.close_cb = None
      try:
        self._sock.close()
      except: 
        self.logger.error("Failed to disconnect socket to %s", self._host, exc_info=True)
      self._sock = None
  
  def add_reconnect_callback(self, callback):
    '''Adds a reconnect callback to the strategy.  This can be used to
    resubscribe to exchanges, etc.'''
    self._strategy.reconnect_callbacks.append(callback)

  ###
  ### EventSocket callbacks
  ###
  def _sock_read_cb(self, sock):
    '''
    Callback when there's data to read on the socket.
    '''
    try:
      self._read_frames()
    except:
      self.logger.error("Failed to read frames from %s", self._host, exc_info=True)
      self.close( reply_code=501, reply_text='Error parsing frames' )

  def _sock_close_cb(self, sock):
    """
    Callback when socket closed.  This is intended to be the callback when the
    closure is unexpected.
    """
    self.logger.warning( 'socket to %s closed unexpectedly', self._host )
    self._close_info = {
      'reply_code'    : 0,
      'reply_text'    : 'socket closed unexpectedly to %s'%(self._host),
      'class_id'      : 0,
      'method_id'     : 0
    }

    # We're not connected any more (we're not closed but we're definitely not
    # connected)
    self._connected = False
    self._sock = None

    # Call back to a user-provided close function
    self._callback_close()

    # Fail and do nothing. If you haven't configured permissions and that's 
    # why the socket is closing, this keeps us from looping.
    self._strategy.fail()
  
  def _sock_error_cb(self, sock, msg, exception=None):
    """
    Callback when there's an error on the socket.
    """
    self.logger.error( 'error on connection to %s: %s', self._host, msg)
    self._close_info = {
      'reply_code'    : 0,
      'reply_text'    : 'socket error on host %s: %s'%(self._host, msg),
      'class_id'      : 0,
      'method_id'     : 0
    }
    
    # we're not connected any more (we're not closed but we're definitely not
    # connected)
    self._connected = False
    self._sock = None

    # Call back to a user-provided close function
    self._callback_close()

    # Fail and try to reconnect, because this is expected to be a transient error.
    self._strategy.fail()
    self._strategy.next_host()

  ###
  ### Connection methods
  ###
  def _next_channel_id(self):
    '''Return the next possible channel id.  Is a circular enumeration.'''
    self._channel_counter += 1
    if self._channel_counter >= self._channel_max:
      self._channel_counter = 1
    return self._channel_counter

  def channel(self, channel_id=None):
    """
    Fetch a Channel object identified by the numeric channel_id, or
    create that object if it doesn't already exist.  If channel_id is not
    None but no channel exists for that id, will raise InvalidChannel.  If
    there are already too many channels open, will raise TooManyChannels.
    """
    if channel_id is None:
      # adjust for channel 0
      if len(self._channels)-1 >= self._channel_max:
        raise Connection.TooManyChannels( "%d channels already open, max %d",
          len(self._channels)-1, self._channel_max )
      channel_id = self._next_channel_id()
      while channel_id in self._channels:
        channel_id = self._next_channel_id()
    elif channel_id in self._channels:
      return self._channels[channel_id]
    else:
      raise Connection.InvalidChannel("%s is not a valid channel id", channel_id )

    # Call open() here so that ConnectionChannel doesn't have it called.  Could
    # also solve this other ways, but it's a HACK regardless.
    rval = Channel(self, channel_id)
    self._channels[ channel_id ] = rval
    rval.open()
    return rval

  def close(self, reply_code=0, reply_text='', class_id=0, method_id=0):
    '''
    Close this connection.
    '''
    self._close_info = {
      'reply_code'    : reply_code,
      'reply_text'    : reply_text,
      'class_id'      : class_id,
      'method_id'     : method_id
    }
    self._channels[0].close()

  def _close_socket(self):
    '''Close the socket.'''
    # The assumption here is that we don't want auto-reconnect to kick in if
    # the socket is purposefully closed.
    self._closed = True

    # By the time we hear about the protocol-level closure, the socket may
    # have already gone away.
    if self._sock != None:
      self._sock.close_cb = None
      try:
        self._sock.close()
      except:
        self.logger.error( 'error closing socket' )
      self._sock = None

  def _callback_close(self):
    '''Callback to any close handler.'''
    if self._close_cb:
      try: self._close_cb()
      except SystemExit: raise
      except: self.logger.error( 'error calling close callback' )


  def _read_frames(self):
    '''
    Read frames from the socket.
    '''
    # Because of the timer callback to dataRead when we re-buffered, there's a
    # chance that in between we've lost the socket.  If that's the case, just
    # silently return as some code elsewhere would have already notified us.
    # That bug could be fixed by improving the message reading so that we consume
    # all possible messages and ensure that only a partial message was rebuffered,
    # so that we can rely on the next read event to read the subsequent message.
    if self._sock is None:
      return
    
    data = self._sock.read()
    reader = Reader( data )
    p_channels = set()
    
    for frame in Frame.read_frames( reader ):
      if self._debug > 1:
        self.logger.debug( "READ: %s", frame )
      self._frames_read += 1
      ch = self.channel( frame.channel_id )
      ch.buffer_frame( frame )
      p_channels.add( ch )

    # Still not clear on what's the best approach here. It seems there's a
    # slight speedup by calling this directly rather than delaying, but the
    # delay allows for pending IO with higher priority to execute.
    self._process_channels( p_channels )
    #event.timeout(0, self._process_channels, p_channels)

    # HACK: read the buffer contents and re-buffer.  Would prefer to pass
    # buffer back, but there's no good way of asking the total size of the
    # buffer, comparing to tell(), and then re-buffering.  There's also no
    # ability to clear the buffer up to the current position.
    # NOTE: This will be cleared up once eventsocket supports the 
    # uber-awesome buffering scheme that will utilize mmap.
    if reader.tell() < len(data):
      self._sock.buffer( data[reader.tell():] )

  def _process_channels(self, channels):
    '''
    Walk through a set of channels and process their frame buffer. Will
    collect all socket output and flush in one write.
    '''
    for channel in channels:
      channel.process_frames()

  def _flush_buffered_frames(self):
    # In the rare case (a bug) where this is called but send_frame thinks
    # they should be buffered, don't clobber.
    frames = self._output_frame_buffer
    self._output_frame_buffer = []
    for frame in frames:
      self.send_frame( frame )
  
  def send_frame(self, frame):
    '''
    Send a single frame. If there is an output buffer, write to that, else send
    immediately to the socket.
    '''
    if self._closed:
      if self._close_info and len(self._close_info['reply_text'])>0:
        raise ConnectionClosed("connection is closed: %s : %s"%\
          (self._close_info['reply_code'],self._close_info['reply_text']) )
      raise ConnectionClosed("connection is closed")

    if self._sock==None or (not self._connected and frame.channel_id!=0):
      self._output_frame_buffer.append( frame )
      return
    
    if self._debug > 1:
      self.logger.debug( "WRITE: %s", frame )

    buf = bytearray()
    frame.write_frame(buf)
    self._sock.write( buf )
    
    self._frames_written += 1
Exemple #6
0
    def __init__(self, **kwargs):
        '''
    Initialize the connection.
    '''
        self._debug = kwargs.get('debug', False)
        self._logger = kwargs.get('logger', root_logger)

        self._user = kwargs.get('user', 'guest')
        self._password = kwargs.get('password', 'guest')
        self._host = kwargs.get('host', 'localhost')
        self._vhost = kwargs.get('vhost', '/')

        self._connect_timeout = kwargs.get('connect_timeout', 5)
        self._sock_opts = kwargs.get('sock_opts')
        self._sock = None
        self._heartbeat = kwargs.get('heartbeat')
        self._reconnect_cb = kwargs.get('reconnect_cb')
        self._close_cb = kwargs.get('close_cb')

        self._login_method = kwargs.get('login_method', 'AMQPLAIN')
        self._locale = kwargs.get('locale', 'en_US')
        self._client_properties = kwargs.get('client_properties')

        self._properties = LIBRARY_PROPERTIES.copy()
        if self._client_properties:
            self._properties.update(self._client_properties)

        self._closed = False
        self._connected = False
        self._close_info = {
            'reply_code': 0,
            'reply_text': 'first connect',
            'class_id': 0,
            'method_id': 0
        }

        self._channels = {0: ConnectionChannel(self, 0)}

        login_response = Writer()
        login_response.write_table({
            'LOGIN': self._user,
            'PASSWORD': self._password
        })
        #stream = BytesIO()
        #login_response.flush(stream)
        #self._login_response = stream.getvalue()[4:]  #Skip the length
        #at the beginning
        self._login_response = login_response.buffer()[4:]

        self._channel_counter = 0
        self._channel_max = 65535
        self._frame_max = 65535

        self._frames_read = 0
        self._frames_written = 0

        self._strategy = kwargs.get('connection_strategy')
        if not self._strategy:
            self._strategy = ConnectionStrategy(
                self, self._host, reconnect_cb=self._reconnect_cb)
        self._strategy.connect()

        self._output_frame_buffer = []
Exemple #7
0
class Connection(object):
    class TooManyChannels(ConnectionError):
        '''This connection has too many channels open.  Non-fatal.'''

    class InvalidChannel(ConnectionError):
        '''The channel id does not correspond to an existing channel.  Non-fatal.'''

    def __init__(self, **kwargs):
        '''
    Initialize the connection.
    '''
        self._debug = kwargs.get('debug', False)
        self._logger = kwargs.get('logger', root_logger)

        self._user = kwargs.get('user', 'guest')
        self._password = kwargs.get('password', 'guest')
        self._host = kwargs.get('host', 'localhost')
        self._vhost = kwargs.get('vhost', '/')

        self._connect_timeout = kwargs.get('connect_timeout', 5)
        self._sock_opts = kwargs.get('sock_opts')
        self._sock = None
        self._heartbeat = kwargs.get('heartbeat')
        self._reconnect_cb = kwargs.get('reconnect_cb')
        self._close_cb = kwargs.get('close_cb')

        self._login_method = kwargs.get('login_method', 'AMQPLAIN')
        self._locale = kwargs.get('locale', 'en_US')
        self._client_properties = kwargs.get('client_properties')

        self._properties = LIBRARY_PROPERTIES.copy()
        if self._client_properties:
            self._properties.update(self._client_properties)

        self._closed = False
        self._connected = False
        self._close_info = {
            'reply_code': 0,
            'reply_text': 'first connect',
            'class_id': 0,
            'method_id': 0
        }

        self._channels = {0: ConnectionChannel(self, 0)}

        login_response = Writer()
        login_response.write_table({
            'LOGIN': self._user,
            'PASSWORD': self._password
        })
        #stream = BytesIO()
        #login_response.flush(stream)
        #self._login_response = stream.getvalue()[4:]  #Skip the length
        #at the beginning
        self._login_response = login_response.buffer()[4:]

        self._channel_counter = 0
        self._channel_max = 65535
        self._frame_max = 65535

        self._frames_read = 0
        self._frames_written = 0

        self._strategy = kwargs.get('connection_strategy')
        if not self._strategy:
            self._strategy = ConnectionStrategy(
                self, self._host, reconnect_cb=self._reconnect_cb)
        self._strategy.connect()

        self._output_frame_buffer = []

    @property
    def logger(self):
        return self._logger

    @property
    def debug(self):
        return self._debug

    @property
    def frame_max(self):
        return self._frame_max

    @property
    def channel_max(self):
        return self._channel_max

    @property
    def frames_read(self):
        '''Number of frames read in the lifetime of this connection.'''
        return self._frames_read

    @property
    def frames_written(self):
        '''Number of frames written in the lifetime of this connection.'''
        return self._frames_written

    @property
    def close_info(self):
        '''Return dict with information on why this connection is closed.  Will
    return None if the connections is open.'''
        return self._close_info if self._closed else None

    def reconnect(self):
        '''Reconnect to the configured host and port.'''
        self._strategy.connect()

    def connect(self, host, port):
        '''
    Connect to a host and port. Can be called directly, or is called by the
    strategy as it tries to find and connect to hosts.
    '''
        # Clear the connect state immediately since we're no longer connected
        # at this point.
        self._connected = False

        # NOTE: purposefully leave output_frame_buffer alone so that pending writes can
        # still occur.  this allows the reconnect to occur silently without
        # completely breaking any pending data on, say, a channel that was just
        # opened.
        self._sock = EventSocket(read_cb=self._sock_read_cb,
                                 close_cb=self._sock_close_cb,
                                 error_cb=self._sock_error_cb,
                                 debug=self._debug,
                                 logger=self._logger)
        self._sock.settimeout(self._connect_timeout)
        if self._sock_opts:
            for k, v in self._sock_opts.iteritems():
                family, type = k
                self._sock.setsockopt(family, type, v)
        self._sock.connect((host, port))
        self._sock.setblocking(False)

        # Only after the socket has connected do we clear this state; closed must
        # be False so that writes can be buffered in writePacket().  The closed
        # state might have been set to True due to a socket error or a redirect.
        self._host = "%s:%d" % (host, port)
        self._closed = False
        self._close_info = {
            'reply_code': 0,
            'reply_text': 'failed to connect to %s' % (self._host),
            'class_id': 0,
            'method_id': 0
        }

        self._sock.write(PROTOCOL_HEADER)

    def disconnect(self):
        '''
    Disconnect from the current host, but otherwise leave this object "open"
    so that it can be reconnected.
    '''
        self._connected = False
        if self._sock != None:
            self._sock.close_cb = None
            try:
                self._sock.close()
            except:
                self.logger.error("Failed to disconnect socket to %s",
                                  self._host,
                                  exc_info=True)
            self._sock = None

    def add_reconnect_callback(self, callback):
        '''Adds a reconnect callback to the strategy.  This can be used to
    resubscribe to exchanges, etc.'''
        self._strategy.reconnect_callbacks.append(callback)

    ###
    ### EventSocket callbacks
    ###
    def _sock_read_cb(self, sock):
        '''
    Callback when there's data to read on the socket.
    '''
        try:
            self._read_frames()
        except:
            self.logger.error("Failed to read frames from %s",
                              self._host,
                              exc_info=True)
            self.close(reply_code=501, reply_text='Error parsing frames')

    def _sock_close_cb(self, sock):
        """
    Callback when socket closed.  This is intended to be the callback when the
    closure is unexpected.
    """
        self.logger.warning('socket to %s closed unexpectedly', self._host)
        self._close_info = {
            'reply_code': 0,
            'reply_text': 'socket closed unexpectedly to %s' % (self._host),
            'class_id': 0,
            'method_id': 0
        }

        # We're not connected any more (we're not closed but we're definitely not
        # connected)
        self._connected = False
        self._sock = None

        # Call back to a user-provided close function
        self._callback_close()

        # Fail and do nothing. If you haven't configured permissions and that's
        # why the socket is closing, this keeps us from looping.
        self._strategy.fail()

    def _sock_error_cb(self, sock, msg, exception=None):
        """
    Callback when there's an error on the socket.
    """
        self.logger.error('error on connection to %s: %s', self._host, msg)
        self._close_info = {
            'reply_code': 0,
            'reply_text': 'socket error on host %s: %s' % (self._host, msg),
            'class_id': 0,
            'method_id': 0
        }

        # we're not connected any more (we're not closed but we're definitely not
        # connected)
        self._connected = False
        self._sock = None

        # Call back to a user-provided close function
        self._callback_close()

        # Fail and try to reconnect, because this is expected to be a transient error.
        self._strategy.fail()
        self._strategy.next_host()

    ###
    ### Connection methods
    ###
    def _next_channel_id(self):
        '''Return the next possible channel id.  Is a circular enumeration.'''
        self._channel_counter += 1
        if self._channel_counter >= self._channel_max:
            self._channel_counter = 1
        return self._channel_counter

    def channel(self, channel_id=None):
        """
    Fetch a Channel object identified by the numeric channel_id, or
    create that object if it doesn't already exist.  If channel_id is not
    None but no channel exists for that id, will raise InvalidChannel.  If
    there are already too many channels open, will raise TooManyChannels.
    """
        if channel_id is None:
            # adjust for channel 0
            if len(self._channels) - 1 >= self._channel_max:
                raise Connection.TooManyChannels(
                    "%d channels already open, max %d",
                    len(self._channels) - 1, self._channel_max)
            channel_id = self._next_channel_id()
            while channel_id in self._channels:
                channel_id = self._next_channel_id()
        elif channel_id in self._channels:
            return self._channels[channel_id]
        else:
            raise Connection.InvalidChannel("%s is not a valid channel id",
                                            channel_id)

        # Call open() here so that ConnectionChannel doesn't have it called.  Could
        # also solve this other ways, but it's a HACK regardless.
        rval = Channel(self, channel_id)
        self._channels[channel_id] = rval
        rval.open()
        return rval

    def close(self, reply_code=0, reply_text='', class_id=0, method_id=0):
        '''
    Close this connection.
    '''
        self._close_info = {
            'reply_code': reply_code,
            'reply_text': reply_text,
            'class_id': class_id,
            'method_id': method_id
        }
        self._channels[0].close()

    def _close_socket(self):
        '''Close the socket.'''
        # The assumption here is that we don't want auto-reconnect to kick in if
        # the socket is purposefully closed.
        self._closed = True

        # By the time we hear about the protocol-level closure, the socket may
        # have already gone away.
        if self._sock != None:
            self._sock.close_cb = None
            try:
                self._sock.close()
            except:
                self.logger.error('error closing socket')
            self._sock = None

    def _callback_close(self):
        '''Callback to any close handler.'''
        if self._close_cb:
            try:
                self._close_cb()
            except SystemExit:
                raise
            except:
                self.logger.error('error calling close callback')

    def _read_frames(self):
        '''
    Read frames from the socket.
    '''
        # Because of the timer callback to dataRead when we re-buffered, there's a
        # chance that in between we've lost the socket.  If that's the case, just
        # silently return as some code elsewhere would have already notified us.
        # That bug could be fixed by improving the message reading so that we consume
        # all possible messages and ensure that only a partial message was rebuffered,
        # so that we can rely on the next read event to read the subsequent message.
        if self._sock is None:
            return

        data = self._sock.read()
        reader = Reader(data)
        p_channels = set()

        for frame in Frame.read_frames(reader):
            if self._debug > 1:
                self.logger.debug("READ: %s", frame)
            self._frames_read += 1
            ch = self.channel(frame.channel_id)
            ch.buffer_frame(frame)
            p_channels.add(ch)

        # Still not clear on what's the best approach here. It seems there's a
        # slight speedup by calling this directly rather than delaying, but the
        # delay allows for pending IO with higher priority to execute.
        self._process_channels(p_channels)
        #event.timeout(0, self._process_channels, p_channels)

        # HACK: read the buffer contents and re-buffer.  Would prefer to pass
        # buffer back, but there's no good way of asking the total size of the
        # buffer, comparing to tell(), and then re-buffering.  There's also no
        # ability to clear the buffer up to the current position.
        # NOTE: This will be cleared up once eventsocket supports the
        # uber-awesome buffering scheme that will utilize mmap.
        if reader.tell() < len(data):
            self._sock.buffer(data[reader.tell():])

    def _process_channels(self, channels):
        '''
    Walk through a set of channels and process their frame buffer. Will
    collect all socket output and flush in one write.
    '''
        for channel in channels:
            channel.process_frames()

    def _flush_buffered_frames(self):
        # In the rare case (a bug) where this is called but send_frame thinks
        # they should be buffered, don't clobber.
        frames = self._output_frame_buffer
        self._output_frame_buffer = []
        for frame in frames:
            self.send_frame(frame)

    def send_frame(self, frame):
        '''
    Send a single frame. If there is an output buffer, write to that, else send
    immediately to the socket.
    '''
        if self._closed:
            if self._close_info and len(self._close_info['reply_text']) > 0:
                raise ConnectionClosed("connection is closed: %s : %s"%\
                  (self._close_info['reply_code'],self._close_info['reply_text']) )
            raise ConnectionClosed("connection is closed")

        if self._sock == None or (not self._connected
                                  and frame.channel_id != 0):
            self._output_frame_buffer.append(frame)
            return

        if self._debug > 1:
            self.logger.debug("WRITE: %s", frame)

        buf = bytearray()
        frame.write_frame(buf)
        self._sock.write(buf)

        self._frames_written += 1
  def setUp(self):
    super(ConnectionStrategyTest,self).setUp()

    self.connection = mock()
    self.connection.logger = mock()
    self.strategy = ConnectionStrategy( self.connection, 'localhost' )
class ConnectionStrategyTest(Chai):

  def setUp(self):
    super(ConnectionStrategyTest,self).setUp()

    self.connection = mock()
    self.connection.logger = mock()
    self.strategy = ConnectionStrategy( self.connection, 'localhost' )

  def test_init_is_doin_it_right(self):
    self.assertEquals( self.strategy._connection, self.connection )
    self.assertEquals( Host(socket.gethostname()), self.strategy._orig_host )
    self.assertEquals( [self.strategy._orig_host], self.strategy._known_hosts )
    self.assertEquals( self.strategy._orig_host, self.strategy._cur_host )
    self.assertFalse(self.strategy._reconnecting)
    self.assertEqual([], self.strategy.reconnect_callbacks)

  def test_init_with_reconnect_cb(self):
    strategy = ConnectionStrategy(self.connection, 'localhost', reconnect_cb = 'my_reconnect_callback')
    self.assertEqual(['my_reconnect_callback'], strategy.reconnect_callbacks)

  def test_set_known_hosts_is_single_entry(self):
    self.assertEquals( [self.strategy._orig_host], self.strategy._known_hosts )
    self.strategy.set_known_hosts( socket.gethostname() )
    self.assertEquals( [self.strategy._orig_host], self.strategy._known_hosts )
  
  def test_set_known_hosts_updates_list_correctly(self):
    self.assertEquals( [self.strategy._orig_host], self.strategy._known_hosts )
    self.strategy.set_known_hosts( 'localhost:4200,localhost,foo:1234' )

    self.assertEquals( [Host('localhost'),Host('localhost:4200'),Host('foo:1234')], self.strategy._known_hosts )

  def test_set_known_hosts_handles_misconfigured_cluster(self):
    self.strategy._cur_host = Host('bar')
    self.strategy._orig_host = Host('foo:5678')
    self.strategy._known_hosts = [ self.strategy._orig_host ]

    expect(self.connection.logger.warning).args(
      "current host %s not in known hosts %s, reconnecting to %s in %ds!",
      self.strategy._cur_host, [Host('foo:5678'),Host('foo:1234')], self.strategy._orig_host, 5 )
    
    expect(self.strategy.connect).args( 5 )

    self.strategy.set_known_hosts( 'foo:1234' )

  def test_next_host_handles_simple_base_case(self):
    self.strategy._cur_host = Host('localhost')
    self.strategy._known_hosts = [Host('localhost'), Host('foo')]
    
    expect(self.strategy.connect)

    self.strategy.next_host()
    self.assertEquals( Host('foo'), self.strategy._cur_host )
    self.assertFalse(self.strategy._reconnecting)

  def test_next_host_finds_first_unconnected_host(self):
    self.strategy._cur_host = Host('localhost')
    self.strategy._known_hosts = [Host('localhost'), Host('foo'), Host('bar')]
    self.strategy._known_hosts[0].state = CONNECTED
    self.strategy._known_hosts[1].state = CONNECTED
    
    expect(self.strategy.connect)

    self.strategy.next_host()
    self.assertEquals( Host('bar'), self.strategy._cur_host )
    self.assertFalse(self.strategy._reconnecting)

  def test_next_host_searches_for_unfailed_hosts_if_all_hosts_not_unconnected(self):
    self.strategy._cur_host = Host('foo')
    self.strategy._known_hosts = [Host('localhost'), Host('foo'), Host('bar'), Host('cat')]
    self.strategy._known_hosts[0].state = CONNECTED
    self.strategy._known_hosts[1].state = CONNECTED
    self.strategy._known_hosts[2].state = CONNECTED
    self.strategy._known_hosts[3].state = CONNECTED
    
    expect(self.strategy.connect)

    self.strategy.next_host()
    self.assertEquals( Host('localhost'), self.strategy._cur_host )
  
  def test_next_host_searches_for_unfailed_hosts_even_if_orig_host_is_failed(self):
    self.strategy._cur_host = Host('foo')
    self.strategy._known_hosts = [Host('localhost'), Host('foo'), Host('bar'), Host('cat')]
    self.strategy._known_hosts[0].state = FAILED
    self.strategy._known_hosts[1].state = CONNECTED
    self.strategy._known_hosts[2].state = CONNECTED
    self.strategy._known_hosts[3].state = CONNECTED
    
    expect(self.strategy.connect)

    self.strategy.next_host()
    self.assertEquals( Host('foo'), self.strategy._cur_host )

  def test_next_host_defaults_to_original_with_delay_if_all_hosts_failed(self):
    self.strategy._orig_host = Host('foo')
    self.strategy._cur_host = Host('bar')
    self.strategy._known_hosts = [Host('foo'), Host('bar'), Host('cat'), Host('dog')]
    self.strategy._known_hosts[0].state = FAILED
    self.strategy._known_hosts[1].state = FAILED
    self.strategy._known_hosts[2].state = FAILED
    self.strategy._known_hosts[3].state = FAILED

    expect(self.connection.logger.warning).args( 
      'Failed to connect to any of %s, will retry %s in %d seconds',
      self.strategy._known_hosts, self.strategy._orig_host, 5 )
    expect(self.strategy.connect).args( 5 )

    self.strategy.next_host()
    self.assertEquals( Host('foo'), self.strategy._cur_host )
    self.assertTrue(self.strategy._reconnecting)

  def test_fail_is_not_stoopehd(self):
    self.strategy._cur_host = Host('foo')
    self.assertEquals( UNCONNECTED, self.strategy._cur_host.state )
    
    stub(self.strategy.connect)
    stub(self.strategy.next_host)
    
    self.strategy.fail()
    self.assertEquals( FAILED, self.strategy._cur_host.state )

# FIXME: These tests need to be fixed and converted to Chai
#  def test_connect_basics(self):
#    self.strategy._pending_connect = None

#    self.mox.StubOutWithMock( self.connection, 'disconnect' )
#    self.connection.disconnect()

#    connection_strategy.event.timeout(0, self.strategy._connect_cb).AndReturn('foo')

#    self.mox.ReplayAll()
#    self.strategy.connect()
#    self.assertTrue( self.strategy._pending_connect, 'foo' )

#  def test_connect_handles_disconnect_errors(self):
#    self.strategy._pending_connect = None

#    self.mox.StubOutWithMock( self.connection, 'disconnect' )
#    self.connection.disconnect().AndRaise( Exception("can'na do it cap'n") )
#    
#    self.mox.StubOutWithMock( self.connection, 'log' )
#    self.connection.log( 'error while disconnecting', logging.ERROR )

#    connection_strategy.event.timeout(0, self.strategy._connect_cb)

#    self.mox.ReplayAll()
#    self.strategy.connect()

#  def test_connect_honors_delay(self):
#    self.strategy._pending_connect = None

#    self.mox.StubOutWithMock( self.connection, 'disconnect' )
#    self.connection.disconnect()
#    
#    connection_strategy.event.timeout(42, self.strategy._connect_cb).AndReturn('foo')

#    self.mox.ReplayAll()
#    self.strategy.connect( 42 )
#    self.assertTrue( self.strategy._pending_connect, 'foo' )
  
  def test_connect_had_single_pending_event(self):
    self.strategy._pending_connect = 'foo'

    expect(self.connection.logger.debug).args( "disconnecting connection" )
    expect(self.connection.disconnect)
    expect(self.connection.logger.debug).args("Pending connect: %s", 'foo')

    self.strategy.connect()
    self.assertTrue( self.strategy._pending_connect, 'foo' )

  def test_connect_cb_when_successful_and_not_reconnecting(self):
    self.strategy._pending_connect = 'foo'
    self.strategy._cur_host = Host('bar')
    self.strategy._reconnecting = False

    expect(self.connection.logger.debug).args( "Connecting to %s on %s", 'bar', 5672 )
    expect(self.connection.connect).args( 'bar', 5672 )
    expect(self.connection.logger.debug).args( 'Connected to %s', self.strategy._cur_host )
    
    self.strategy._connect_cb()
    self.assertTrue( self.strategy._pending_connect is None )
    self.assertEquals( CONNECTED, self.strategy._cur_host.state )

  def test_connect_cb_when_successful_and_reconnecting(self):
    reconnect_cb = mock()
    self.strategy._pending_connect = 'foo'
    self.strategy._cur_host = Host('bar')
    self.strategy._reconnecting = True
    self.strategy.reconnect_callbacks = [ reconnect_cb ]

    expect(self.connection.logger.debug).args( "Connecting to %s on %s", 'bar', 5672 )
    expect(self.connection.connect).args( 'bar', 5672 )
    expect(self.connection.logger.info).args( 'Connected to %s', self.strategy._cur_host )
    expect(reconnect_cb)
    
    self.strategy._connect_cb()
    self.assertTrue( self.strategy._pending_connect is None )
    self.assertEquals( CONNECTED, self.strategy._cur_host.state )
    self.assertFalse(self.strategy._reconnecting)

  def test_connect_cb_on_fail_and_first_connect_attempt(self):
    self.strategy._cur_host = Host('bar')
    
    expect(self.connection.logger.debug).args( "Connecting to %s on %s", 'bar', 5672 )
    expect(self.connection.connect).args( 'bar', 5672 ).raises( socket.error('fail sauce') )

    expect(self.connection.logger.exception).args(
      "Failed to connect to %s, will try again in %d seconds", self.strategy._cur_host, 2 )
    expect(self.strategy.connect).args( 2 )
    
    self.strategy._connect_cb()
    self.assertEquals( FAILED, self.strategy._cur_host.state )
  
  def test_connect_cb_on_fail_and_second_connect_attempt(self):
    self.strategy._cur_host = Host('bar')
    self.strategy._cur_host.state = FAILED
    
    expect(self.connection.logger.debug).args( "Connecting to %s on %s", 'bar', 5672 )
    expect(self.connection.connect).args( 'bar', 5672 ).raises( socket.error('fail sauce') )
    expect(self.connection.logger.critical).args( "Failed to connect to %s", self.strategy._cur_host )
    expect(self.strategy.next_host)
    
    self.strategy._connect_cb()