예제 #1
0
    async def _reconnect(self):
        """
        Called when the remote server is innacessible and the connection has to be restarted
        """

        # 1. Close all transactions
        for msg_class in self._transactions:
            _1, _2, _3, coroutine_abrt, _4 = self._msgs_registered[msg_class]
            if coroutine_abrt is not None:
                for key in self._transactions[msg_class]:
                    for args, kwargs in self._transactions[msg_class][key]:
                        create_safe_task(self._loop, self._logger, coroutine_abrt(key, *args, **kwargs))
            self._transactions[msg_class] = {}

        # 2. Call on_disconnect
        await self._on_disconnect()

        # 3. Stop tasks
        for task in self._restartable_tasks:
            task.cancel()
        self._restartable_tasks = []

        # 4. Restart socket
        self._socket.disconnect(self._router_addr)

        # 5. Re-do start sequence
        await self.client_start()
예제 #2
0
 async def _run_socket(self):
     """
     Task that runs this client.
     """
     while True:
         try:
             message = await ZMQUtils.recv(self._socket)
             self._ping_count = 0  # restart ping count
             msg_class = message.__msgtype__
             if msg_class in self._handlers_registered:
                 # If a handler is registered, give the message to it
                 create_safe_task(self._loop, self._logger, self._handlers_registered[msg_class](message))
             elif msg_class in self._transactions:
                 # If there are transaction associated, check if the key is ok
                 _1, get_key, coroutine_recv, _2, responsible = self._msgs_registered[msg_class]
                 key = get_key(message)
                 if key in self._transactions[msg_class]:
                     # key exists; call all the coroutines
                     for args, kwargs in self._transactions[msg_class][key]:
                         create_safe_task(self._loop, self._logger, coroutine_recv(message, *args, **kwargs))
                     # remove all transaction parts
                     for key2 in responsible:
                         del self._transactions[key2][key]
                 else:
                     # key does not exist
                     raise Exception("Received message %s for an unknown transaction %s", msg_class, key)
             else:
                 raise Exception("Received unknown message %s", msg_class)
         except (asyncio.CancelledError, KeyboardInterrupt):
             return
         except:
             self._logger.exception("Exception while handling a message")
예제 #3
0
 async def handle_agent_message(self, agent_addr, message):
     """Dispatch messages received from agents to the right handlers"""
     message_handlers = {
         AgentHello: self.handle_agent_hello,
         AgentJobStarted: self.handle_agent_job_started,
         AgentJobDone: self.handle_agent_job_done,
         AgentJobSSHDebug: self.handle_agent_job_ssh_debug,
         Pong: self._handle_pong
     }
     try:
         func = message_handlers[message.__class__]
     except:
         raise TypeError("Unknown message type %s" % message.__class__)
     create_safe_task(self._loop, self._logger, func(agent_addr, message))
예제 #4
0
    async def handle_client_message(self, client_addr, message):
        """Dispatch messages received from clients to the right handlers"""

        # Verify that the client is registered
        if message.__class__ != ClientHello and client_addr not in self._registered_clients:
            await ZMQUtils.send_with_addr(self._client_socket, client_addr, Unknown())
            return

        message_handlers = {
            ClientHello: self.handle_client_hello,
            ClientNewJob: self.handle_client_new_job,
            ClientKillJob: self.handle_client_kill_job,
            ClientGetQueue: self.handle_client_get_queue,
            Ping: self.handle_client_ping
        }
        try:
            func = message_handlers[message.__class__]
        except:
            raise TypeError("Unknown message type %s" % message.__class__)
        create_safe_task(self._loop, self._logger, func(client_addr, message))