Esempio n. 1
0
class AsyncioAuthenticator(Authenticator):
    """ZAP authentication for use in the asyncio IO loop"""

    def __init__(self, context=None, loop=None):
        super().__init__(context)
        self.loop = loop or asyncio.get_event_loop()
        self.__poller = None
        self.__task = None

    @asyncio.coroutine
    def __handle_zap(self):
        while True:
            events = yield from self.__poller.poll()
            if self.zap_socket in dict(events):
                msg = yield from self.zap_socket.recv_multipart()
                self.handle_zap_message(msg)

    def start(self):
        """Start ZAP authentication"""
        super().start()
        self.__poller = Poller()
        self.__poller.register(self.zap_socket, zmq.POLLIN)
        self.__task = asyncio.ensure_future(self.__handle_zap())

    def stop(self):
        """Stop ZAP authentication"""
        if self.__task:
            self.__task.cancel()
        if self.__poller:
            self.__poller.unregister(self.zap_socket)
            self.__poller = None
        super().stop()
Esempio n. 2
0
def run_worker(context):
    poller = Poller()

    liveness = HEARTBEAT_LIVENESS
    interval = INTERVAL_INIT

    heartbeat_at = time.time() + HEARTBEAT_INTERVAL

    worker = yield from worker_socket(context, poller)
    cycles = 0
    while True:
        socks = yield from poller.poll(HEARTBEAT_INTERVAL * 1000)
        socks = dict(socks)

        # Handle worker activity on backend
        if socks.get(worker) == zmq.POLLIN:
            #  Get message
            #  - 3-part envelope + content -> request
            #  - 1-part HEARTBEAT -> heartbeat
            frames = yield from worker.recv_multipart()
            if not frames:
                break    # Interrupted

            if len(frames) == 3:
                # Simulate various problems, after a few cycles
                cycles += 1
                if cycles > 3 and randint(0, 5) == 0:
                    print("I: Simulating a crash")
                    break
                if cycles > 3 and randint(0, 5) == 0:
                    print("I: Simulating CPU overload")
                    yield from asyncio.sleep(3)
                print("I: Normal reply")
                yield from worker.send_multipart(frames)
                liveness = HEARTBEAT_LIVENESS
                yield from asyncio.sleep(1)  # Do some heavy work
            elif len(frames) == 1 and frames[0] == PPP_HEARTBEAT:
                print("I: Queue heartbeat")
                liveness = HEARTBEAT_LIVENESS
            else:
                print("E: Invalid message: %s" % frames)
            interval = INTERVAL_INIT
        else:
            liveness -= 1
            if liveness == 0:
                print("W: Heartbeat failure, can't reach queue")
                print("W: Reconnecting in %0.2fs..." % interval)
                yield from asyncio.sleep(interval)

                if interval < INTERVAL_MAX:
                    interval *= 2
                poller.unregister(worker)
                worker.setsockopt(zmq.LINGER, 0)
                worker.close()
                worker = yield from worker_socket(context, poller)
                liveness = HEARTBEAT_LIVENESS
        if time.time() > heartbeat_at:
            heartbeat_at = time.time() + HEARTBEAT_INTERVAL
            print("I: Worker heartbeat")
            yield from worker.send(PPP_HEARTBEAT)
Esempio n. 3
0
class AsyncioAuthenticator(Authenticator):
    """ZAP authentication for use in the asyncio IO loop"""
    def __init__(self, context=None, loop=None):
        super().__init__(context)
        self.loop = loop or asyncio.get_event_loop()
        self.__poller = None
        self.__task = None

    @asyncio.coroutine
    def __handle_zap(self):
        while True:
            events = yield from self.__poller.poll()
            if self.zap_socket in dict(events):
                msg = yield from self.zap_socket.recv_multipart()
                self.handle_zap_message(msg)

    def start(self):
        """Start ZAP authentication"""
        super().start()
        self.__poller = Poller()
        self.__poller.register(self.zap_socket, zmq.POLLIN)
        self.__task = asyncio.ensure_future(self.__handle_zap())

    def stop(self):
        """Stop ZAP authentication"""
        if self.__task:
            self.__task.cancel()
        if self.__poller:
            self.__poller.unregister(self.zap_socket)
            self.__poller = None
        super().stop()
Esempio n. 4
0
class Server(Killable):

    _event_class = asyncio.Event

    def __init__(self, address):
        super(Server, self).__init__()
        self.address = address
        self.queue = asyncio.Queue()

        self.context = Context.instance()
        self.socket = self.context.socket(zmq.PULL)
        self.poller = Poller()
        self._listen_future = None

    async def _receive_into_queue(self):
        while self.alive:
            try:
                events = await self.poller.poll(timeout=1e-4)
                if self.socket in dict(events):
                    data = await self.socket.recv()
                    await self.queue.put(data)
            except zmq.error.ZMQError:
                await asyncio.sleep(1e-4)

    async def start(self):
        self.socket.bind('tcp://%s:%s' % self.address)
        self.poller.register(self.socket, zmq.POLLIN)
        self._listen_future = asyncio.ensure_future(self._receive_into_queue())
        await asyncio.sleep(0)

    async def shutdown(self):
        self.kill()
        self.poller.unregister(self.socket)
        if self._listen_future is not None and not self._listen_future.done():
            self._listen_future.cancel()
        self.socket.close(linger=0)
        await asyncio.sleep(0)

    def iter_messages(self):

        async def wrapped():
            while self.alive:
                try:
                    result = self.queue.get_nowait()
                except asyncio.QueueEmpty:
                    await asyncio.sleep(1e-6)
                else:
                    await asyncio.sleep(0)
                    return result
        return AsyncGenerator(wrapped)
