def test_send_and_receive_message(self): p = BasicProtocol() p.transport = Transport() session_factory = SessionFactory(ASession) p.set_session_factory(session_factory) self.assertFalse(p.send_message("123")) msg = MessageHello() self.assertFalse(p.send_message(msg)) p.connectionMade() self.assertTrue(p.send_message(msg)) self.assertEqual(len(p.transport.buff), 1) p.dataReceived(p.transport.buff[0]) self.assertIsInstance(p.session.msgs[0], MessageHello) self.assertEquals(msg.timestamp, p.session.msgs[0].timestamp) time.sleep(1) msg = MessageHello() self.assertNotEquals(msg.timestamp, p.session.msgs[0].timestamp) self.assertTrue(p.send_message(msg)) self.assertEqual(len(p.transport.buff), 2) db = DataBuffer() db.append_string(p.transport.buff[1]) m = Message.deserialize(db)[0] self.assertEqual(m.timestamp, msg.timestamp) p.connectionLost() self.assertNotIn('session', p.__dict__)
def test_get_resource(self): conn = BasicProtocol() conn.transport = Mock() conn.server = Mock() db = DataBuffer() sess = TaskSession(conn) sess.send = lambda m: db.append_string(m.serialize()) sess._can_send = lambda *_: True sess.request_resource(str(uuid.uuid4()), TaskResourceHeader("tmp")) assert Message.deserialize_message(db.buffered_data)
def _prepare_msg_to_send(self, msg): if self.session is None: logger.error("Wrong session, not sending message") return None msg = self.session.sign(msg) if not msg: logger.error("Wrong session, not sending message") return None ser_msg = msg.serialize() enc_msg = self.session.encrypt(ser_msg) db = DataBuffer() db.append_len_prefixed_string(enc_msg) return db.read_all()
def test_decrypt_and_deserialize(self): db = DataBuffer() server = mock.Mock() n_messages = 10 def serialize_messages(_b): for m in [message.MessageHello() for _ in xrange(0, n_messages)]: m.serialize_to_buffer(_b) serialize_messages(db) server.decrypt = lambda x: x assert len(message.Message.decrypt_and_deserialize(db, server)) == n_messages patch_method = 'golem.network.transport.message.Message.deserialize_message' with mock.patch(patch_method, side_effect=lambda x: None): serialize_messages(db) assert len(message.Message.decrypt_and_deserialize(db, server)) == 0 def raise_assertion(*_): raise AssertionError() def raise_error(*_): raise Exception() server.decrypt = raise_assertion serialize_messages(db) result = message.Message.decrypt_and_deserialize(db, server) assert len(result) == n_messages assert all(not m.encrypted for m in result) server.decrypt = raise_error serialize_messages(db) result = message.Message.decrypt_and_deserialize(db, server) assert len(result) == 0
def test_timestamp_and_timezones(self): epoch_t = 1475238345.0 def set_tz(tz): os.environ['TZ'] = tz try: time.tzset() except AttributeError: raise unittest.SkipTest("tzset required") set_tz('Europe/Warsaw') warsaw_time = time.localtime(epoch_t) m = message.MessageHello(timestamp=epoch_t) db = DataBuffer() m.serialize_to_buffer(db) set_tz('US/Eastern') server = mock.Mock() server.decrypt = lambda x: x msgs = message.Message.decrypt_and_deserialize(db, server) assert len(msgs) == 1 newyork_time = time.localtime(msgs[0].timestamp) assert warsaw_time != newyork_time assert time.gmtime(epoch_t) == time.gmtime(msgs[0].timestamp)
def _prepare_msg_to_send(self, msg): ser_msg = msg.serialize() db = DataBuffer() db.append_len_prefixed_string(ser_msg) return db.read_all()
def __init__(self): self.opened = False self.db = DataBuffer() self.lock = Lock() SessionProtocol.__init__(self)
class BasicProtocol(SessionProtocol): """ Connection-oriented basic protocol for twisted, support message serialization""" def __init__(self): self.opened = False self.db = DataBuffer() self.lock = Lock() SessionProtocol.__init__(self) def send_message(self, msg): """ Serialize and send message :param Message msg: message to send :return bool: return True if message has been send, False if an error has """ if not self.opened: logger.error(msg) logger.error("Send message failed - connection closed.") return False msg_to_send = self._prepare_msg_to_send(msg) if msg_to_send is None: return False self.transport.getHandle() self.transport.write(msg_to_send) return True def close(self): """ Close connection, after writing all pending (flush the write buffer and wait for producer to finish). :return None: """ self.transport.loseConnection() def close_now(self): """ Close connection ASAP, doesn't flush the write buffer or wait for the producer to finish :return: """ self.opened = False self.transport.abortConnection() # Protocol functions def connectionMade(self): """Called when new connection is successfully opened""" SessionProtocol.connectionMade(self) self.opened = True def dataReceived(self, data): """Called when additional chunk of data is received from another peer""" if not self._can_receive(): return None if not self.session: logger.warning("No session argument in connection state") return None self._interpret(data) def connectionLost(self, reason=connectionDone): """Called when connection is lost (for whatever reason)""" self.opened = False if self.session: self.session.dropped() SessionProtocol.connectionLost(self, reason) # Protected functions def _prepare_msg_to_send(self, msg): ser_msg = msg.serialize() db = DataBuffer() db.append_len_prefixed_string(ser_msg) return db.read_all() def _can_receive(self): return self.opened and isinstance(self.db, DataBuffer) def _interpret(self, data): with self.lock: self.db.append_string(data) mess = self._data_to_messages() # Interpret messages if mess: for m in mess: self.session.interpret(m) elif data: logger.info("Deserialization of messages from {}:{} failed, maybe it's still " "too short?".format(self.session.address, self.session.port)) def _data_to_messages(self): return Message.deserialize(self.db)
def _prepare_msg_to_send(self, msg): ser_msg = golem_messages.dump(msg, None, None) db = DataBuffer() db.append_len_prefixed_bytes(ser_msg) return db.read_all()
def __init__(self): super().__init__() self.opened = False self.db = DataBuffer() self.spam_protector = SpamProtector()
class BasicProtocol(SessionProtocol): """Connection-oriented basic protocol for twisted, supports message serialization """ def __init__(self): super().__init__() self.opened = False self.db = DataBuffer() self.spam_protector = SpamProtector() def send_message(self, msg): """ Serialize and send message :param Message msg: message to send :return bool: return True if message has been send, False otherwise """ if not self.opened: logger.warning("Send message %s failed - connection closed", msg) return False try: msg_to_send = self._prepare_msg_to_send(msg) except golem_messages.exceptions.SerializationError: logger.exception('Cannot serialize message: %s', msg) raise if msg_to_send is None: return False self.transport.getHandle() self.transport.write(msg_to_send) return True def close(self): """ Close connection, after writing all pending (flush the write buffer and wait for producer to finish). :return None: """ self.transport.loseConnection() def close_now(self): """ Close connection ASAP, doesn't flush the write buffer or wait for the producer to finish :return: """ self.opened = False self.transport.abortConnection() # Protocol functions def connectionMade(self): """Called when new connection is successfully opened""" SessionProtocol.connectionMade(self) self.opened = True def dataReceived(self, data): """Called when additional chunk of data is received from another peer""" if not self._can_receive(): return if not self.session: logger.warning("No session argument in connection state") return self._interpret(data) def connectionLost(self, reason=connectionDone): """Called when connection is lost (for whatever reason)""" self.opened = False if self.session: self.session.dropped() SessionProtocol.connectionLost(self, reason) # Protected functions def _prepare_msg_to_send(self, msg): ser_msg = golem_messages.dump(msg, None, None) db = DataBuffer() db.append_len_prefixed_bytes(ser_msg) return db.read_all() def _can_receive(self) -> bool: return self.opened and isinstance(self.db, DataBuffer) def _interpret(self, data): self.session.last_message_time = time.time() self.db.append_bytes(data) mess = self._data_to_messages() for m in mess: self.session.interpret(m) def _load_message(self, data): msg = golem_messages.load(data, None, None) logger.debug( 'BasicProtocol._load_message(): received %r', msg, ) return msg def _data_to_messages(self): messages = [] for data in self.db.get_len_prefixed_bytes(): if len(data) > MAX_MESSAGE_SIZE: logger.info( 'Ignoring huge message %dB from %r', len(data), self.transport.getPeer(), ) continue try: if not self.spam_protector.check_msg(data): continue msg = self._load_message(data) except golem_messages.exceptions.HeaderError as e: logger.debug( "Invalid message header: %s from %s. Ignoring.", e, self.transport.getPeer(), ) continue except golem_messages.exceptions.VersionMismatchError as e: logger.debug( "Message version mismatch: %s from %s. Closing.", e, self.transport.getPeer(), ) msg = message.base.Disconnect( reason=message.base.Disconnect.REASON.ProtocolVersion, ) self.send_message(msg) self.close() return [] except golem_messages.exceptions.MessageError as e: logger.info("Failed to deserialize message (%r) %r", e, data) logger.debug( "BasicProtocol._data_to_messages() failed %r", data, exc_info=True, ) continue messages.append(msg) return messages