class SocketTests(unittest.TestCase): '''Test cases for micro behaviors related to socket operations. Some cases are covered as part of other tests, but in this fixture we check more details of specific method related to socket operation, with the help of mock classes to avoid expensive overhead. ''' class MockSocket(): '''A mock socket used instead of standard socket objects.''' def __init__(self): self.ex_on_send = None # raised from send() if not None self.recv_result = b'test' # dummy data or exception self.blockings = [] # history of setblocking() params def setblocking(self, on): self.blockings.append(on) def send(self, data): if self.ex_on_send is not None: raise self.ex_on_send return 10 # arbitrary choice def recv(self, len): if isinstance(self.recv_result, Exception): raise self.recv_result ret = self.recv_result self.recv_result = b'' # if called again, return empty data return ret def fileno(self): return 42 # arbitrary choice class LoggerWrapper(): '''A simple wrapper of logger to inspect log messages.''' def __init__(self, logger): self.error_called = 0 self.warn_called = 0 self.debug_called = 0 self.orig_logger = logger def error(self, *args): self.error_called += 1 self.orig_logger.error(*args) def warn(self, *args): self.warn_called += 1 self.orig_logger.warn(*args) def debug(self, *args): self.debug_called += 1 self.orig_logger.debug(*args) def mock_kill_socket(self, fileno, sock): '''A replacement of MsgQ.kill_socket method for inspection.''' self.__killed_socket = (fileno, sock) if fileno in self.__msgq.sockets: del self.__msgq.sockets[fileno] def setUp(self): self.__msgq = MsgQ() self.__msgq.kill_socket = self.mock_kill_socket self.__sock = self.MockSocket() self.__data = b'dummy' self.__msgq.sockets[42] = self.__sock self.__msgq.sendbuffs[42] = (None, b'testdata') self.__sock_error = socket.error() self.__killed_socket = None self.__logger = self.LoggerWrapper(msgq.logger) msgq.logger = self.__logger self.__orig_select = msgq.select.select def tearDown(self): msgq.logger = self.__logger.orig_logger msgq.select.select = self.__orig_select def test_send_data(self): # Successful case: _send_data() returns the hardcoded value, and # setblocking() is called twice with the expected parameters self.assertEqual(10, self.__msgq._send_data(self.__sock, self.__data)) self.assertEqual([0, 1], self.__sock.blockings) self.assertIsNone(self.__killed_socket) def test_send_data_interrupt(self): '''send() is interrupted. send_data() returns 0, sock isn't killed.''' expected_blockings = [] for eno in [errno.EAGAIN, errno.EWOULDBLOCK, errno.EINTR]: self.__sock_error.errno = eno self.__sock.ex_on_send = self.__sock_error self.assertEqual(0, self.__msgq._send_data(self.__sock, self.__data)) expected_blockings.extend([0, 1]) self.assertEqual(expected_blockings, self.__sock.blockings) self.assertIsNone(self.__killed_socket) def test_send_data_error(self): '''Unexpected error happens on send(). The socket is killed. If the error is EPIPE, it's logged at the warn level; otherwise an error message is logged. ''' expected_blockings = [] expected_errors = 0 expected_warns = 0 for eno in [errno.EPIPE, errno.ECONNRESET, errno.ENOBUFS]: self.__sock_error.errno = eno self.__sock.ex_on_send = self.__sock_error self.__killed_socket = None # clear any previuos value self.assertEqual(None, self.__msgq._send_data(self.__sock, self.__data)) self.assertEqual((42, self.__sock), self.__killed_socket) expected_blockings.extend([0, 1]) self.assertEqual(expected_blockings, self.__sock.blockings) if eno == errno.EPIPE: expected_warns += 1 else: expected_errors += 1 self.assertEqual(expected_errors, self.__logger.error_called) self.assertEqual(expected_warns, self.__logger.warn_called) def test_process_packet(self): '''Check some failure cases in handling an incoming message.''' expected_errors = 0 expected_debugs = 0 # if socket.recv() fails due to socket.error, it will be logged # as error and the socket will be killed regardless of errno. for eno in [errno.ENOBUFS, errno.ECONNRESET]: self.__sock_error.errno = eno self.__sock.recv_result = self.__sock_error self.__killed_socket = None # clear any previuos value self.__msgq.process_packet(42, self.__sock) self.assertEqual((42, self.__sock), self.__killed_socket) expected_errors += 1 self.assertEqual(expected_errors, self.__logger.error_called) self.assertEqual(expected_debugs, self.__logger.debug_called) # if socket.recv() returns empty data, the result depends on whether # there's any preceding data; in the second case below, at least # 6 bytes of data will be expected, and the second call to our faked # recv() returns empty data. In that case it will be logged as error. for recv_data in [b'', b'short']: self.__sock.recv_result = recv_data self.__killed_socket = None self.__msgq.process_packet(42, self.__sock) self.assertEqual((42, self.__sock), self.__killed_socket) if len(recv_data) == 0: expected_debugs += 1 else: expected_errors += 1 self.assertEqual(expected_errors, self.__logger.error_called) self.assertEqual(expected_debugs, self.__logger.debug_called) def test_do_select(self): """ Check the behaviour of the run_select method. In particular, check that we skip writing to the sockets we read, because a read may have side effects (like closing the socket) and we want to prevent strange behavior. """ self.__read_called = [] self.__write_called = [] self.__reads = None self.__writes = None def do_read(fd, socket): self.__read_called.append(fd) self.__msgq.running = False def do_write(fd): self.__write_called.append(fd) self.__msgq.running = False self.__msgq.process_packet = do_read self.__msgq._process_write = do_write self.__msgq.fd_to_lname = {42: 'lname', 44: 'other', 45: 'unused'} # The do_select does index it, but just passes the value. So reuse # the dict to safe typing in the test. self.__msgq.sockets = self.__msgq.fd_to_lname self.__msgq.sendbuffs = {42: 'data', 43: 'data'} def my_select(reads, writes, errors): self.__reads = reads self.__writes = writes self.assertEqual([], errors) return ([42, 44], [42, 43], []) msgq.select.select = my_select self.__msgq.listen_socket = DummySocket self.__msgq.running = True self.__msgq.run_select() self.assertEqual([42, 44], self.__read_called) self.assertEqual([43], self.__write_called) self.assertEqual({42, 44, 45}, set(self.__reads)) self.assertEqual({42, 43}, set(self.__writes))
class SocketTests(unittest.TestCase): '''Test cases for micro behaviors related to socket operations. Some cases are covered as part of other tests, but in this fixture we check more details of specific method related to socket operation, with the help of mock classes to avoid expensive overhead. ''' class MockSocket(): '''A mock socket used instead of standard socket objects.''' def __init__(self): self.ex_on_send = None # raised from send() if not None self.recv_result = b'test' # dummy data or exception self.blockings = [] # history of setblocking() params def setblocking(self, on): self.blockings.append(on) def send(self, data): if self.ex_on_send is not None: raise self.ex_on_send return 10 # arbitrary choice def recv(self, len): if isinstance(self.recv_result, Exception): raise self.recv_result ret = self.recv_result self.recv_result = b'' # if called again, return empty data return ret def fileno(self): return 42 # arbitrary choice class LoggerWrapper(): '''A simple wrapper of logger to inspect log messages.''' def __init__(self, logger): self.error_called = 0 self.warn_called = 0 self.debug_called = 0 self.orig_logger = logger def error(self, *args): self.error_called += 1 self.orig_logger.error(*args) def warn(self, *args): self.warn_called += 1 self.orig_logger.warn(*args) def debug(self, *args): self.debug_called += 1 self.orig_logger.debug(*args) def mock_kill_socket(self, fileno, sock): '''A replacement of MsgQ.kill_socket method for inspection.''' self.__killed_socket = (fileno, sock) if fileno in self.__msgq.sockets: del self.__msgq.sockets[fileno] def setUp(self): self.__msgq = MsgQ() self.__msgq.kill_socket = self.mock_kill_socket self.__sock = self.MockSocket() self.__data = b'dummy' self.__msgq.sockets[42] = self.__sock self.__msgq.sendbuffs[42] = (None, b'testdata') self.__sock_error = socket.error() self.__killed_socket = None self.__logger = self.LoggerWrapper(msgq.logger) msgq.logger = self.__logger def tearDown(self): msgq.logger = self.__logger.orig_logger def test_send_data(self): # Successful case: _send_data() returns the hardcoded value, and # setblocking() is called twice with the expected parameters self.assertEqual(10, self.__msgq._send_data(self.__sock, self.__data)) self.assertEqual([0, 1], self.__sock.blockings) self.assertIsNone(self.__killed_socket) def test_send_data_interrupt(self): '''send() is interrupted. send_data() returns 0, sock isn't killed.''' expected_blockings = [] for eno in [errno.EAGAIN, errno.EWOULDBLOCK, errno.EINTR]: self.__sock_error.errno = eno self.__sock.ex_on_send = self.__sock_error self.assertEqual(0, self.__msgq._send_data(self.__sock, self.__data)) expected_blockings.extend([0, 1]) self.assertEqual(expected_blockings, self.__sock.blockings) self.assertIsNone(self.__killed_socket) def test_send_data_error(self): '''Unexpected error happens on send(). The socket is killed. If the error is EPIPE, it's logged at the warn level; otherwise an error message is logged. ''' expected_blockings = [] expected_errors = 0 expected_warns = 0 for eno in [errno.EPIPE, errno.ECONNRESET, errno.ENOBUFS]: self.__sock_error.errno = eno self.__sock.ex_on_send = self.__sock_error self.__killed_socket = None # clear any previuos value self.assertEqual(None, self.__msgq._send_data(self.__sock, self.__data)) self.assertEqual((42, self.__sock), self.__killed_socket) expected_blockings.extend([0, 1]) self.assertEqual(expected_blockings, self.__sock.blockings) if eno == errno.EPIPE: expected_warns += 1 else: expected_errors += 1 self.assertEqual(expected_errors, self.__logger.error_called) self.assertEqual(expected_warns, self.__logger.warn_called) def test_process_fd_read_after_bad_write(self): '''Check the specific case of write fail followed by read attempt. The write failure results in kill_socket, then read shouldn't tried. ''' self.__sock_error.errno = errno.EPIPE self.__sock.ex_on_send = self.__sock_error self.__msgq.process_socket = None # if called, trigger an exception self.__msgq._process_fd(42, True, True, False) # shouldn't crash # check the socket is deleted from the fileno=>sock dictionary self.assertEqual({}, self.__msgq.sockets) def test_process_fd_close_after_bad_write(self): '''Similar to the previous, but for checking dup'ed kill attempt''' self.__sock_error.errno = errno.EPIPE self.__sock.ex_on_send = self.__sock_error self.__msgq._process_fd(42, True, False, True) # shouldn't crash self.assertEqual({}, self.__msgq.sockets) def test_process_fd_writer_after_close(self): '''Emulate a "writable" socket has been already closed and killed.''' # This just shouldn't crash self.__msgq._process_fd(4200, True, False, False) def test_process_packet(self): '''Check some failure cases in handling an incoming message.''' expected_errors = 0 expected_debugs = 0 # if socket.recv() fails due to socket.error, it will be logged # as error and the socket will be killed regardless of errno. for eno in [errno.ENOBUFS, errno.ECONNRESET]: self.__sock_error.errno = eno self.__sock.recv_result = self.__sock_error self.__killed_socket = None # clear any previuos value self.__msgq.process_packet(42, self.__sock) self.assertEqual((42, self.__sock), self.__killed_socket) expected_errors += 1 self.assertEqual(expected_errors, self.__logger.error_called) self.assertEqual(expected_debugs, self.__logger.debug_called) # if socket.recv() returns empty data, the result depends on whether # there's any preceding data; in the second case below, at least # 6 bytes of data will be expected, and the second call to our faked # recv() returns empty data. In that case it will be logged as error. for recv_data in [b'', b'short']: self.__sock.recv_result = recv_data self.__killed_socket = None self.__msgq.process_packet(42, self.__sock) self.assertEqual((42, self.__sock), self.__killed_socket) if len(recv_data) == 0: expected_debugs += 1 else: expected_errors += 1 self.assertEqual(expected_errors, self.__logger.error_called) self.assertEqual(expected_debugs, self.__logger.debug_called)