Esempio n. 5
0
def run_client(context):
    print("I: Connecting to server...")
    client = context.socket(zmq.REQ)
    client.connect(SERVER_ENDPOINT)
    poll = Poller()
    poll.register(client, zmq.POLLIN)
    sequence = 0
    retries_left = REQUEST_RETRIES
    while retries_left:
        sequence += 1
        request = str(sequence)
        print("I: Sending (%s)" % request)
        yield from client.send_string(request)
        expect_reply = True
        while expect_reply:
            socks = yield from poll.poll(REQUEST_TIMEOUT)
            socks = dict(socks)
            if socks.get(client) == zmq.POLLIN:
                reply = yield from client.recv()
                if not reply:
                    break
                if int(reply) == sequence:
                    print("I: Server replied OK (%s)" % reply)
                    retries_left = REQUEST_RETRIES
                    expect_reply = False
                else:
                    print("E: Malformed reply from server: %s" % reply)
            else:
                print("W: No response from server, retrying...")
                # Socket is confused. Close and remove it.
                print('W: confused')
                client.setsockopt(zmq.LINGER, 0)
                client.unbind(SERVER_ENDPOINT)
                #client.close()
                poll.unregister(client)
                retries_left -= 1
                if retries_left == 0:
                    print("E: Server seems to be offline, abandoning")
                    return
                print("I: Reconnecting and resending (%s)" % request)
                # Create new connection
                client = context.socket(zmq.REQ)
                client.connect(SERVER_ENDPOINT)
                poll.register(client, zmq.POLLIN)
                yield from client.send_string(request)
Esempio n. 6
0
def run_client(context):
    print("I: Connecting to server...")
    client = context.socket(zmq.REQ)
    client.connect(SERVER_ENDPOINT)
    poll = Poller()
    poll.register(client, zmq.POLLIN)
    sequence = 0
    retries_left = REQUEST_RETRIES
    while retries_left:
        sequence += 1
        request = str(sequence)
        print("I: Sending (%s)" % request)
        yield from client.send_string(request)
        expect_reply = True
        while expect_reply:
            socks = yield from poll.poll(REQUEST_TIMEOUT)
            socks = dict(socks)
            if socks.get(client) == zmq.POLLIN:
                reply = yield from client.recv()
                if not reply:
                    break
                if int(reply) == sequence:
                    print("I: Server replied OK (%s)" % reply)
                    retries_left = REQUEST_RETRIES
                    expect_reply = False
                else:
                    print("E: Malformed reply from server: %s" % reply)
            else:
                print("W: No response from server, retrying...")
                # Socket is confused. Close and remove it.
                print('W: confused')
                client.setsockopt(zmq.LINGER, 0)
                client.unbind(SERVER_ENDPOINT)
                #client.close()
                poll.unregister(client)
                retries_left -= 1
                if retries_left == 0:
                    print("E: Server seems to be offline, abandoning")
                    return
                print("I: Reconnecting and resending (%s)" % request)
                # Create new connection
                client = context.socket(zmq.REQ)
                client.connect(SERVER_ENDPOINT)
                poll.register(client, zmq.POLLIN)
                yield from client.send_string(request)
Esempio n. 7
0
class QWeatherClient:
    """Client class for the QWeather messaging framework"""
    class serverclass:
        """Support class to represent the available servers as objects, with their exposed functions as callable attributes. The __repr__ makes it look like they are server objects"""
        def __init__(self,name,addr,methods,client):
            self.name = name
            self.addr = addr
            self.client = client
            for amethod in methods:
                setattr(self,amethod[0],self.bindingfunc(amethod[0],amethod[1]))


        def bindingfunc(self,methodname,methoddoc):
            """Ensures that "calling" the attribute of the "server"object with the name of a server function, sends a request to the server to execute that function and return the response"""
            def func(*args,**kwargs):
                timeout = kwargs.pop('timeout',CSYNCTIMEOUT) # This pops the value for timeout if it exists in kwargs, or returns the default timeout value. So this saves a line of code on logic check
                return self.client.send_request([self.name.encode(),methodname.encode(),pickle.dumps([args,kwargs])],timeout=timeout)
            func.__name__ = methodname
            func.__doc__ = methoddoc
            func.__repr__ = lambda: methoddoc
            func.is_remote_server_method = True
            return func


        def __repr__(self):
            msg = ""
            lst = [getattr(self,method) for method in dir(self) if getattr(getattr(self,method),'is_remote_server_method',False)]
            if len(lst) == 0:
                return 'No servers connected'
            else:
                for amethod in lst:
                    msg += amethod.__name__ +"\n"
            return msg.strip()
    

    context = None
    socket = None
    poller = None
    futureobjectdict = {}

    def __init__(self,QWeatherStationIP,name = None,loop = None,debug=False,verbose=False):
        IpAndPort = re.search(IPREPATTERN,QWeatherStationIP)
        assert IpAndPort != None, 'Ip not understood (tcp://xxx.xxx.xxx.xxx:XXXX or txp://localhost:XXXX)'
        self.QWeatherStationIP = IpAndPort.group(1)
        self.QWeatherStationSocket = IpAndPort.group(2)
        assert self.QWeatherStationIP[:6] == 'tcp://', 'Ip not understood (tcp://xxx.xxx.xxx.xxx:XXXX or txp://localhost:XXXX)'
        assert len(self.QWeatherStationSocket) == 4, 'Port not understood (tcp://xxx.xxx.xxx.xxx:XXXX or txp://localhost:XXXX)'
        if loop is None:
            self.loop = asyncio.get_event_loop()
        else:
            self.loop = loop

        if name is None:
            import socket
            name = socket.gethostname()

        formatting = '{:}: %(levelname)s: %(message)s'.format(name)
        if debug:
            logging.basicConfig(format=formatting,level=logging.DEBUG)
        if verbose:
            logging.basicConfig(format=formatting,level=logging.INFO)
        self.name = name.encode()
        self.reconnect()
