コード例 #1
0
ファイル: test_network.py プロジェクト: scorpilix/Golemtest
 def test_send_and_receive_message(self):
     p = BasicProtocol()
     p.transport = Transport()
     session_factory = SessionFactory(ASession)
     p.set_session_factory(session_factory)
     self.assertFalse(p.send_message("123"))
     msg = MessageHello()
     self.assertFalse(p.send_message(msg))
     p.connectionMade()
     self.assertTrue(p.send_message(msg))
     self.assertEqual(len(p.transport.buff), 1)
     p.dataReceived(p.transport.buff[0])
     self.assertIsInstance(p.session.msgs[0], MessageHello)
     self.assertEquals(msg.timestamp, p.session.msgs[0].timestamp)
     time.sleep(1)
     msg = MessageHello()
     self.assertNotEquals(msg.timestamp, p.session.msgs[0].timestamp)
     self.assertTrue(p.send_message(msg))
     self.assertEqual(len(p.transport.buff), 2)
     db = DataBuffer()
     db.append_string(p.transport.buff[1])
     m = Message.deserialize(db)[0]
     self.assertEqual(m.timestamp, msg.timestamp)
     p.connectionLost()
     self.assertNotIn('session', p.__dict__)
コード例 #2
0
    def test_get_resource(self):
        conn = BasicProtocol()
        conn.transport = Mock()
        conn.server = Mock()

        db = DataBuffer()

        sess = TaskSession(conn)
        sess.send = lambda m: db.append_string(m.serialize())
        sess._can_send = lambda *_: True
        sess.request_resource(str(uuid.uuid4()), TaskResourceHeader("tmp"))

        assert Message.deserialize_message(db.buffered_data)
コード例 #3
0
    def _prepare_msg_to_send(self, msg):
        if self.session is None:
            logger.error("Wrong session, not sending message")
            return None

        msg = self.session.sign(msg)
        if not msg:
            logger.error("Wrong session, not sending message")
            return None
        ser_msg = msg.serialize()
        enc_msg = self.session.encrypt(ser_msg)

        db = DataBuffer()
        db.append_len_prefixed_string(enc_msg)
        return db.read_all()
コード例 #4
0
ファイル: test_message.py プロジェクト: scorpilix/Golemtest
    def test_decrypt_and_deserialize(self):
        db = DataBuffer()
        server = mock.Mock()
        n_messages = 10

        def serialize_messages(_b):
            for m in [message.MessageHello() for _ in xrange(0, n_messages)]:
                m.serialize_to_buffer(_b)

        serialize_messages(db)
        server.decrypt = lambda x: x
        assert len(message.Message.decrypt_and_deserialize(db, server)) == n_messages

        patch_method = 'golem.network.transport.message.Message.deserialize_message'
        with mock.patch(patch_method, side_effect=lambda x: None):
            serialize_messages(db)
            assert len(message.Message.decrypt_and_deserialize(db, server)) == 0

        def raise_assertion(*_):
            raise AssertionError()

        def raise_error(*_):
            raise Exception()

        server.decrypt = raise_assertion
        serialize_messages(db)

        result = message.Message.decrypt_and_deserialize(db, server)

        assert len(result) == n_messages
        assert all(not m.encrypted for m in result)

        server.decrypt = raise_error
        serialize_messages(db)

        result = message.Message.decrypt_and_deserialize(db, server)

        assert len(result) == 0
コード例 #5
0
ファイル: test_message.py プロジェクト: scorpilix/Golemtest
    def test_timestamp_and_timezones(self):
        epoch_t = 1475238345.0

        def set_tz(tz):
            os.environ['TZ'] = tz
            try:
                time.tzset()
            except AttributeError:
                raise unittest.SkipTest("tzset required")

        set_tz('Europe/Warsaw')
        warsaw_time = time.localtime(epoch_t)
        m = message.MessageHello(timestamp=epoch_t)
        db = DataBuffer()
        m.serialize_to_buffer(db)
        set_tz('US/Eastern')
        server = mock.Mock()
        server.decrypt = lambda x: x
        msgs = message.Message.decrypt_and_deserialize(db, server)
        assert len(msgs) == 1
        newyork_time = time.localtime(msgs[0].timestamp)
        assert warsaw_time != newyork_time
        assert time.gmtime(epoch_t) == time.gmtime(msgs[0].timestamp)
コード例 #6
0
    def _prepare_msg_to_send(self, msg):
        ser_msg = msg.serialize()

        db = DataBuffer()
        db.append_len_prefixed_string(ser_msg)
        return db.read_all()
コード例 #7
0
 def __init__(self):
     self.opened = False
     self.db = DataBuffer()
     self.lock = Lock()
     SessionProtocol.__init__(self)
