class Packet(object): # TODO blocksize should always be max_packet_size. # TODO Never read more data than necessary. # TODO Read and write should return packet events. def __init__(self, socket, max_packet_size=16384, blocksize=4096, encode=None, decode=None): self.env = socket.env self.socket = socket self.max_packet_size = max_packet_size self.blocksize = blocksize self.encode = encode self.decode = decode self._read_ev = None self._read_buf = b'' self._read_size = None self._write_ev = None self._write_buf = b'' def _wrap(self, event): if event.ok: event._value = Packet(event._value, self.max_packet_size, self.blocksize, self.encode, self.decode) def accept(self): event = self.socket.accept() event.callbacks.append(self._wrap) return event def bind(self, address): return self.socket.bind(address) def listen(self, backlog=5): return self.socket.listen(backlog) def connect(self, address): return self.socket.connect(address) @property def address(self): return self.socket.address @property def peer_address(self): return self.socket.address def read(self): if self._read_ev is not None: raise RuntimeError('Already reading') event = self._read_ev = Event(self.env) if self._read_buf: self._read_data(Event(self.env).succeed(b'')) else: self.socket.read(self.blocksize).callbacks.append(self._read_data) return event def _read_data(self, event): if not event.ok: event.defused = True self._read_ev.fail(event.value) self._read_ev = None return self._read_buf += event.value if self._read_size is None and len(self._read_buf) >= Header.size: self._read_size = Header.unpack_from(self._read_buf)[0] if self._read_size > self.max_packet_size: raise ValueError('Packet too large. Allowed %d bytes but ' 'got %d bytes' % (self.max_packet_size, self._read_size)) self._read_size += Header.size if (self._read_size is not None and len(self._read_buf) >= self._read_size): packet = self._read_buf[Header.size:self._read_size] if self.decode is None: self._read_ev.succeed(packet) else: # TODO Handle errors. self._read_ev.succeed(self.decode(packet)) self._read_buf = self._read_buf[self._read_size:] self._read_size = None self._read_ev = None return self.socket.read(self.blocksize).callbacks.append(self._read_data) def write(self, packet): if self._write_ev is not None: raise RuntimeError('Already writing') if self.encode is not None: packet = self.encode(packet) if len(packet) > self.max_packet_size: raise ValueError('Packet too large. Allowed %d bytes but ' 'got %d bytes' % (self.max_packet_size, len(packet))) self._write_ev = Event(self.env) self._write_buf = Header.pack(len(packet)) + packet self.socket.write(self._write_buf).callbacks.append(self._write_data) return self._write_ev def _write_data(self, event): if not event.ok: event.defused = True self._write_ev.fail(event.value) self._write_ev = None return self._write_buf = self._write_buf[event.value:] if not self._write_buf: self._write_ev.succeed() self._write_ev = None else: self.socket.write(self._write_buf).callbacks.append( self._write_data) def close(self): self.socket.close()
class Message(object): # TODO Rename class and module (channel)? def __init__(self, env, socket, codec=None, message_limit=1024): self.env = env self.socket = socket if codec is None: codec = JSON() self.codec = codec self.message_limit = message_limit self._message_id = count() self._in_queue = [] self._out_queue = [] self._in_messages = {} """Maps incoming message objects to ids.""" self._out_messages = {} """Maps outgoing message ids to objects.""" self._send_ev = None self._recv_ev = None self.reader = Process(self.env, self._reader()) self.writer = Process(self.env, self._writer()) def _reader(self): try: buffer = b'' while True: data = yield self.socket.read() msg_type, msg_id, content = self.codec.decode(data) if msg_type == REQUEST: message = InMessage(self, msg_id, content) message.callbacks.append(self._reply) if len(self._in_messages) >= self.message_limit: # Close the connection if the maximum number of # incoming messages is reached. self.close() raise MessageOverflowError( 'Incoming message limit of %d has ' 'been exceeded' % self.message_limit) self._in_messages[message] = msg_id if self._recv_ev is not None: self._recv_ev.succeed(message) self._recv_ev = None else: self._in_queue.append(message) elif msg_type == SUCCESS: self._out_messages.pop(msg_id).succeed(content) elif msg_type == FAILURE: self._out_messages.pop(msg_id).fail( RemoteException(self, content)) else: raise RuntimeError('Invalid message type %d' % msg_type) except BaseException as e: self._handle_error(self.reader, e) def _writer(self): env = self.env try: while True: if not self._out_queue: self._send_ev = Event(self.env) yield self._send_ev yield self.socket.write(self._out_queue.pop(0)) except BaseException as e: self._handle_error(self.writer, e) def _handle_error(self, process, err): # FIXME Should I really ignore errors? if isinstance(err, socket.error) and err.errno in UNCRITICAL_ERRORS: uncritical = True else: uncritical = False process.defused = uncritical if self._send_ev is not None: # FIXME Is this safe? Is it impossible, that socket.write has been # triggered but not yet been processed? self._send_ev.defused = uncritical self._send_ev.fail(err) if self._out_messages is not None: for msg_id, event in self._out_messages.items(): event.defused = uncritical event.fail(err) if self._recv_ev is not None: self._recv_ev.defused = uncritical self._recv_ev.fail(err) self._in_messages = None self._out_messages = None self._in_queue = None self._out_queue = None self._recv_ev = None self._send_ev = None raise err def _reply(self, event): try: message_id = self._in_messages.pop(event) except AttributeError: if self._in_messages is not None: raise # Channel has been closed. Ignore the event. event.defused = True return if event.ok: failure = None try: self._out_queue.append( self.codec.encode((SUCCESS, message_id, event._value))) except BaseException as e: failure = e else: failure = event._value if failure is not None: # Failure is handled on the remote side. event.defused = True # FIXME Ugly hack for python < 3.3 if hasattr(failure, '__traceback__'): stacktrace = traceback.format_exception( failure.__class__, failure, failure.__traceback__) else: stacktrace = traceback.format_exception_only( failure.__class__, failure) self._out_queue.append( self.codec.encode((FAILURE, message_id, ''.join(stacktrace)))) if self._send_ev is not None: self._send_ev.succeed() self._send_ev = None def send(self, content): if self._out_queue is None: raise self.writer.value if len(self._out_messages) >= self.message_limit: raise MessageOverflowError('Outgoing message limit of %d has been ' 'exceeded' % self.message_limit) message_id = next(self._message_id) data = self.codec.encode((REQUEST, message_id, content)) message = OutMessage(self, message_id, content) self._out_queue.append(data) self._out_messages[message_id] = message # Wake the writer process. if self._send_ev is not None: self._send_ev.succeed() self._send_ev = None return message def recv(self): if self._in_queue is None: raise self.reader.value # Enqueue reads if there are no pending incoming messages. if not self._in_queue: if self._recv_ev is not None: raise RuntimeError('Concurrent receive attempt') self._recv_ev = Event(self.env) return self._recv_ev return Event(self.env).succeed(self._in_queue.pop(0)) def close(self): self.socket.close()