#        self.ping_broker()
        self.loop.run_until_complete(self.get_server_info())
        self.running = False
        self.messageid = 0
        atexit.register(self.close)


    def reconnect(self):
        '''connects or reconnects to the broker'''
        if self.poller:
            self.poller.unregister(self.socket)
        if self.socket: 
            self.socket.close()
        self.context = Context()
        self.socket = self.context.socket(zmq.DEALER)
        self.socket.connect(self.QWeatherStationIP + ':' + self.QWeatherStationSocket)
        self.subsocket = self.context.socket(zmq.SUB)
        self.subsocket.connect(self.QWeatherStationIP + ':' + str(int(self.QWeatherStationSocket) + SUBSOCKET))

        self.poller = Poller()
        self.poller.register(self.socket,zmq.POLLIN)
        self.poller.register(self.subsocket,zmq.POLLIN)

    def subscribe(self,servername,function):
        """Subscribe to a server with a callback function"""
        self.subsocket.setsockopt(zmq.SUBSCRIBE,servername.encode())
        self.subscribers[servername] = function

    def unsubscribe(self,servername):
        """Unsubscribe from a server"""
        self.subsocket.setsockopt(zmq.UNSUBSCRIBE,servername.encode())
        self.subscribers.pop(servername)
        
    
    async def get_server_info(self):
        """Get information about servers from the broker"""
        msg = [b'',b'C',CREADY,PCLIENT,self.name]
        self.send_message(msg)
        msg = await self.recieve_message()
        empty = msg.pop(0)
        assert empty == b''
        command = msg.pop(0)
        self.serverlist = []
        self.subscribers = {}
        if command == CREADY + CFAIL:
            raise Exception(msg.pop(0).decode())
        else:
            serverdict = pickle.loads(msg.pop(0))
            servermethoddict = pickle.loads(msg.pop(0))
            for addr,name in serverdict.items():
                methods = servermethoddict[addr]
                server = self.serverclass(name,addr,methods,self)
                server.is_remote_server = True
                setattr(self,name,server)
                self.serverlist.append(server)

    

    def send_request(self,body,timeout):
        """Send a request. If the client is running (i.e. in async mode) send an async request, else send a synchronous request\n
        Attach a messageID to each request. (0-255)"""
        self.messageid+=1
        if self.messageid > 255:
            self.messageid = 0
        if self.running:
            result =  asyncio.get_event_loop().create_task(self.async_send_request(body,self.messageid.to_bytes(1,'big')))
        else:
            result = self.sync_send_request(body,self.messageid.to_bytes(1,'big'),timeout)
        return result

    def ping_broker(self):
        """Ping the broker"""
        self.send_message([b'',b'P'])
        try:
            if len(self.loop.run_until_complete(self.poller.poll(timeout=2000))) == 0: #wait 2 seconds for a ping from the broker
                raise Exception('QWeatherStation not found')
            else:
                msg =  self.loop.run_until_complete(self.recieve_message())
                empty = msg.pop(0)
                pong = msg.pop(0)

                logging.debug('Recieved Pong: {:}'.format(pong))
                if pong != b'b':
                    raise Exception('QWeatherStation sent wrong Pong')              

        except Exception as e:
            self.poller.unregister(self.socket)
            self.socket.close()
            raise e
        

    def sync_send_request(self,body,ident,timeout):
        """Synchronously send request. Timeout with the default timeoutvalue [FINDOUTHOWTOLINKTOTHECONSTANTSPAGETOSHOWDEFAULTVALUE]"""
        msg = [b'',b'C',CREQUEST,ident]  + body
        server = body[0]
        self.send_message(msg)
        if len(self.loop.run_until_complete(self.poller.poll(timeout=timeout))) == 0:
            return Exception('Synchronous request timed out. Try adding following keyword to function call: "timeout=XX" in ms')
        else:
            msg = self.loop.run_until_complete(self.recieve_message())
            empty = msg.pop(0)
            assert empty == b''
            command = msg.pop(0)
            ident = msg.pop(0)
            server = msg.pop(0)
            answ = pickle.loads(msg[0])
            return answ
   
    async def async_send_request(self,body,ident):
        """Ansynchronously send request. No explicit timeout on the client side for this. Relies on the "servertimeout" on the broker side"""
        server = body[0]
        msg = [b'',b'C',CREQUEST,ident]  + body


        self.send_message(msg)
        answ = await self.recieve_future_message(ident+server) #Waits here until the future is set to completed
        self.futureobjectdict.pop(ident+server)
        return answ

    def send_message(self,msg):
        """Send a multi-frame-message over the ZMQ socket"""
        self.socket.send_multipart(msg)


    def recieve_future_message(self,id):
        """Create a future for the async request, add it to the dict of futures (id = messageid+server"""
        tmp = self.loop.create_future()
        self.futureobjectdict[id] = tmp
        return tmp

    async def recieve_message(self):
        """Recieve a multi-frame-message over the zmq socket"""
        msg = await self.socket.recv_multipart()
        return msg

    def handle_message(self,msg):
        """First step of handling an incoming message\n
        First asserts that the first frame is empty\n
        Then sorts the message into either request+success, request+fail or ping"""
        empty = msg.pop(0)
        assert empty == b''
        command = msg.pop(0)

        if command == CREQUEST + CSUCCESS:
            messageid = msg.pop(0)
            servername = msg.pop(0)
            msg = pickle.loads(msg[0])
            self.handle_request_success(messageid,servername,msg)

        elif command == CREQUEST + CFAIL:
            messageid = msg.pop(0)
            servername = msg.pop(0)
            self.handle_request_fail(messageid,servername)

        elif command == CPING:
            ping = msg.pop(0)
            if ping != b'P':
                raise Exception('QWeatherStation sent wrong ping')
            logging.debug('Recieved Ping from QWeatherStation')
            self.send_message([b'',b'b'])

    def handle_request_success(self,messageid,servername,msg):
        """Handle successful request by setting the result of the future (manually finishing the future)"""
        self.futureobjectdict[messageid + servername].set_result(msg)

    def handle_request_fail(self,messageid,servername):
        """Handle a failed request by setting the future to an exception"""
        self.futureobjectdict[messageid+server].set_exception(Exception(msg.pop(0)))

    def handle_broadcast(self,msg):
        """Handle a message on the broadcast socket by calling the callback function connected to the relevant server"""
        server= msg.pop(0).decode()
        msg = pickle.loads(msg.pop(0))
        self.subscribers[server](msg)

    async def run(self):
        """Asynchronously run the client by repeatedly polling the recieving socket"""
        self.running = True
        while True:
            try:
                socks = await self.poller.poll(1000)
                socks = dict(socks)
                if self.socket in socks:
                    msg = await self.recieve_message()
                    self.handle_message(msg)

                elif self.subsocket in socks:
                    msg = await self.recieve_message()
                    self.handle_broadcast(msg)

            except KeyboardInterrupt:
                self.close()
                break



    def close(self):
        """Closing function. Tells the broker that it disconnects. Is not called if the terminal is closed or the process is force-killed"""
        self.send_message([b'',b'C',CDISCONNECT])
        self.poller.unregister(self.socket)
        self.socket.close()



    def __repr__(self):
        msg = ""
        if len(self.serverlist) == 0:
            return 'No servers connected'
        else:
            for aserver in self.serverlist:
                msg += aserver.name + "\n"
        return msg.strip()

    def __iter__(self):
        return (aserv for aserv in self.serverlist)

    def __getitem__(self,key):
        return self.serverlist[key]
