def __init__(self, srv_address, zmq_context=None, retry_timeout=-1, nretry=1, **kwargs): # Get parameters srv_param = parse_address(srv_address) cli_param = dict() for k in ['protocol', 'host', 'port']: cli_param[k] = kwargs.pop(k, srv_param[k]) zmq_context = zmq_context or _global_context # Create new address for the frontend if cli_param['protocol'] in ['inproc', 'ipc']: cli_param['host'] = get_ipc_host() cli_address = format_address(cli_param['protocol'], cli_param['host']) self.cli_socket = zmq_context.socket(zmq.ROUTER) self.cli_address = bind_socket(self.cli_socket, cli_address, nretry=nretry, retry_timeout=retry_timeout) self.cli_socket.setsockopt(zmq.LINGER, 0) CommBase.register_comm('ZMQComm', 'ROUTER_server_' + self.cli_address, self.cli_socket) # Bind backend self.srv_socket = zmq_context.socket(zmq.DEALER) self.srv_socket.setsockopt(zmq.LINGER, 0) self.srv_address = bind_socket(self.srv_socket, srv_address, nretry=nretry, retry_timeout=retry_timeout) CommBase.register_comm('ZMQComm', 'DEALER_server_' + self.srv_address, self.srv_socket) # Set up poller # self.poller = zmq.Poller() # self.poller.register(frontend, zmq.POLLIN) self.reply_socket = None # Set name super(ZMQProxy, self).__init__(self.srv_address, self.cli_address, **kwargs) self._name = 'ZMQProxy.%s' % srv_address
def test_queue(): r"""Test creation/removal of queue.""" mq = IPCComm.get_queue() key = str(mq.key) assert (CommBase.is_registered('ipc', key)) IPCComm.IPCComm.unregister_comm(key, dont_close=True) assert_raises(KeyError, IPCComm.remove_queue, mq) IPCComm.IPCComm.register_comm(key, mq) IPCComm.remove_queue(mq) assert (not CommBase.is_registered('ipc', key))
def backlog_thread(self): r"""tools.YggThread: Thread that will handle sinding or receiving backlogged messages.""" if self._backlog_thread is None: if self.direction == 'send': self._backlog_thread = CommBase.CommThreadLoop( self, target=self.run_backlog_send, suffix='SendBacklog') else: self._backlog_thread = CommBase.CommThreadLoop( self, target=self.run_backlog_recv, suffix='RecvBacklog') return self._backlog_thread
def close_sockets(self): r"""Close the sockets.""" self.debug('Closing sockets') if self.cli_socket: self.cli_socket.close() self.cli_socket = None if self.srv_socket: self.srv_socket.close() self.srv_socket = None CommBase.unregister_comm('ZMQComm', 'ROUTER_server_' + self.cli_address) CommBase.unregister_comm('ZMQComm', 'DEALER_server_' + self.srv_address)
def remove_queue(mq): r"""Remove a sysv_ipc.MessageQueue and unregister it. Args: mq (:class:`sysv_ipc.MessageQueue`) Message queue. Raises: KeyError: If the provided queue is not registered. """ key = str(mq.key) if not CommBase.is_registered('IPCComm', key): raise KeyError("Queue not registered.") CommBase.unregister_comm('IPCComm', key)
def recv_message(self, *args, **kwargs): r"""Receive a message. Args: *args: Arguments are passed to the forked comm's recv_message method. **kwargs: Keyword arguments are passed to the forked comm's recv_message method. Returns: CommMessage: Received message. """ timeout = kwargs.pop('timeout', None) if timeout is None: timeout = self.recv_timeout kwargs['timeout'] = 0 first_comm = True T = self.start_timeout(timeout, key_suffix='recv:forkd') out = None i = 0 while ((not T.is_out) or first_comm) and self.is_open and (out is None): for i in range(len(self)): if out is not None: break x = self.curr_comm if x.is_open: msg = x.recv_message(*args, **kwargs) self.errors += x.errors if msg.flag == CommBase.FLAG_EOF: self.eof_recv[self.curr_comm_index % len(self)] = 1 if sum(self.eof_recv) == len(self): out = msg else: x.finalize_message(msg) elif msg.flag not in [CommBase.FLAG_FAILURE, CommBase.FLAG_EMPTY]: out = msg self.curr_comm_index += 1 first_comm = False if out is None: self.sleep() self.stop_timeout(key_suffix='recv:forkd') if out is None: if self.is_closed: self.debug('Comm closed') out = CommBase.CommMessage(flag=CommBase.FLAG_FAILURE) else: out = CommBase.CommMessage(flag=CommBase.FLAG_EMPTY, args=self.last_comm.empty_obj_recv) out.args = {i: out.args} return out
def test_send_recv(self): r"""Test sending/receiving with queues closed.""" self.instance.close_comm() self.send_comm.close() self.recv_comm.close() assert (self.instance.is_comm_closed) assert (self.send_comm.is_closed) assert (self.recv_comm.is_closed) flag = self.instance.send_message( CommBase.CommMessage(args=self.test_msg)) assert (not flag) flag = self.instance.recv_message() if self.instance.icomm._commtype != 'value': assert (not flag) # Short if self.instance.icomm._commtype != 'value': flag = self.send_comm.send(self.test_msg) assert (not flag) flag, ret = self.recv_comm.recv() if self.instance.icomm._commtype != 'value': assert (not flag) self.assert_equal(ret, None) # Long if self.instance.icomm._commtype != 'value': flag = self.send_comm.send_nolimit(self.test_msg) assert (not flag) flag, ret = self.recv_comm.recv_nolimit() if self.instance.icomm._commtype != 'value': assert (not flag) self.assert_equal(ret, None) self.instance.confirm_output(timeout=1.0)
def remove_model(self, direction, name): r"""Remove a model from the list of models. Args: direction (str): Direction of model. name (str): Name of model exiting. Returns: bool: True if all of the input/output models have signed off; False otherwise. """ with self.lock: if (direction == "input") and (name in self.clients): super(RPCRequestDriver, self).send_message( CommBase.CommMessage(args=YGG_CLIENT_EOF, flag=CommBase.FLAG_SUCCESS), header_kwargs={ 'raw': True, 'model': name }, skip_processing=True) out = super(RPCRequestDriver, self).remove_model(direction, name) if out: self.send_eof() return out
def test_send_recv_closed(self, instance, send_comm, recv_comm, test_msg): r"""Test sending/receiving with queues closed.""" instance.close_comm() send_comm.close() recv_comm.close() assert (instance.is_comm_closed) assert (send_comm.is_closed) assert (recv_comm.is_closed) flag = instance.send_message(CommBase.CommMessage(args=test_msg)) assert (not flag) flag = instance.recv_message() if instance.icomm._commtype != 'value': assert (not flag) # Short if instance.icomm._commtype != 'value': flag = send_comm.send(test_msg) assert (not flag) flag, ret = recv_comm.recv() if instance.icomm._commtype != 'value': assert (not flag) assert (ret is None) # Long if instance.icomm._commtype != 'value': flag = send_comm.send_nolimit(test_msg) assert (not flag) flag, ret = recv_comm.recv_nolimit() if instance.icomm._commtype != 'value': assert (not flag) assert (ret is None) instance.confirm_output(timeout=1.0)
def recv_message(self, timeout=None, **kwargs): r"""Receive a message. Args: *args: Arguments are passed to the response comm's recv_message method. **kwargs: Keyword arguments are passed to the response comm's recv_message method. Returns: CommMessage: Received message. """ # Sleep until there is a message if timeout is None: timeout = kwargs.get('timeout', self.recv_timeout) T = self.start_timeout(timeout, key_suffix='.recv:backlog') while (not T.is_out) and (not self.backlog_ready.is_set()): self.backlog_ready.wait(self.sleeptime) self.stop_timeout(key_suffix='.recv:backlog') # Handle absence of messages if self.n_msg_backlog == 0: self.verbose_debug("No messages waiting.") if self.is_closed: self.debug(("No messages waiting and comm closed." "%s, %s, %s") % (self.backlog_thread is not None, not self.backlog_thread.was_break, self.backlog_thread.is_alive())) self.printStatus(level='debug') if self.backlog_thread.was_break: self.debug("Break stack:\n%s", self.backlog_thread.break_stack) out = CommBase.CommMessage(flag=CommBase.FLAG_FAILURE) else: out = CommBase.CommMessage(flag=CommBase.FLAG_EMPTY, args=self.empty_obj_recv) # Return backlogged message else: self.debug('Returning backlogged received message') out = self.pop_backlog() return out
def get_queue(qid=None): r"""Create or return a sysv_ipc.MessageQueue and register it. Args: qid (int, optional): If provided, ID for existing queue that should be returned. Defaults to None and a new queue is returned. Returns: :class:`sysv_ipc.MessageQueue`: Message queue. """ if _ipc_installed: kwargs = dict(max_message_size=tools.get_YGG_MSG_MAX()) if qid is None: kwargs['flags'] = sysv_ipc.IPC_CREX mq = sysv_ipc.MessageQueue(qid, **kwargs) key = str(mq.key) CommBase.register_comm('IPCComm', key, mq) return mq else: # pragma: windows logger.warning("IPC not installed. Queue cannot be returned.") return None
def reply_thread(self): r"""tools.YggTask: Task that will handle sinding or receiving backlogged messages.""" if (self._reply_thread is None) and (not self.is_async): def reply_target(): if self.is_closed: raise multitasking.BreakLoopException("Comm closed") if self.direction == 'send': self.confirm_send(sleep=True) else: self.confirm_recv(sleep=True) self._reply_thread = CommBase.CommTaskLoop( self, target=reply_target, suffix='Reply') return self._reply_thread
def send_eof(self, **kwargs): r"""Send EOF message. Returns: bool: Success or failure of send. """ with self.lock: if self._eof_sent: # pragma: debug self.debug('Already sent EOF') return False self._eof_sent = True self.debug('Sent EOF') msg = CommBase.CommMessage(flag=CommBase.FLAG_EOF, args=self.ocomm.eof_msg) return self.send_message(msg, **kwargs)
def on_eof(self, msg): r"""On EOF, decrement number of clients. Only send EOF if the number of clients drops to 0. Args: msg (CommMessage): Message object that provided the EOF. Returns: CommMessage, bool: Value that should be returned by recv_message on EOF. """ with self.lock: self.remove_model('input', msg.header.get('model', '')) if self.nclients == 0: self.debug("All clients have signed off (EOF).") return super(RPCRequestDriver, self).on_eof(msg) return CommBase.CommMessage(flag=CommBase.FLAG_EMPTY, args=self.icomm.empty_obj_recv)
def recv(self, timeout=None, return_message_object=False, dont_finalize=False, **kwargs): r"""Receive a message. Args: *args: All arguments are passed to comm _recv method. return_message_object (bool, optional): If True, the full wrapped CommMessage message object is returned instead of the tuple. Defaults to False. dont_finalize (bool, optional): If True, finalize_message will not be called even if async_recv_method is 'recv_message'. Defaults to False. **kwargs: All keywords arguments are passed to comm _recv method. Returns: tuple (bool, obj): Success or failure of receive and received message. """ self.precheck('recv') # Sleep until there is a message if timeout is None: timeout = kwargs.get('timeout', self.recv_timeout) T = self.start_timeout(timeout, key_suffix='.recv:backlog') while (not T.is_out) and (not self.backlog_ready.is_set()): self.backlog_ready.wait(self.sleeptime) self.stop_timeout(key_suffix='.recv:backlog') # Handle absence of messages if self.n_msg_backlog == 0: self.verbose_debug("No messages waiting.") if self.is_closed: self.info(("No messages waiting and comm closed." "%s, %s, %s") % (self.backlog_thread is not None, not self.backlog_thread.was_break, self.backlog_thread.is_alive())) self.printStatus() if self.backlog_thread.was_break: self.info("Break stack:\n%s", self.backlog_thread.break_stack) out = CommBase.CommMessage(flag=CommBase.FLAG_FAILURE) else: out = CommBase.CommMessage(flag=CommBase.FLAG_EMPTY, args=self.empty_obj_recv) # Return backlogged message else: self.debug('Returning backlogged received message') out = self.pop_backlog() if not dont_finalize: # if self.is_eof(out.args) and self.close_on_eof_recv: if (out.flag == CommBase.FLAG_EOF) and self.close_on_eof_recv: self.close() out.flag = CommBase.FLAG_FAILURE self._used = True if not dont_finalize: kws_finalize = { k: kwargs.pop(k) for k in self._finalize_message_kws if k in kwargs } if self.async_recv_method != 'recv_message': out.finalized = False kws_finalize['skip_processing'] = True out = self._wrapped.finalize_message(out, **kws_finalize) if not return_message_object: out = (bool(out.flag), out.args) return out
def recv_message(self, *args, **kwargs): r"""Receive a message. Args: *args: Arguments are passed to the forked comm's recv_message method. **kwargs: Keyword arguments are passed to the forked comm's recv_message method. Returns: CommMessage: Received message. """ timeout = kwargs.pop('timeout', None) if timeout is None: timeout = self.recv_timeout kwargs['timeout'] = 0 first_comm = True T = self.start_timeout(timeout, key_suffix='recv:forkd') out = None out_gather = {} idx = None if self.pattern == 'gather': def complete(): return (len(out_gather) == len(self)) else: def complete(): return bool(out_gather) while ((not T.is_out) or first_comm) and self.is_open and (not complete()): for i in range(len(self)): if complete(): break idx = self.curr_comm_index % len(self) x = self.curr_comm if idx not in out_gather: if self.comm_list_backlog[idx]: out_gather[idx] = self.comm_list_backlog[idx].pop(0) elif x.is_open: msg = x.recv_message(*args, **kwargs) self.errors += x.errors if msg.flag == CommBase.FLAG_EOF: self.eof_recv[idx] = 1 if self.pattern == 'gather': assert (all((v.flag == CommBase.FLAG_EOF) for v in out_gather.values())) out_gather[idx] = msg elif sum(self.eof_recv) == len(self): out_gather[idx] = msg else: x.finalize_message(msg) elif msg.flag == CommBase.FLAG_SUCCESS: out_gather[idx] = msg self.curr_comm_index += 1 first_comm = False if not complete(): self.sleep() self.stop_timeout(key_suffix='recv:forkd') if complete(): if self.pattern == 'cycle': idx, out = next(iter(out_gather.items())) args_copy = copy.deepcopy(out) out.args = {idx: args_copy} elif self.pattern == 'gather': out = copy.deepcopy(next(iter(out_gather.values()))) out.args = {idx: v for idx, v in out_gather.items()} # TODO: Gather header/type etc? else: for idx, v in out_gather.items(): self.comm_list_backlog[idx].append(v) if self.is_closed: self.debug('Comm closed') out = CommBase.CommMessage(flag=CommBase.FLAG_FAILURE) else: out = CommBase.CommMessage(flag=CommBase.FLAG_EMPTY) if self.pattern == 'cycle': out.args = self.last_comm.empty_obj_recv else: out.args = [] return out
def test_registry(): r"""Test registry of comm.""" comm_class = 'CommBase' key = 'key1' value = None assert (not CommBase.is_registered(comm_class, key)) assert (not CommBase.unregister_comm(comm_class, key)) assert_equal(CommBase.get_comm_registry(None), {}) assert_equal(CommBase.get_comm_registry(comm_class), {}) CommBase.register_comm(comm_class, key, value) assert (key in CommBase.get_comm_registry(comm_class)) assert (CommBase.is_registered(comm_class, key)) assert (not CommBase.unregister_comm(comm_class, key, dont_close=True)) CommBase.register_comm(comm_class, key, value) assert (not CommBase.unregister_comm(comm_class, key))
def test_cleanup_comms(self): r"""Test cleanup_comms for comm class.""" CommBase.cleanup_comms(self.recv_instance.comm_class) assert (len(CommBase.get_comm_registry( self.recv_instance.comm_class)) == 0)
def terminate(self, *args, **kwargs): CommBase.unregister_comm('RMQComm', self.srv_address) super(RMQServer, self).terminate(*args, **kwargs)
def test_registry(): r"""Test registry of comm.""" from yggdrasil.communication import CommBase comm_class = 'CommBase' key = 'key1' value = None assert (not CommBase.is_registered(comm_class, key)) assert (not CommBase.unregister_comm(comm_class, key)) assert (CommBase.get_comm_registry(None) == {}) assert (CommBase.get_comm_registry(comm_class) == {}) CommBase.register_comm(comm_class, key, value) assert (key in CommBase.get_comm_registry(comm_class)) assert (CommBase.is_registered(comm_class, key)) assert (not CommBase.unregister_comm(comm_class, key, dont_close=True)) CommBase.register_comm(comm_class, key, value) assert (not CommBase.unregister_comm(comm_class, key))