class AsyncioAuthenticator(Authenticator): """ZAP authentication for use in the asyncio IO loop""" def __init__(self, context=None, loop=None): super().__init__(context) self.loop = loop or asyncio.get_event_loop() self.__poller = None self.__task = None @asyncio.coroutine def __handle_zap(self): while True: events = yield from self.__poller.poll() if self.zap_socket in dict(events): msg = yield from self.zap_socket.recv_multipart() self.handle_zap_message(msg) def start(self): """Start ZAP authentication""" super().start() self.__poller = Poller() self.__poller.register(self.zap_socket, zmq.POLLIN) self.__task = asyncio.ensure_future(self.__handle_zap()) def stop(self): """Stop ZAP authentication""" if self.__task: self.__task.cancel() if self.__poller: self.__poller.unregister(self.zap_socket) self.__poller = None super().stop()
def run_worker(context): poller = Poller() liveness = HEARTBEAT_LIVENESS interval = INTERVAL_INIT heartbeat_at = time.time() + HEARTBEAT_INTERVAL worker = yield from worker_socket(context, poller) cycles = 0 while True: socks = yield from poller.poll(HEARTBEAT_INTERVAL * 1000) socks = dict(socks) # Handle worker activity on backend if socks.get(worker) == zmq.POLLIN: # Get message # - 3-part envelope + content -> request # - 1-part HEARTBEAT -> heartbeat frames = yield from worker.recv_multipart() if not frames: break # Interrupted if len(frames) == 3: # Simulate various problems, after a few cycles cycles += 1 if cycles > 3 and randint(0, 5) == 0: print("I: Simulating a crash") break if cycles > 3 and randint(0, 5) == 0: print("I: Simulating CPU overload") yield from asyncio.sleep(3) print("I: Normal reply") yield from worker.send_multipart(frames) liveness = HEARTBEAT_LIVENESS yield from asyncio.sleep(1) # Do some heavy work elif len(frames) == 1 and frames[0] == PPP_HEARTBEAT: print("I: Queue heartbeat") liveness = HEARTBEAT_LIVENESS else: print("E: Invalid message: %s" % frames) interval = INTERVAL_INIT else: liveness -= 1 if liveness == 0: print("W: Heartbeat failure, can't reach queue") print("W: Reconnecting in %0.2fs..." % interval) yield from asyncio.sleep(interval) if interval < INTERVAL_MAX: interval *= 2 poller.unregister(worker) worker.setsockopt(zmq.LINGER, 0) worker.close() worker = yield from worker_socket(context, poller) liveness = HEARTBEAT_LIVENESS if time.time() > heartbeat_at: heartbeat_at = time.time() + HEARTBEAT_INTERVAL print("I: Worker heartbeat") yield from worker.send(PPP_HEARTBEAT)
class Server(Killable): _event_class = asyncio.Event def __init__(self, address): super(Server, self).__init__() self.address = address self.queue = asyncio.Queue() self.context = Context.instance() self.socket = self.context.socket(zmq.PULL) self.poller = Poller() self._listen_future = None async def _receive_into_queue(self): while self.alive: try: events = await self.poller.poll(timeout=1e-4) if self.socket in dict(events): data = await self.socket.recv() await self.queue.put(data) except zmq.error.ZMQError: await asyncio.sleep(1e-4) async def start(self): self.socket.bind('tcp://%s:%s' % self.address) self.poller.register(self.socket, zmq.POLLIN) self._listen_future = asyncio.ensure_future(self._receive_into_queue()) await asyncio.sleep(0) async def shutdown(self): self.kill() self.poller.unregister(self.socket) if self._listen_future is not None and not self._listen_future.done(): self._listen_future.cancel() self.socket.close(linger=0) await asyncio.sleep(0) def iter_messages(self): async def wrapped(): while self.alive: try: result = self.queue.get_nowait() except asyncio.QueueEmpty: await asyncio.sleep(1e-6) else: await asyncio.sleep(0) return result return AsyncGenerator(wrapped)
def run_client(context): print("I: Connecting to server...") client = context.socket(zmq.REQ) client.connect(SERVER_ENDPOINT) poll = Poller() poll.register(client, zmq.POLLIN) sequence = 0 retries_left = REQUEST_RETRIES while retries_left: sequence += 1 request = str(sequence) print("I: Sending (%s)" % request) yield from client.send_string(request) expect_reply = True while expect_reply: socks = yield from poll.poll(REQUEST_TIMEOUT) socks = dict(socks) if socks.get(client) == zmq.POLLIN: reply = yield from client.recv() if not reply: break if int(reply) == sequence: print("I: Server replied OK (%s)" % reply) retries_left = REQUEST_RETRIES expect_reply = False else: print("E: Malformed reply from server: %s" % reply) else: print("W: No response from server, retrying...") # Socket is confused. Close and remove it. print('W: confused') client.setsockopt(zmq.LINGER, 0) client.unbind(SERVER_ENDPOINT) #client.close() poll.unregister(client) retries_left -= 1 if retries_left == 0: print("E: Server seems to be offline, abandoning") return print("I: Reconnecting and resending (%s)" % request) # Create new connection client = context.socket(zmq.REQ) client.connect(SERVER_ENDPOINT) poll.register(client, zmq.POLLIN) yield from client.send_string(request)
class QWeatherClient: """Client class for the QWeather messaging framework""" class serverclass: """Support class to represent the available servers as objects, with their exposed functions as callable attributes. The __repr__ makes it look like they are server objects""" def __init__(self,name,addr,methods,client): self.name = name self.addr = addr self.client = client for amethod in methods: setattr(self,amethod[0],self.bindingfunc(amethod[0],amethod[1])) def bindingfunc(self,methodname,methoddoc): """Ensures that "calling" the attribute of the "server"object with the name of a server function, sends a request to the server to execute that function and return the response""" def func(*args,**kwargs): timeout = kwargs.pop('timeout',CSYNCTIMEOUT) # This pops the value for timeout if it exists in kwargs, or returns the default timeout value. So this saves a line of code on logic check return self.client.send_request([self.name.encode(),methodname.encode(),pickle.dumps([args,kwargs])],timeout=timeout) func.__name__ = methodname func.__doc__ = methoddoc func.__repr__ = lambda: methoddoc func.is_remote_server_method = True return func def __repr__(self): msg = "" lst = [getattr(self,method) for method in dir(self) if getattr(getattr(self,method),'is_remote_server_method',False)] if len(lst) == 0: return 'No servers connected' else: for amethod in lst: msg += amethod.__name__ +"\n" return msg.strip() context = None socket = None poller = None futureobjectdict = {} def __init__(self,QWeatherStationIP,name = None,loop = None,debug=False,verbose=False): IpAndPort = re.search(IPREPATTERN,QWeatherStationIP) assert IpAndPort != None, 'Ip not understood (tcp://xxx.xxx.xxx.xxx:XXXX or txp://localhost:XXXX)' self.QWeatherStationIP = IpAndPort.group(1) self.QWeatherStationSocket = IpAndPort.group(2) assert self.QWeatherStationIP[:6] == 'tcp://', 'Ip not understood (tcp://xxx.xxx.xxx.xxx:XXXX or txp://localhost:XXXX)' assert len(self.QWeatherStationSocket) == 4, 'Port not understood (tcp://xxx.xxx.xxx.xxx:XXXX or txp://localhost:XXXX)' if loop is None: self.loop = asyncio.get_event_loop() else: self.loop = loop if name is None: import socket name = socket.gethostname() formatting = '{:}: %(levelname)s: %(message)s'.format(name) if debug: logging.basicConfig(format=formatting,level=logging.DEBUG) if verbose: logging.basicConfig(format=formatting,level=logging.INFO) self.name = name.encode() self.reconnect() # self.ping_broker() self.loop.run_until_complete(self.get_server_info()) self.running = False self.messageid = 0 atexit.register(self.close) def reconnect(self): '''connects or reconnects to the broker''' if self.poller: self.poller.unregister(self.socket) if self.socket: self.socket.close() self.context = Context() self.socket = self.context.socket(zmq.DEALER) self.socket.connect(self.QWeatherStationIP + ':' + self.QWeatherStationSocket) self.subsocket = self.context.socket(zmq.SUB) self.subsocket.connect(self.QWeatherStationIP + ':' + str(int(self.QWeatherStationSocket) + SUBSOCKET)) self.poller = Poller() self.poller.register(self.socket,zmq.POLLIN) self.poller.register(self.subsocket,zmq.POLLIN) def subscribe(self,servername,function): """Subscribe to a server with a callback function""" self.subsocket.setsockopt(zmq.SUBSCRIBE,servername.encode()) self.subscribers[servername] = function def unsubscribe(self,servername): """Unsubscribe from a server""" self.subsocket.setsockopt(zmq.UNSUBSCRIBE,servername.encode()) self.subscribers.pop(servername) async def get_server_info(self): """Get information about servers from the broker""" msg = [b'',b'C',CREADY,PCLIENT,self.name] self.send_message(msg) msg = await self.recieve_message() empty = msg.pop(0) assert empty == b'' command = msg.pop(0) self.serverlist = [] self.subscribers = {} if command == CREADY + CFAIL: raise Exception(msg.pop(0).decode()) else: serverdict = pickle.loads(msg.pop(0)) servermethoddict = pickle.loads(msg.pop(0)) for addr,name in serverdict.items(): methods = servermethoddict[addr] server = self.serverclass(name,addr,methods,self) server.is_remote_server = True setattr(self,name,server) self.serverlist.append(server) def send_request(self,body,timeout): """Send a request. If the client is running (i.e. in async mode) send an async request, else send a synchronous request\n Attach a messageID to each request. (0-255)""" self.messageid+=1 if self.messageid > 255: self.messageid = 0 if self.running: result = asyncio.get_event_loop().create_task(self.async_send_request(body,self.messageid.to_bytes(1,'big'))) else: result = self.sync_send_request(body,self.messageid.to_bytes(1,'big'),timeout) return result def ping_broker(self): """Ping the broker""" self.send_message([b'',b'P']) try: if len(self.loop.run_until_complete(self.poller.poll(timeout=2000))) == 0: #wait 2 seconds for a ping from the broker raise Exception('QWeatherStation not found') else: msg = self.loop.run_until_complete(self.recieve_message()) empty = msg.pop(0) pong = msg.pop(0) logging.debug('Recieved Pong: {:}'.format(pong)) if pong != b'b': raise Exception('QWeatherStation sent wrong Pong') except Exception as e: self.poller.unregister(self.socket) self.socket.close() raise e def sync_send_request(self,body,ident,timeout): """Synchronously send request. Timeout with the default timeoutvalue [FINDOUTHOWTOLINKTOTHECONSTANTSPAGETOSHOWDEFAULTVALUE]""" msg = [b'',b'C',CREQUEST,ident] + body server = body[0] self.send_message(msg) if len(self.loop.run_until_complete(self.poller.poll(timeout=timeout))) == 0: return Exception('Synchronous request timed out. Try adding following keyword to function call: "timeout=XX" in ms') else: msg = self.loop.run_until_complete(self.recieve_message()) empty = msg.pop(0) assert empty == b'' command = msg.pop(0) ident = msg.pop(0) server = msg.pop(0) answ = pickle.loads(msg[0]) return answ async def async_send_request(self,body,ident): """Ansynchronously send request. No explicit timeout on the client side for this. Relies on the "servertimeout" on the broker side""" server = body[0] msg = [b'',b'C',CREQUEST,ident] + body self.send_message(msg) answ = await self.recieve_future_message(ident+server) #Waits here until the future is set to completed self.futureobjectdict.pop(ident+server) return answ def send_message(self,msg): """Send a multi-frame-message over the ZMQ socket""" self.socket.send_multipart(msg) def recieve_future_message(self,id): """Create a future for the async request, add it to the dict of futures (id = messageid+server""" tmp = self.loop.create_future() self.futureobjectdict[id] = tmp return tmp async def recieve_message(self): """Recieve a multi-frame-message over the zmq socket""" msg = await self.socket.recv_multipart() return msg def handle_message(self,msg): """First step of handling an incoming message\n First asserts that the first frame is empty\n Then sorts the message into either request+success, request+fail or ping""" empty = msg.pop(0) assert empty == b'' command = msg.pop(0) if command == CREQUEST + CSUCCESS: messageid = msg.pop(0) servername = msg.pop(0) msg = pickle.loads(msg[0]) self.handle_request_success(messageid,servername,msg) elif command == CREQUEST + CFAIL: messageid = msg.pop(0) servername = msg.pop(0) self.handle_request_fail(messageid,servername) elif command == CPING: ping = msg.pop(0) if ping != b'P': raise Exception('QWeatherStation sent wrong ping') logging.debug('Recieved Ping from QWeatherStation') self.send_message([b'',b'b']) def handle_request_success(self,messageid,servername,msg): """Handle successful request by setting the result of the future (manually finishing the future)""" self.futureobjectdict[messageid + servername].set_result(msg) def handle_request_fail(self,messageid,servername): """Handle a failed request by setting the future to an exception""" self.futureobjectdict[messageid+server].set_exception(Exception(msg.pop(0))) def handle_broadcast(self,msg): """Handle a message on the broadcast socket by calling the callback function connected to the relevant server""" server= msg.pop(0).decode() msg = pickle.loads(msg.pop(0)) self.subscribers[server](msg) async def run(self): """Asynchronously run the client by repeatedly polling the recieving socket""" self.running = True while True: try: socks = await self.poller.poll(1000) socks = dict(socks) if self.socket in socks: msg = await self.recieve_message() self.handle_message(msg) elif self.subsocket in socks: msg = await self.recieve_message() self.handle_broadcast(msg) except KeyboardInterrupt: self.close() break def close(self): """Closing function. Tells the broker that it disconnects. Is not called if the terminal is closed or the process is force-killed""" self.send_message([b'',b'C',CDISCONNECT]) self.poller.unregister(self.socket) self.socket.close() def __repr__(self): msg = "" if len(self.serverlist) == 0: return 'No servers connected' else: for aserver in self.serverlist: msg += aserver.name + "\n" return msg.strip() def __iter__(self): return (aserv for aserv in self.serverlist) def __getitem__(self,key): return self.serverlist[key]
async def _cothread_user_socket_proxy(self, router_info: RouterInfo, lease: Lease, socket: Socket): current_physical_address: Optional[ PhysicalAddress] = await self.get_socket_for_lease( router_info, lease) if not current_physical_address: socket.close() return poller = Poller() poller.register(current_physical_address.socket) poller.register(socket, zmq.POLLIN) buffer: List[Tuple[Socket, List[Frame]]] = [] while True: if (not current_physical_address) or ( not current_physical_address.is_connected): if current_physical_address: poller.unregister(current_physical_address.socket) current_physical_address = await self.get_socket_for_lease( router_info, lease) if not current_physical_address: socket.close() return else: poller.register(socket, zmq.POLLIN) if not current_physical_address.socket: socket.close() return if buffer: buffer_pointer = -1 for target_socket, frames in buffer: target_socket = (current_physical_address.socket if target_socket != socket else socket) try: await asyncio.wait_for( target_socket.send_multipart(frames, copy=False, track=True), len(frames) * 5, ) buffer_pointer += 1 except asyncio.TimeoutError: break if buffer_pointer != (len(buffer) - 1): buffer = buffer[buffer_pointer:] else: buffer = [] else: pevents: List[Tuple[Socket, int]] = await poller.poll() for sock, ev in pevents: target_sock = (current_physical_address.socket if sock == socket else socket) if ev & zmq.POLLIN: frames = await sock.recv_multipart(copy=False) try: await asyncio.wait_for( target_sock.send_multipart(frames, copy=False, track=True), len(frames) * 5, ) except asyncio.TimeoutError: buffer.append((target_sock, frames))
class QWeatherStation: """Central broker for the communcation done in QWeather""" def __init__(self, IP, loop=None, verbose=False, debug=False): if loop is None: #from zmq import Context,Poller # import asyncio # from zmq.asyncio import Context,Poller self.loop = asyncio.get_event_loop() else: self.loop = loop IpAndPort = re.search(IPREPATTERN, IP) assert IpAndPort != None, 'Ip not understood (tcp://xxx.xxx.xxx.xxx:XXXX or txp://*:XXXX)' self.StationIP = IpAndPort.group(1) self.StationSocket = IpAndPort.group(2) assert self.StationIP[: 6] == 'tcp://', 'Ip not understood (tcp://xxx.xxx.xxx.xxx:XXXX or txp://*:XXXX)' assert len( self.StationSocket ) == 4, 'Port not understood (tcp://xxx.xxx.xxx.xxx:XXXX or txp://*:XXXX)' formatting = '{:}: %(levelname)s: %(message)s'.format( 'QWeatherStation') if debug: logging.basicConfig(format=formatting, level=logging.DEBUG) if verbose: logging.basicConfig(format=formatting, level=logging.INFO) self.servers = {} # key:value = clientaddress:value, bytes:string self.clients = {} # key:value = clientaddress:value, bytes:string self.servermethods = {} self.serverjobs = {} self.pinged = [] self.requesttimeoutdict = {} self.cnx = Context() self.socket = self.cnx.socket(zmq.ROUTER) self.socket.bind(self.StationIP + ':' + self.StationSocket) self.proxy = ThreadProxy(zmq.XSUB, zmq.XPUB) self.proxy.bind_in(self.StationIP + ':' + str(int(self.StationSocket) + PUBLISHSOCKET)) self.proxy.bind_out(self.StationIP + ':' + str(int(self.StationSocket) + SUBSOCKET)) self.proxy.start() self.poller = Poller() self.poller.register(self.socket, zmq.POLLIN) logging.info('Ready to run on IP: {:}'.format(self.get_own_ip())) def get_own_ip(self): import socket s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) try: # doesn't even have to be reachable s.connect(('10.255.255.255', 1)) IP = s.getsockname()[0] except: IP = '127.0.0.1' finally: s.close() return IP async def async_run(self): """Ansynchronous run the broker by polling the socket repeatedly""" while True: try: items = await self.poller.poll(1000) except KeyboardInterrupt: self.close() break if items: msg = await self.recieve_message() self.handle_message(msg) def run(self): """Runs the broker, enabling message handling (blocking if called from a scrip)""" self.loop.run_until_complete(self.async_run()) def close(self): """Closing function, called at exit""" self.poller.unregister(self.socket) self.socket.close() def handle_message(self, msg): """The first step of message handling.\n First assert that the second frame is empty\n Then process either [S]erver, [C]lient, [P]ing [b]pong, or [#] for executing broker functions """ sender = msg.pop(0) if sender in self.clients.keys(): logging.debug('Recieved message from {:}:\n{:}'.format( self.clients[sender], msg, '\n\n')) else: logging.debug('Recieved message from ID:{:}:\n{:}'.format( int.from_bytes(sender, byteorder='big'), msg, '\n\n')) empty = msg.pop(0) assert empty == b'' SenderType = msg.pop(0) #Server if SenderType == b'S': command = msg.pop(0) # 0xF? for server and 0x0? for client self.process_server(sender, command, msg) #Client elif (SenderType == b'C'): command = msg.pop(0) # 0xF? for server and 0x0? for client self.process_client(sender, command, msg) #Ping elif SenderType == b'P': if sender in self.clients.keys(): logging.debug('Recieved Ping from "{:}"'.format( self.clients[sender])) else: logging.debug('Recieved Ping from ID:{:}'.format( int.from_bytes(sender, byteorder='big'))) self.socket.send_multipart( [sender, b'', b'b']) #Sending an upside down P (b) to indicate a pong #Pong elif SenderType == b'b': print('got a pong') logging.debug('Recieved Pong from ID:{:}'.format( int.from_bytes(sender, byteorder='big'))) print(sender, self.pinged, sender in self.pinged) if sender in self.pinged: print('before', self.pinged) self.pinged.remove(sender) print('after', self.pinged) #Execute command elif SenderType == b'#': command = msg.pop(0) if command == b'P': #request broker to ping all servers and remove old ones logging.debug('Ping of all servers requested') self.loop.create_task(self.ping_connections()) elif command == b'R': #requests the broker to "restart" by removing all connections for atask in self.requesttimeoutdict.items(): atask.cancel() self.requesttimeoutdict = {} self.servers = {} self.clients = {} if sender in self.clients.keys(): logging.debug('Recieved Ping from "{:}"'.format( self.clients[sender])) else: logging.debug('Recieved Ping from ID:{:}'.format( int.from_bytes(sender, byteorder='big'))) #SenderType not understood else: logging.info('Invalid message') def process_client(self, sender, command, msg): """Second stage of the message handling. Messages go here if they came from a client""" if command == CREADY: version = msg.pop(0) self.handle_client_ready(sender, version, msg) elif command == CREQUEST: messageid = msg.pop(0) servername = msg.pop(0).decode() self.handle_client_request(sender, messageid, servername, msg) elif command == CDISCONNECT: self.handle_client_disconnect() def handle_client_ready(self, sender, version, msg): """Check the client is using the same version of QWeather, add client to clientlist and send client list of servers and servermethods""" if not version == PCLIENT: newmsg = [ sender, b'', CREADY + CFAIL, 'Mismatch in protocol between client and broker'.encode() ] else: newmsg = [sender, b'', CREADY + CSUCCESS] + [ pickle.dumps(self.servers) ] + [pickle.dumps(self.servermethods)] name = msg.pop(0).decode() self.clients[sender] = name logging.info('Client ready at ID:{:} name:{:}'.format( int.from_bytes(sender, byteorder='big'), self.clients[sender])) self.send_message(newmsg) def handle_client_request(self, sender, messageid, servername, msg): """Send a client request to the correct server. Add a timeout callback in case the server response timeouts""" try: #Find the server address in the server dict based on the name {address:name} serveraddr = next(key for key, value in self.servers.items() if value == servername) #Create a timeout call which returns an exception if the reply from the server times out. timeout = self.loop.call_later( B_SERVERRESPONSE_TIMEOUT, self.send_message, [ sender, b'', CREQUEST + CFAIL, messageid, servername.encode(), pickle.dumps((Exception('Timeout error'))) ]) #Add the timeout to a dictionary so we can find it later (and cancel it before it times out) self.requesttimeoutdict[messageid + sender] = timeout msg = [serveraddr, b'', CREQUEST, messageid, sender] + msg #If the joblist for the requested server is empty, send it to the server, else add it to the serverjoblist for later execution if len(self.serverjobs[serveraddr]) == 0: self.send_message(msg) logging.debug('Client request from "{:}":\n{:}'.format( self.clients[sender], msg)) else: self.serverjobs[serveraddr].append(msg) except StopIteration as e: logging.debug('Trying to contact a server that does not exist') def handle_client_disconnect(self, sender): """Remove the client from the client dictionary""" logging.debug('Client "{:}" disconnecting'.format( self.clients[sender])) self.clients.pop(sender) def process_server(self, sender, command, msg): """Second stage of the message handling. Messages go here if they came from a server""" if command == CREADY: version = msg.pop(0) self.handle_server_ready(sender, version, msg) elif command == CREPLY: messageid = msg.pop(0) servername = self.servers[sender] clientaddr = msg.pop(0) answ = msg.pop(0) self.handle_server_reply(sender, messageid, servername, clientaddr, answ) elif command == SDISCONNECT: self.handle_server_disconnect(sender) def handle_server_ready(self, sender, version, msg): """Check the server is using the same version of QWeather.\n Add the server to the serverdict, add the methods to the servermethods dict, add an empty list to the serverjobs dict\n Keys for all 3 dicts are the serveraddress/id assigned by ZMQ (the first frame of every message recieved)""" if not version == PSERVER: newmsg = [ sender, b'', CREADY + CFAIL, 'Mismatch in protocol between server and broker'.encode() ] else: servername = msg.pop(0).decode() servermethods = pickle.loads(msg.pop(0)) self.servers[sender] = servername self.servermethods[sender] = servermethods self.serverjobs[sender] = [] newmsg = [sender, b'', CREADY + CSUCCESS] logging.info('Server "{:}" ready at: {:}'.format( servername, int.from_bytes(sender, byteorder='big'))) self.send_message(newmsg) def handle_server_reply(self, sender, messageid, servername, clientaddr, answer): """Forward the server reply to the client that requested it.\n Also cancel the timeout callback now that the server has replied in time\n If there are more jobs in the serverjob list for this server, send the oldest one to the server""" msg = [ clientaddr, b'', CREQUEST + CSUCCESS, messageid, servername.encode(), answer ] try: #Cancel the timeout callback created when the request was sent ot the server timeouttask = self.requesttimeoutdict.pop(messageid + clientaddr) timeouttask.cancel() self.send_message(msg) logging.debug('Server answer to Client "{:}":\n{:}'.format( self.clients[clientaddr], msg)) #If there are more requests in queue for the server, send the oldest one if len(self.serverjobs[sender]) > 0: self.send_message(self.serverjobs[sender].pop(0)) except KeyError: print("Trying to send answer to client that does not exist") def handle_server_disconnect(self, sender): """Remove the server from the server, serverjobs and servermethods dictionaries""" logging.debug('Server "{:}" disconnecting'.format( self.servers[sender])) self.servers.pop(sender) self.serverjobs.pop(sender) self.servermethods.pop(sender) def send_message(self, msg): """Send a multi-frame-message over the zmq socket""" self.socket.send_multipart(msg) async def recieve_message(self): """Recieve a multi-frame-message over the zmq socket (async)""" msg = await self.socket.recv_multipart() return msg async def ping_connections(self): """Ping all connections, then await 2 seconds and check if the pings responded""" self.__ping() await asyncio.sleep(2) self.__check_ping() def __ping(self): self.pinged = [] for addresse in self.servers.keys(): self.socket.send_multipart([addresse, b'', CPING, b'P']) self.pinged.append(addresse) def __check_ping(self): for aping in self.pinged: for aname, aserver in self.servers.items(): if aping == aserver[0]: break del self.servers[aname] print('servers:', self.servers) # print(self.pinged) self.pinged = [] def get_servers(self): """Return the server dictionary""" return self.servers def get_clients(self): """Return the client dictionary""" return self.clients
class SchedulerConnection(object): __slots__ = ( 'address', # context object to open socket connections 'context', # pull socket to receive check definitions from scheduler 'pull', # poller object for `pull` socket 'poller', # monitor socket for `pull` socket 'monitor_socket', # poller object for monitor socket 'monitor_poller', 'first_missing', ) def __init__(self, address): self.pull = self.poller = None self.monitor_poller = self.monitor_socket = None self.address = address self.context = Context.instance() self.open() self.first_missing = None def __str__(self): return self.address def __repr__(self): return 'Scheduler({})'.format(self.address) def open(self): self.pull = self.context.socket(PULL) logger.info('%s - opening pull socket ...', self) self.pull.connect(self.address) if settings.SCHEDULER_MONITOR: logger.info('%s - opening monitor socket ...', self) self.monitor_socket = self.pull.get_monitor_socket( events=EVENT_DISCONNECTED ) self.register() def register(self): self.poller = Poller() self.poller.register(self.pull, POLLIN) if settings.SCHEDULER_MONITOR: self.monitor_poller = Poller() self.monitor_poller.register(self.monitor_socket, POLLIN) logger.info('%s - all sockets are successfully registered ' 'in poller objects ...', self) def close(self): """Unregister open sockets from poller objects and close them.""" self.unregister() logger.info('%s - closing open sockets ...', self) self.pull.close() if settings.SCHEDULER_MONITOR: self.monitor_socket.close() logger.info('%s - connection closed successfully ...', self) def unregister(self): """Unregister open sockets from poller object.""" logger.info('%s - unregistering sockets from poller objects ...', self) self.poller.unregister(self.pull) if settings.SCHEDULER_MONITOR: self.monitor_poller.unregister(self.monitor_socket) def reconnect(self): self.close() self.open() self.first_missing = None @asyncio.coroutine def receive(self): check = None events = yield from self.poller.poll(timeout=2000) if self.pull in dict(events): check = yield from self.pull.recv_multipart() check = jsonapi.loads(check[0]) if check: self.first_missing = None elif self.first_missing is None: self.first_missing = datetime.now(tz=pytz.utc) if self.first_missing: diff = datetime.now(tz=pytz.utc) - self.first_missing delta = timedelta(minutes=settings.SCHEDULER_LIVENESS_IN_MINUTES) if diff > delta: logger.warning( 'Alamo worker `%s` pid `%s` try to reconnect to ' '`%s` scheduler.', settings.WORKER_FQDN, settings.WORKER_PID, self ) self.reconnect() return check @asyncio.coroutine def receive_event(self): event = None events = yield from self.monitor_poller.poll(timeout=2000) if self.monitor_socket in dict(events): msg = yield from self.monitor_socket.recv_multipart( flags=NOBLOCK) event = parse_monitor_message(msg) return event
class AsyncRdsBusClient(object): """ RDS-BUS 客户端 """ ASC = 1 DESC = -1 def __init__(self, url, logger, request_timeout=None, database=None): self._logger = logger self._database = database self._context = Context.instance() self._poller = Poller() self._request = self._context.socket(zmq.DEALER) self._request_timeout = request_timeout or 60 self._rds_bus_url = url self._request.connect(self._rds_bus_url) self._request_dict = dict() self._io_loop = asyncio.get_event_loop() self._running = False asyncio.ensure_future(self.start()) @classmethod def pack(cls, database: str, key: str, parameter: dict, is_query: bool = False, order_by: list = None, page_no: int = None, per_page: int = None, found_rows: bool = False): """ 打包请求数据 :param database: RDS-BUS的数据库类名 :param key: 数据库类所持有的实例名 :param parameter: 参数字典 :param is_query: 是否为查询操作 :param order_by: 排序信息 [{"column": "字段名", "order": AsyncRdsBusClient.ASC/AsyncRdsBusClient.DESC}] :param page_no: 当前页(范围[0-n) n指第n页) :param per_page: 每页记录数 :param found_rows: 是否统计总数 :return: """ if is_query: amount = int(per_page) if per_page else None offset = int(page_no) * amount if page_no else None limit = (dict(amount=amount, offset=offset) if offset else dict( amount=amount)) if amount else None result = dict(command="{}/{}".format(database, key), data=dict(var=parameter, order_by=order_by, limit=limit, found_rows=found_rows)) else: result = dict(command="{}/{}".format(database, key), data=dict(var=parameter)) return result async def query(self, key: str, parameter: dict, order_by: list = None, page_no: int = None, per_page: int = None, found_rows: bool = False, database: str = None, execute: bool = True): """ 查询接口 :param database: RDS-BUS的数据库类名 :param key: 数据库类所持有的语句实例名 :param parameter: 参数字典 :param order_by: 排序信息 [{"column": "字段名", "order": AsyncRdsBusClient.ASC/AsyncRdsBusClient.DESC}] :param page_no: 当前页(范围[0-n) n指第n页) :param per_page: 每页记录数 :param found_rows: 是否统计总数 :param execute: 是否执行 :return: """ _database = database or self._database argument = self.pack(database=_database, key=key, parameter=parameter, is_query=True, order_by=order_by, page_no=page_no, per_page=per_page, found_rows=found_rows) if execute: response = await self._send(operation=OperationType.QUERY, argument=argument) result = RdsData(response) else: result = argument return result async def insert(self, key: str, parameter: dict, database: str = None, execute: bool = True): """ 新增接口 :param database: RDS-BUS的数据库类名 :param key: 数据库类所持有的语句实例名 :param parameter: 参数字典 :param execute: 是否执行 :return: """ _database = database or self._database argument = self.pack(database=_database, key=key, parameter=parameter) if execute: response = await self._send(operation=OperationType.INSERT, argument=argument) result = RdsData(response) else: result = argument return result async def update(self, key: str, parameter: dict, database: str = None, execute: bool = True): """ 更新接口 :param database: RDS-BUS的数据库类名 :param key: 数据库类所持有的语句实例名 :param parameter: 参数字典 :param execute: 是否执行 :return: """ _database = database or self._database argument = self.pack(database=_database, key=key, parameter=parameter) if execute: response = await self._send(operation=OperationType.UPDATE, argument=argument) result = RdsData(response) else: result = argument return result async def delete(self, key: str, parameter: dict, database: str = None, execute: bool = False): """ 删除接口 :param database: RDS-BUS的数据库类名 :param key: 数据库类所持有的语句实例名 :param parameter: 参数字典 :param execute: 是否执行 :return: """ _database = database or self._database argument = self.pack(database=_database, key=key, parameter=parameter) if execute: response = await self._send(operation=OperationType.DELETE, argument=argument) result = RdsData(response) else: result = argument return result async def transaction(self, data: list, database: str = None): """ 事务接口 :param database: RDS-BUS的数据库类名 :param data: 操作列表 :return: """ _database = database or self._database result = await self._send( operation=OperationType.TRANSACTION, argument=dict(command="{}/transaction".format(_database), data=data)) return RdsListData(result) async def batch(self, data: list, database: str = None): """ 批量接口 :param database: RDS-BUS的数据库类名 :param data: 操作列表 :return: """ _database = database or self._database result = await self._send(operation=OperationType.BATCH, argument=dict( command="{}/batch".format(_database), data=data)) return RdsListData(result) async def start(self): self._poller.register(self._request, zmq.POLLIN) self._running = True while True: events = await self._poller.poll() if self._request in dict(events): response = await self._request.recv_json() self._logger.debug("received {}".format(response)) if response["id"] in self._request_dict: future = self._request_dict.pop(response["id"]) if HttpResult.is_duplicate_data_failure(response["code"]): future.set_exception( DuplicateDataException.new_exception( response["desc"])) elif HttpResult.is_failure(response["code"]): future.set_exception( CallServiceException(method="ZMQ", url=self._rds_bus_url, errmsg=response["desc"])) else: future.set_result(response["data"]) else: self._logger.warning( "unknown response {}".format(response)) def stop(self): if self._running: self._poller.unregister(self._request) self._running = False def shutdown(self): self.stop() self._request.close() def _send(self, operation, argument): """ :param operation: :param argument: :return: """ request_id = get_unique_id() self._request_dict[request_id] = asyncio.Future() self._io_loop.call_later(self._request_timeout, self._session_timeout, request_id) self._request.send_multipart([ json.dumps( dict(id=request_id, operation=operation.value, argument=argument)).encode("utf-8") ]) return self._request_dict[request_id] def _session_timeout(self, request_id): if request_id in self._request_dict: future = self._request_dict.pop(request_id) future.set_exception( ServerTimeoutException(method="ZMQ", url=self._rds_bus_url))