Esempio n. 8
0
 async def _cothread_user_socket_proxy(self, router_info: RouterInfo,
                                       lease: Lease, socket: Socket):
     current_physical_address: Optional[
         PhysicalAddress] = await self.get_socket_for_lease(
             router_info, lease)
     if not current_physical_address:
         socket.close()
         return
     poller = Poller()
     poller.register(current_physical_address.socket)
     poller.register(socket, zmq.POLLIN)
     buffer: List[Tuple[Socket, List[Frame]]] = []
     while True:
         if (not current_physical_address) or (
                 not current_physical_address.is_connected):
             if current_physical_address:
                 poller.unregister(current_physical_address.socket)
             current_physical_address = await self.get_socket_for_lease(
                 router_info, lease)
             if not current_physical_address:
                 socket.close()
                 return
             else:
                 poller.register(socket, zmq.POLLIN)
         if not current_physical_address.socket:
             socket.close()
             return
         if buffer:
             buffer_pointer = -1
             for target_socket, frames in buffer:
                 target_socket = (current_physical_address.socket
                                  if target_socket != socket else socket)
                 try:
                     await asyncio.wait_for(
                         target_socket.send_multipart(frames,
                                                      copy=False,
                                                      track=True),
                         len(frames) * 5,
                     )
                     buffer_pointer += 1
                 except asyncio.TimeoutError:
                     break
             if buffer_pointer != (len(buffer) - 1):
                 buffer = buffer[buffer_pointer:]
             else:
                 buffer = []
         else:
             pevents: List[Tuple[Socket, int]] = await poller.poll()
             for sock, ev in pevents:
                 target_sock = (current_physical_address.socket
                                if sock == socket else socket)
                 if ev & zmq.POLLIN:
                     frames = await sock.recv_multipart(copy=False)
                     try:
                         await asyncio.wait_for(
                             target_sock.send_multipart(frames,
                                                        copy=False,
                                                        track=True),
                             len(frames) * 5,
                         )
                     except asyncio.TimeoutError:
                         buffer.append((target_sock, frames))