コード例 #8
0
class BasicProtocol(SessionProtocol):

    """ Connection-oriented basic protocol for twisted, support message serialization"""
    def __init__(self):
        self.opened = False
        self.db = DataBuffer()
        self.lock = Lock()
        SessionProtocol.__init__(self)

    def send_message(self, msg):
        """
        Serialize and send message
        :param Message msg: message to send
        :return bool: return True if message has been send, False if an error has
        """
        if not self.opened:
            logger.error(msg)
            logger.error("Send message failed - connection closed.")
            return False

        msg_to_send = self._prepare_msg_to_send(msg)

        if msg_to_send is None:
            return False

        self.transport.getHandle()
        self.transport.write(msg_to_send)

        return True

    def close(self):
        """
        Close connection, after writing all pending  (flush the write buffer and wait for producer to finish).
        :return None:
        """
        self.transport.loseConnection()

    def close_now(self):
        """
        Close connection ASAP, doesn't flush the write buffer or wait for the producer to finish
        :return:
        """
        self.opened = False
        self.transport.abortConnection()

    # Protocol functions
    def connectionMade(self):
        """Called when new connection is successfully opened"""
        SessionProtocol.connectionMade(self)
        self.opened = True

    def dataReceived(self, data):
        """Called when additional chunk of data is received from another peer"""
        if not self._can_receive():
            return None

        if not self.session:
            logger.warning("No session argument in connection state")
            return None

        self._interpret(data)

    def connectionLost(self, reason=connectionDone):
        """Called when connection is lost (for whatever reason)"""
        self.opened = False
        if self.session:
            self.session.dropped()

        SessionProtocol.connectionLost(self, reason)

    # Protected functions
    def _prepare_msg_to_send(self, msg):
        ser_msg = msg.serialize()

        db = DataBuffer()
        db.append_len_prefixed_string(ser_msg)
        return db.read_all()

    def _can_receive(self):
        return self.opened and isinstance(self.db, DataBuffer)

    def _interpret(self, data):
        with self.lock:
            self.db.append_string(data)
            mess = self._data_to_messages()

        # Interpret messages
        if mess:
            for m in mess:
                self.session.interpret(m)
        elif data:
            logger.info("Deserialization of messages from {}:{} failed, maybe it's still "
                        "too short?".format(self.session.address, self.session.port))

    def _data_to_messages(self):
        return Message.deserialize(self.db)
コード例 #9
0
ファイル: tcpnetwork.py プロジェクト: U0001F3A2/golem
    def _prepare_msg_to_send(self, msg):
        ser_msg = golem_messages.dump(msg, None, None)

        db = DataBuffer()
        db.append_len_prefixed_bytes(ser_msg)
        return db.read_all()
コード例 #10
0
ファイル: tcpnetwork.py プロジェクト: U0001F3A2/golem
 def __init__(self):
     super().__init__()
     self.opened = False
     self.db = DataBuffer()
     self.spam_protector = SpamProtector()
コード例 #11
0
ファイル: tcpnetwork.py プロジェクト: U0001F3A2/golem
class BasicProtocol(SessionProtocol):
    """Connection-oriented basic protocol for twisted, supports message
       serialization
    """
    def __init__(self):
        super().__init__()
        self.opened = False
        self.db = DataBuffer()
        self.spam_protector = SpamProtector()

    def send_message(self, msg):
        """
        Serialize and send message
        :param Message msg: message to send
        :return bool: return True if message has been send, False otherwise
        """
        if not self.opened:
            logger.warning("Send message %s failed - connection closed", msg)
            return False

        try:
            msg_to_send = self._prepare_msg_to_send(msg)
        except golem_messages.exceptions.SerializationError:
            logger.exception('Cannot serialize message: %s', msg)
            raise

        if msg_to_send is None:
            return False

        self.transport.getHandle()
        self.transport.write(msg_to_send)

        return True

    def close(self):
        """
        Close connection, after writing all pending
        (flush the write buffer and wait for producer to finish).
        :return None:
        """
        self.transport.loseConnection()

    def close_now(self):
        """
        Close connection ASAP, doesn't flush the write buffer or wait for
        the producer to finish
        :return:
        """
        self.opened = False
        self.transport.abortConnection()

    # Protocol functions
    def connectionMade(self):
        """Called when new connection is successfully opened"""
        SessionProtocol.connectionMade(self)
        self.opened = True

    def dataReceived(self, data):
        """Called when additional chunk of data
            is received from another peer"""
        if not self._can_receive():
            return

        if not self.session:
            logger.warning("No session argument in connection state")
            return

        self._interpret(data)

    def connectionLost(self, reason=connectionDone):
        """Called when connection is lost (for whatever reason)"""
        self.opened = False
        if self.session:
            self.session.dropped()

        SessionProtocol.connectionLost(self, reason)

    # Protected functions
    def _prepare_msg_to_send(self, msg):
        ser_msg = golem_messages.dump(msg, None, None)

        db = DataBuffer()
        db.append_len_prefixed_bytes(ser_msg)
        return db.read_all()

    def _can_receive(self) -> bool:
        return self.opened and isinstance(self.db, DataBuffer)

    def _interpret(self, data):
        self.session.last_message_time = time.time()
        self.db.append_bytes(data)
        mess = self._data_to_messages()
        for m in mess:
            self.session.interpret(m)

    def _load_message(self, data):
        msg = golem_messages.load(data, None, None)
        logger.debug(
            'BasicProtocol._load_message(): received %r',
            msg,
        )
        return msg

    def _data_to_messages(self):
        messages = []

        for data in self.db.get_len_prefixed_bytes():
            if len(data) > MAX_MESSAGE_SIZE:
                logger.info(
                    'Ignoring huge message %dB from %r',
                    len(data),
                    self.transport.getPeer(),
                )
                continue

            try:
                if not self.spam_protector.check_msg(data):
                    continue
                msg = self._load_message(data)
            except golem_messages.exceptions.HeaderError as e:
                logger.debug(
                    "Invalid message header: %s from %s. Ignoring.",
                    e,
                    self.transport.getPeer(),
                )
                continue
            except golem_messages.exceptions.VersionMismatchError as e:
                logger.debug(
                    "Message version mismatch: %s from %s. Closing.",
                    e,
                    self.transport.getPeer(),
                )
                msg = message.base.Disconnect(
                    reason=message.base.Disconnect.REASON.ProtocolVersion, )
                self.send_message(msg)
                self.close()
                return []
            except golem_messages.exceptions.MessageError as e:
                logger.info("Failed to deserialize message (%r) %r", e, data)
                logger.debug(
                    "BasicProtocol._data_to_messages() failed %r",
                    data,
                    exc_info=True,
                )
                continue

            messages.append(msg)

        return messages