Esempio n. 9
0
class QWeatherStation:
    """Central broker for the communcation done in QWeather"""
    def __init__(self, IP, loop=None, verbose=False, debug=False):
        if loop is None:
            #from zmq import Context,Poller
            #        import asyncio
            #       from zmq.asyncio import Context,Poller
            self.loop = asyncio.get_event_loop()
        else:
            self.loop = loop

        IpAndPort = re.search(IPREPATTERN, IP)
        assert IpAndPort != None, 'Ip not understood (tcp://xxx.xxx.xxx.xxx:XXXX or txp://*:XXXX)'
        self.StationIP = IpAndPort.group(1)
        self.StationSocket = IpAndPort.group(2)
        assert self.StationIP[:
                              6] == 'tcp://', 'Ip not understood (tcp://xxx.xxx.xxx.xxx:XXXX or txp://*:XXXX)'
        assert len(
            self.StationSocket
        ) == 4, 'Port not understood (tcp://xxx.xxx.xxx.xxx:XXXX or txp://*:XXXX)'
        formatting = '{:}: %(levelname)s: %(message)s'.format(
            'QWeatherStation')
        if debug:
            logging.basicConfig(format=formatting, level=logging.DEBUG)
        if verbose:
            logging.basicConfig(format=formatting, level=logging.INFO)
        self.servers = {}  # key:value = clientaddress:value, bytes:string
        self.clients = {}  # key:value = clientaddress:value, bytes:string
        self.servermethods = {}
        self.serverjobs = {}
        self.pinged = []
        self.requesttimeoutdict = {}
        self.cnx = Context()
        self.socket = self.cnx.socket(zmq.ROUTER)
        self.socket.bind(self.StationIP + ':' + self.StationSocket)
        self.proxy = ThreadProxy(zmq.XSUB, zmq.XPUB)
        self.proxy.bind_in(self.StationIP + ':' +
                           str(int(self.StationSocket) + PUBLISHSOCKET))
        self.proxy.bind_out(self.StationIP + ':' +
                            str(int(self.StationSocket) + SUBSOCKET))
        self.proxy.start()
        self.poller = Poller()
        self.poller.register(self.socket, zmq.POLLIN)

        logging.info('Ready to run on IP: {:}'.format(self.get_own_ip()))

    def get_own_ip(self):
        import socket
        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        try:
            # doesn't even have to be reachable
            s.connect(('10.255.255.255', 1))
            IP = s.getsockname()[0]
        except:
            IP = '127.0.0.1'
        finally:
            s.close()
        return IP

    async def async_run(self):
        """Ansynchronous run the broker by polling the socket repeatedly"""
        while True:
            try:
                items = await self.poller.poll(1000)
            except KeyboardInterrupt:
                self.close()
                break

            if items:
                msg = await self.recieve_message()
                self.handle_message(msg)

    def run(self):
        """Runs the broker, enabling message handling (blocking if called from a scrip)"""
        self.loop.run_until_complete(self.async_run())

    def close(self):
        """Closing function, called at exit"""
        self.poller.unregister(self.socket)
        self.socket.close()

    def handle_message(self, msg):
        """The first step of message handling.\n
        First assert that the second frame is empty\n
        Then process either [S]erver, [C]lient, [P]ing [b]pong, or [#] for executing broker functions
        """
        sender = msg.pop(0)
        if sender in self.clients.keys():
            logging.debug('Recieved message from {:}:\n{:}'.format(
                self.clients[sender], msg, '\n\n'))
        else:
            logging.debug('Recieved message from ID:{:}:\n{:}'.format(
                int.from_bytes(sender, byteorder='big'), msg, '\n\n'))
        empty = msg.pop(0)
        assert empty == b''
        SenderType = msg.pop(0)

        #Server
        if SenderType == b'S':
            command = msg.pop(0)  # 0xF? for server and 0x0? for client
            self.process_server(sender, command, msg)

        #Client
        elif (SenderType == b'C'):
            command = msg.pop(0)  # 0xF? for server and 0x0? for client
            self.process_client(sender, command, msg)

        #Ping
        elif SenderType == b'P':
            if sender in self.clients.keys():
                logging.debug('Recieved Ping from "{:}"'.format(
                    self.clients[sender]))
            else:
                logging.debug('Recieved Ping from ID:{:}'.format(
                    int.from_bytes(sender, byteorder='big')))

            self.socket.send_multipart(
                [sender, b'',
                 b'b'])  #Sending an upside down P (b) to indicate a pong

        #Pong
        elif SenderType == b'b':
            print('got a pong')
            logging.debug('Recieved Pong from ID:{:}'.format(
                int.from_bytes(sender, byteorder='big')))
            print(sender, self.pinged, sender in self.pinged)
            if sender in self.pinged:
                print('before', self.pinged)
                self.pinged.remove(sender)
                print('after', self.pinged)

        #Execute command
        elif SenderType == b'#':
            command = msg.pop(0)
            if command == b'P':  #request broker to ping all servers and remove old ones
                logging.debug('Ping of all servers requested')
                self.loop.create_task(self.ping_connections())
            elif command == b'R':  #requests the broker to "restart" by removing all connections
                for atask in self.requesttimeoutdict.items():
                    atask.cancel()
                self.requesttimeoutdict = {}
                self.servers = {}
                self.clients = {}

            if sender in self.clients.keys():
                logging.debug('Recieved Ping from "{:}"'.format(
                    self.clients[sender]))
            else:
                logging.debug('Recieved Ping from ID:{:}'.format(
                    int.from_bytes(sender, byteorder='big')))

        #SenderType not understood
        else:
            logging.info('Invalid message')

    def process_client(self, sender, command, msg):
        """Second stage of the message handling. Messages go here if they came from a client"""
        if command == CREADY:
            version = msg.pop(0)
            self.handle_client_ready(sender, version, msg)

        elif command == CREQUEST:
            messageid = msg.pop(0)
            servername = msg.pop(0).decode()
            self.handle_client_request(sender, messageid, servername, msg)

        elif command == CDISCONNECT:
            self.handle_client_disconnect()

    def handle_client_ready(self, sender, version, msg):
        """Check the client is using the same version of QWeather, add client to clientlist and send client list of servers and servermethods"""
        if not version == PCLIENT:
            newmsg = [
                sender, b'', CREADY + CFAIL,
                'Mismatch in protocol between client and broker'.encode()
            ]
        else:
            newmsg = [sender, b'', CREADY + CSUCCESS] + [
                pickle.dumps(self.servers)
            ] + [pickle.dumps(self.servermethods)]

            name = msg.pop(0).decode()
            self.clients[sender] = name
            logging.info('Client ready at ID:{:} name:{:}'.format(
                int.from_bytes(sender, byteorder='big'), self.clients[sender]))
        self.send_message(newmsg)

    def handle_client_request(self, sender, messageid, servername, msg):
        """Send a client request to the correct server. Add a timeout callback in case the server response timeouts"""
        try:
            #Find the server address in the server dict based on the name {address:name}
            serveraddr = next(key for key, value in self.servers.items()
                              if value == servername)
            #Create a timeout call which returns an exception if the reply from the server times out.
            timeout = self.loop.call_later(
                B_SERVERRESPONSE_TIMEOUT, self.send_message, [
                    sender, b'', CREQUEST + CFAIL, messageid,
                    servername.encode(),
                    pickle.dumps((Exception('Timeout error')))
                ])
            #Add the timeout to a dictionary so we can find it later (and cancel it before it times out)
            self.requesttimeoutdict[messageid + sender] = timeout

            msg = [serveraddr, b'', CREQUEST, messageid, sender] + msg
            #If the joblist for the requested server is empty, send it to the server, else add it to the serverjoblist for later execution
            if len(self.serverjobs[serveraddr]) == 0:
                self.send_message(msg)
                logging.debug('Client request from "{:}":\n{:}'.format(
                    self.clients[sender], msg))
            else:
                self.serverjobs[serveraddr].append(msg)
        except StopIteration as e:
            logging.debug('Trying to contact a server that does not exist')

    def handle_client_disconnect(self, sender):
        """Remove the client from the client dictionary"""
        logging.debug('Client "{:}" disconnecting'.format(
            self.clients[sender]))
        self.clients.pop(sender)

    def process_server(self, sender, command, msg):
        """Second stage of the message handling. Messages go here if they came from a server"""
        if command == CREADY:
            version = msg.pop(0)
            self.handle_server_ready(sender, version, msg)

        elif command == CREPLY:
            messageid = msg.pop(0)
            servername = self.servers[sender]
            clientaddr = msg.pop(0)
            answ = msg.pop(0)
            self.handle_server_reply(sender, messageid, servername, clientaddr,
                                     answ)

        elif command == SDISCONNECT:
            self.handle_server_disconnect(sender)

    def handle_server_ready(self, sender, version, msg):
        """Check the server is using the same version of QWeather.\n
        Add the server to the serverdict, add the methods to the servermethods dict, add an empty list to the serverjobs dict\n
        Keys for all 3 dicts are the serveraddress/id assigned by ZMQ (the first frame of every message recieved)"""
        if not version == PSERVER:
            newmsg = [
                sender, b'', CREADY + CFAIL,
                'Mismatch in protocol between server and broker'.encode()
            ]
        else:
            servername = msg.pop(0).decode()
            servermethods = pickle.loads(msg.pop(0))
            self.servers[sender] = servername
            self.servermethods[sender] = servermethods
            self.serverjobs[sender] = []
            newmsg = [sender, b'', CREADY + CSUCCESS]
            logging.info('Server "{:}" ready at: {:}'.format(
                servername, int.from_bytes(sender, byteorder='big')))
        self.send_message(newmsg)

    def handle_server_reply(self, sender, messageid, servername, clientaddr,
                            answer):
        """Forward the server reply to the client that requested it.\n
        Also cancel the timeout callback now that the server has replied in time\n
        If there are more jobs in the serverjob list for this server, send the oldest one to the server"""
        msg = [
            clientaddr, b'', CREQUEST + CSUCCESS, messageid,
            servername.encode(), answer
        ]
        try:
            #Cancel the timeout callback created when the request was sent ot the server
            timeouttask = self.requesttimeoutdict.pop(messageid + clientaddr)
            timeouttask.cancel()
            self.send_message(msg)
            logging.debug('Server answer to Client "{:}":\n{:}'.format(
                self.clients[clientaddr], msg))
            #If there are more requests in queue for the server, send the oldest one
            if len(self.serverjobs[sender]) > 0:
                self.send_message(self.serverjobs[sender].pop(0))
        except KeyError:
            print("Trying to send answer to client that does not exist")

    def handle_server_disconnect(self, sender):
        """Remove the server from the server, serverjobs and servermethods dictionaries"""
        logging.debug('Server  "{:}" disconnecting'.format(
            self.servers[sender]))
        self.servers.pop(sender)
        self.serverjobs.pop(sender)
        self.servermethods.pop(sender)

    def send_message(self, msg):
        """Send a multi-frame-message over the zmq socket"""
        self.socket.send_multipart(msg)

    async def recieve_message(self):
        """Recieve a multi-frame-message over the zmq socket (async)"""
        msg = await self.socket.recv_multipart()
        return msg

    async def ping_connections(self):
        """Ping all connections, then await 2 seconds and check if the pings responded"""
        self.__ping()
        await asyncio.sleep(2)
        self.__check_ping()

    def __ping(self):
        self.pinged = []
        for addresse in self.servers.keys():
            self.socket.send_multipart([addresse, b'', CPING, b'P'])
            self.pinged.append(addresse)

    def __check_ping(self):
        for aping in self.pinged:
            for aname, aserver in self.servers.items():
                if aping == aserver[0]:
                    break
            del self.servers[aname]
        print('servers:', self.servers)
        #        print(self.pinged)
        self.pinged = []

    def get_servers(self):
        """Return the server dictionary"""
        return self.servers

    def get_clients(self):
        """Return the client dictionary"""
        return self.clients
Esempio n. 10
0
class SchedulerConnection(object):
    __slots__ = (
        'address',
        # context object to open socket connections
        'context',
        # pull socket to receive check definitions from scheduler
        'pull',
        # poller object for `pull` socket
        'poller',
        # monitor socket for `pull` socket
        'monitor_socket',
        # poller object for monitor socket
        'monitor_poller',
        'first_missing',
    )

    def __init__(self, address):
        self.pull = self.poller = None
        self.monitor_poller = self.monitor_socket = None
        self.address = address
        self.context = Context.instance()
        self.open()
        self.first_missing = None

    def __str__(self):
        return self.address

    def __repr__(self):
        return 'Scheduler({})'.format(self.address)

    def open(self):
        self.pull = self.context.socket(PULL)
        logger.info('%s - opening pull socket ...', self)
        self.pull.connect(self.address)
        if settings.SCHEDULER_MONITOR:
            logger.info('%s - opening monitor socket ...', self)
            self.monitor_socket = self.pull.get_monitor_socket(
                events=EVENT_DISCONNECTED
            )
        self.register()

    def register(self):
        self.poller = Poller()
        self.poller.register(self.pull, POLLIN)
        if settings.SCHEDULER_MONITOR:
            self.monitor_poller = Poller()
            self.monitor_poller.register(self.monitor_socket, POLLIN)
        logger.info('%s - all sockets are successfully registered '
                    'in poller objects ...', self)

    def close(self):
        """Unregister open sockets from poller objects and close them."""
        self.unregister()
        logger.info('%s - closing open sockets ...', self)
        self.pull.close()
        if settings.SCHEDULER_MONITOR:
            self.monitor_socket.close()
        logger.info('%s - connection closed successfully ...', self)

    def unregister(self):
        """Unregister open sockets from poller object."""
        logger.info('%s - unregistering sockets from poller objects ...', self)
        self.poller.unregister(self.pull)
        if settings.SCHEDULER_MONITOR:
            self.monitor_poller.unregister(self.monitor_socket)

    def reconnect(self):
        self.close()
        self.open()
        self.first_missing = None

    @asyncio.coroutine
    def receive(self):
        check = None
        events = yield from self.poller.poll(timeout=2000)
        if self.pull in dict(events):
            check = yield from self.pull.recv_multipart()
            check = jsonapi.loads(check[0])

        if check:
            self.first_missing = None
        elif self.first_missing is None:
            self.first_missing = datetime.now(tz=pytz.utc)
        if self.first_missing:
            diff = datetime.now(tz=pytz.utc) - self.first_missing
            delta = timedelta(minutes=settings.SCHEDULER_LIVENESS_IN_MINUTES)
            if diff > delta:
                logger.warning(
                    'Alamo worker `%s` pid `%s` try to reconnect to '
                    '`%s` scheduler.',
                    settings.WORKER_FQDN, settings.WORKER_PID, self
                )
                self.reconnect()
        return check

    @asyncio.coroutine
    def receive_event(self):
        event = None
        events = yield from self.monitor_poller.poll(timeout=2000)
        if self.monitor_socket in dict(events):
            msg = yield from self.monitor_socket.recv_multipart(
                flags=NOBLOCK)
            event = parse_monitor_message(msg)
        return event
Esempio n. 11
0
class AsyncRdsBusClient(object):
    """
    RDS-BUS 客户端
    """
    ASC = 1
    DESC = -1

    def __init__(self, url, logger, request_timeout=None, database=None):
        self._logger = logger
        self._database = database
        self._context = Context.instance()
        self._poller = Poller()
        self._request = self._context.socket(zmq.DEALER)
        self._request_timeout = request_timeout or 60
        self._rds_bus_url = url
        self._request.connect(self._rds_bus_url)
        self._request_dict = dict()
        self._io_loop = asyncio.get_event_loop()
        self._running = False
        asyncio.ensure_future(self.start())

    @classmethod
    def pack(cls,
             database: str,
             key: str,
             parameter: dict,
             is_query: bool = False,
             order_by: list = None,
             page_no: int = None,
             per_page: int = None,
             found_rows: bool = False):
        """
        打包请求数据
        :param database: RDS-BUS的数据库类名
        :param key: 数据库类所持有的实例名
        :param parameter: 参数字典
        :param is_query: 是否为查询操作
        :param order_by: 排序信息 [{"column": "字段名", "order": AsyncRdsBusClient.ASC/AsyncRdsBusClient.DESC}]
        :param page_no: 当前页(范围[0-n) n指第n页)
        :param per_page: 每页记录数
        :param found_rows: 是否统计总数
        :return:
        """
        if is_query:
            amount = int(per_page) if per_page else None
            offset = int(page_no) * amount if page_no else None
            limit = (dict(amount=amount, offset=offset) if offset else dict(
                amount=amount)) if amount else None
            result = dict(command="{}/{}".format(database, key),
                          data=dict(var=parameter,
                                    order_by=order_by,
                                    limit=limit,
                                    found_rows=found_rows))
        else:
            result = dict(command="{}/{}".format(database, key),
                          data=dict(var=parameter))
        return result

    async def query(self,
                    key: str,
                    parameter: dict,
                    order_by: list = None,
                    page_no: int = None,
                    per_page: int = None,
                    found_rows: bool = False,
                    database: str = None,
                    execute: bool = True):
        """
        查询接口
        :param database: RDS-BUS的数据库类名
        :param key: 数据库类所持有的语句实例名
        :param parameter: 参数字典
        :param order_by: 排序信息 [{"column": "字段名", "order": AsyncRdsBusClient.ASC/AsyncRdsBusClient.DESC}]
        :param page_no: 当前页(范围[0-n) n指第n页)
        :param per_page: 每页记录数
        :param found_rows: 是否统计总数
        :param execute: 是否执行
        :return:
        """
        _database = database or self._database
        argument = self.pack(database=_database,
                             key=key,
                             parameter=parameter,
                             is_query=True,
                             order_by=order_by,
                             page_no=page_no,
                             per_page=per_page,
                             found_rows=found_rows)
        if execute:
            response = await self._send(operation=OperationType.QUERY,
                                        argument=argument)
            result = RdsData(response)
        else:
            result = argument
        return result

    async def insert(self,
                     key: str,
                     parameter: dict,
                     database: str = None,
                     execute: bool = True):
        """
        新增接口
        :param database: RDS-BUS的数据库类名
        :param key: 数据库类所持有的语句实例名
        :param parameter: 参数字典
        :param execute: 是否执行
        :return:
        """
        _database = database or self._database
        argument = self.pack(database=_database, key=key, parameter=parameter)
        if execute:
            response = await self._send(operation=OperationType.INSERT,
                                        argument=argument)
            result = RdsData(response)
        else:
            result = argument
        return result

    async def update(self,
                     key: str,
                     parameter: dict,
                     database: str = None,
                     execute: bool = True):
        """
        更新接口
        :param database: RDS-BUS的数据库类名
        :param key: 数据库类所持有的语句实例名
        :param parameter: 参数字典
        :param execute: 是否执行
        :return:
        """
        _database = database or self._database
        argument = self.pack(database=_database, key=key, parameter=parameter)
        if execute:
            response = await self._send(operation=OperationType.UPDATE,
                                        argument=argument)
            result = RdsData(response)
        else:
            result = argument
        return result

    async def delete(self,
                     key: str,
                     parameter: dict,
                     database: str = None,
                     execute: bool = False):
        """
        删除接口
        :param database: RDS-BUS的数据库类名
        :param key: 数据库类所持有的语句实例名
        :param parameter: 参数字典
        :param execute: 是否执行
        :return:
        """
        _database = database or self._database
        argument = self.pack(database=_database, key=key, parameter=parameter)
        if execute:
            response = await self._send(operation=OperationType.DELETE,
                                        argument=argument)
            result = RdsData(response)
        else:
            result = argument
        return result

    async def transaction(self, data: list, database: str = None):
        """
        事务接口
        :param database: RDS-BUS的数据库类名
        :param data: 操作列表
        :return:
        """
        _database = database or self._database
        result = await self._send(
            operation=OperationType.TRANSACTION,
            argument=dict(command="{}/transaction".format(_database),
                          data=data))
        return RdsListData(result)

    async def batch(self, data: list, database: str = None):
        """
        批量接口
        :param database: RDS-BUS的数据库类名
        :param data: 操作列表
        :return:
        """
        _database = database or self._database
        result = await self._send(operation=OperationType.BATCH,
                                  argument=dict(
                                      command="{}/batch".format(_database),
                                      data=data))
        return RdsListData(result)

    async def start(self):
        self._poller.register(self._request, zmq.POLLIN)
        self._running = True
        while True:
            events = await self._poller.poll()
            if self._request in dict(events):
                response = await self._request.recv_json()
                self._logger.debug("received {}".format(response))
                if response["id"] in self._request_dict:
                    future = self._request_dict.pop(response["id"])
                    if HttpResult.is_duplicate_data_failure(response["code"]):
                        future.set_exception(
                            DuplicateDataException.new_exception(
                                response["desc"]))
                    elif HttpResult.is_failure(response["code"]):
                        future.set_exception(
                            CallServiceException(method="ZMQ",
                                                 url=self._rds_bus_url,
                                                 errmsg=response["desc"]))
                    else:
                        future.set_result(response["data"])
                else:
                    self._logger.warning(
                        "unknown response {}".format(response))

    def stop(self):
        if self._running:
            self._poller.unregister(self._request)
            self._running = False

    def shutdown(self):
        self.stop()
        self._request.close()

    def _send(self, operation, argument):
        """

        :param operation:
        :param argument:
        :return:
        """
        request_id = get_unique_id()
        self._request_dict[request_id] = asyncio.Future()
        self._io_loop.call_later(self._request_timeout, self._session_timeout,
                                 request_id)
        self._request.send_multipart([
            json.dumps(
                dict(id=request_id,
                     operation=operation.value,
                     argument=argument)).encode("utf-8")
        ])
        return self._request_dict[request_id]

    def _session_timeout(self, request_id):
        if request_id in self._request_dict:
            future = self._request_dict.pop(request_id)
            future.set_exception(
                ServerTimeoutException(method="ZMQ", url=self._rds_bus_url))