def init_shm(self): ''' initialize the shared memory manager ''' self.shm_mgr = SharedMemoryManager(self.config) self.shm_mgr.connect(self.SHM_SOCKET_NAME_TEMPLATE % self.pid) self.shm_mgr.create_key_and_ignore_conflict( self.shm_key_pkts, SHMContainerTypes.DICT, ) self.shm_mgr.create_key_and_ignore_conflict( self.shm_key_pkts_to_repeat, SHMContainerTypes.LIST, ) self.shm_mgr.create_key_and_ignore_conflict( self.shm_key_last_repeat_time, SHMContainerTypes.DICT, ) self.shm_mgr.create_key_and_ignore_conflict( self.shm_key_next_repeat_time, SHMContainerTypes.DICT, ) self.shm_mgr.create_key_and_ignore_conflict( self.shm_key_max_repeat_times, SHMContainerTypes.DICT, ) self.shm_mgr.create_key_and_ignore_conflict( self.shm_key_repeated_times, SHMContainerTypes.DICT, )
def init_shm(self): ''' initialize the shared memory manager ''' self.shm_mgr = SharedMemoryManager(self.config) self.shm_mgr.connect(self.SHM_SOCKET_NAME_TEMPLATE % self.pid) self.shm_key_conns = self.SHM_KEY_TMP_CONNS % self.pid self.shm_mgr.create_key_and_ignore_conflict( self.shm_key_conns, SHMContainerTypes.DICT, )
def launch_shm_worker(): global worker_pid, shm_mgr pid = os.fork() if pid == -1: raise OSError('fork failed, unable to run SharedMemoryManager') # run SHM worker in the child process elif pid == 0: shm_mgr = SharedMemoryManager(config) shm_mgr.run_as_worker() # send testing request in the parent process else: # wait for shm worker time.sleep(1) worker_pid = pid return pid
def _start_shm_mgr(self): self.shm_mgr = SharedMemoryManager(self.config) # start SharedMemoryManager worker pid = os.fork() if pid == -1: raise OSError('fork failed') elif pid == 0: self._sig_shm_worker() try: self.shm_mgr.run_as_worker() except Exception: err_msg = traceback.format_exc() shm_logger.error( f'Unexpected error occurred, SHM worker crashed. ' f'Traceback:\n{err_msg}') sys.exit(1) sys.exit(0) # the sub-process ends here else: self.shm_worker_pid = pid logger.info(f'Started SharedMemoryManager: {pid}')
def __init__( self, config, efferent, logic_handler, protocol_wrapper, main_afferent, minor_afferents=tuple(), ): ''' constructor :param config: the config :param efferent: an efferent instance :param logic_handler: a logic handler instance :param protocol_wrapper: a protocol wrapper instance :param main_afferent: the main afferent :param minor_afferents: a group of minor afferents, any iterable type contains afferent instances ''' self.__running = False self.core_id = None self._epoll = select.epoll() self.afferent_mapping = {} self.config = config self.main_afferent = main_afferent self.efferent = efferent self.logic_handler = logic_handler self.protocol_wrapper = protocol_wrapper self.shm_mgr = SharedMemoryManager(self.config) self.plug_afferent(self.main_afferent) for afferent in minor_afferents: self.plug_afferent(afferent)
def test_1_set(self): KEY_FOR_SET = 'set_test' VALUE_FOR_SET = 'aaaaaaaaaaaaa' print('\n\n=====================set-value====================') shm_mgr = SharedMemoryManager(config, sensitive=False) shm_mgr.connect('test_set') print('conn_id: ', shm_mgr.current_connection.conn_id) self.assertIsInstance(shm_mgr.current_connection.conn_id, str) resp = shm_mgr.create_key(KEY_FOR_SET, SHMContainerTypes.STR) self.assertTrue(resp.get('succeeded')) resp = shm_mgr.set_value(KEY_FOR_SET, VALUE_FOR_SET) self.assertTrue(resp.get('succeeded')) resp = shm_mgr.read_key(KEY_FOR_SET) self.assertTrue(resp.get('succeeded')) self.assertEqual(resp.get('value'), VALUE_FOR_SET) print(resp) shm_mgr.disconnect() self.assertEqual(shm_mgr.current_connection, None)
class BaseCore(): ''' The base model of cores Literally, the core is supposed to be a kernel-like component. It organizes other components to work together and administrate them. Some components are plugable, and some others are necessary. Here is the list of all components: afferents in neverland.afferents, plugable efferents in neverland.efferents, necessary logic handlers in neverland.logic, necessary protocol wrappers in neverland.protocol, necessary In the initial version, all these components are necessary, and afferents could be multiple. ''' EV_MASK = select.EPOLLIN SHM_SOCKET_NAME_TEMPLATE = 'SHM-Core-%d.socket' # SHM container for containing allocated core id # data structure: # [1, 2, 3, 4] SHM_KEY_CORE_ID = 'Core_id' # The shared status of cluster controlling, # enumerated in neverland.core.state.ClusterControllingStates SHM_KEY_CC_STATE = 'Core_CCState' def __init__( self, config, efferent, logic_handler, protocol_wrapper, main_afferent, minor_afferents=tuple(), ): ''' constructor :param config: the config :param efferent: an efferent instance :param logic_handler: a logic handler instance :param protocol_wrapper: a protocol wrapper instance :param main_afferent: the main afferent :param minor_afferents: a group of minor afferents, any iterable type contains afferent instances ''' self.__running = False self.core_id = None self._epoll = select.epoll() self.afferent_mapping = {} self.config = config self.main_afferent = main_afferent self.efferent = efferent self.logic_handler = logic_handler self.protocol_wrapper = protocol_wrapper self.shm_mgr = SharedMemoryManager(self.config) self.plug_afferent(self.main_afferent) for afferent in minor_afferents: self.plug_afferent(afferent) def set_cc_state(self, status): self.shm_mgr.set_value(self.SHM_KEY_CC_STATE, status) def get_cc_state(self): resp = self.shm_mgr.read_key(self.SHM_KEY_CC_STATE) return resp.get('value') @property def cc_state(self): return self.get_cc_state() def init_shm(self): self.shm_mgr.connect(self.SHM_SOCKET_NAME_TEMPLATE % NodeContext.pid) self.shm_mgr.create_key_and_ignore_conflict( self.SHM_KEY_CORE_ID, SHMContainerTypes.LIST, ) self.shm_mgr.create_key_and_ignore_conflict( self.SHM_KEY_CC_STATE, SHMContainerTypes.INT, CCStates.INIT, ) logger.debug(f'init_shm for core of worker {NodeContext.pid} has done') def close_shm(self): self.shm_mgr.disconnect() def self_allocate_core_id(self): ''' Let the core pick up an id for itself ''' try: resp = self.shm_mgr.lock_key(self.SHM_KEY_CORE_ID) except SHMResponseTimeout: # Currently, SHM_MAX_BLOCKING_TIME is 4 seconds and # these works can be definitely done in 4 seconds. # If a SHMResponseTimeout occurred, then there must # be a deadlock raise SHMResponseTimeout( f'deadlock of key: {self.SHM_KEY_CORE_ID}') resp = self.shm_mgr.read_key(self.SHM_KEY_CORE_ID) allocated_id = resp.get('value') if len(allocated_id) == 0: id_ = 0 else: last_id = allocated_id[-1] id_ = last_id + 1 self.shm_mgr.add_value( self.SHM_KEY_CORE_ID, [id_], ) self.core_id = id_ self.shm_mgr.unlock_key(self.SHM_KEY_CORE_ID) logger.debug( f'core of worker {NodeContext.pid} has self-allocated id: {id_}') def plug_afferent(self, afferent): self._epoll.register(afferent.fd, self.EV_MASK) self.afferent_mapping.update({afferent.fd: afferent}) def unplug_afferent(self, fd): ''' remove an afferent from the core :param fd: the file discriptor of the afferent, int ''' if fd not in self.afferent_mapping: return self._epoll.unregister(fd) self.afferent_mapping.pop(fd) def request_to_join_cluster(self): ''' send a request of the node is going to join the cluster ''' entrance = self.config.cluster_entrance identification = self.config.net.identification if entrance is None: raise ConfigError("cluster_entrance is not defined") if identification is None: raise ConfigError("identification is not defined") logger.info('Trying to join cluster...') content = { 'identification': identification, 'ip': get_localhost_ip(), 'listen_port': self.config.net.aff_listen_port, } subject = CCSubjects.JOIN_CLUSTER dest = (entrance.ip, entrance.port) pkt = UDPPacket() pkt.fields = ObjectifiedDict( type=PktTypes.CTRL, dest=dest, subject=subject, content=content, ) pkt.next_hop = dest pkt = self.protocol_wrapper.wrap(pkt) NodeContext.pkt_mgr.repeat_pkt(pkt) logger.info( f'Sending request to cluster entrance {entrance.ip}:{entrance.port}' ) logger.info('[Node Status] WAITING_FOR_JOIN') self.set_cc_state(CCStates.WAITING_FOR_JOIN) def request_to_leave_cluster(self): ''' send a request of the node is going to detach from the cluster ''' entrance = self.config.cluster_entrance identification = self.config.net.identification if entrance is None: raise ConfigError("cluster_entrance is not defined") if identification is None: raise ConfigError("identification is not defined") logger.info('Trying to leave cluster...') content = {"identification": identification} subject = CCSubjects.LEAVE_CLUSTER dest = (entrance.ip, entrance.port) pkt = UDPPacket() pkt.fields = ObjectifiedDict( type=PktTypes.CTRL, dest=dest, subject=subject, content=content, ) pkt.next_hop = dest pkt = self.protocol_wrapper.wrap(pkt) NodeContext.pkt_mgr.repeat_pkt(pkt) logger.info( f'Sent request to cluster entrance {entrance.ip}:{entrance.port}') logger.info('[Node Status] WAITING_FOR_LEAVE') self.set_cc_state(CCStates.WAITING_FOR_LEAVE) def handle_pkt(self, pkt): pkt = self.protocol_wrapper.unwrap(pkt) if not pkt.valid: return try: pkt = self.logic_handler.handle_logic(pkt) except DropPacket: return pkt = self.protocol_wrapper.wrap(pkt) self.efferent.transmit(pkt) def _poll(self): events = self._epoll.poll(POLL_TIMEOUT) for fd, evt in events: afferent = self.afferent_mapping[fd] if evt & select.EPOLLERR: self.unplug_afferent(fd) afferent.destroy() elif evt & select.EPOLLIN: pkt = afferent.recv() self.handle_pkt(pkt) def run(self): self.set_cc_state(CCStates.WORKING) self.__running = True self.main_afferent.listen() addr = self.main_afferent.listen_addr port = self.main_afferent.listen_port logger.info(f'Main afferent is listening on {addr}:{port}') while self.__running: self._poll() def run_for_a_while(self, duration=None, polling_times=None): ''' run the core within the specified duration time or poll times :param duration: the duration time, seconds in int :param polling_times: times the _poll method shall be invoked, int These 2 arguments will not work together, if duration specified, the poll_times will be ignored ''' if duration is None and polling_times is None: raise ArgumentError('no argument passed') self.main_afferent.listen() addr = self.main_afferent.listen_addr port = self.main_afferent.listen_port logger.info(f'Main afferent is listening on {addr}:{port}') if duration is not None: starting_time = time.time() while time.time() - starting_time <= duration: self._poll() elif polling_times is not None: for _ in polling_times: self._poll() def shutdown(self): self.__running = False
def test_999_backlog(self): global do_not_kill_shm_worker key = 'bl0' shm_mgr = SharedMemoryManager(config, sensitive=False) shm_mgr.connect('test_bl') shm_mgr.create_key(key, SHMContainerTypes.LIST) shm_mgr.lock_key(key) pid = os.fork() if pid == -1: raise OSError('fork failed') if pid == 0: # This socket will cause a warning, so we close it here shm_mgr.current_connection.socket.close() do_not_kill_shm_worker = True shm_mgr1 = SharedMemoryManager(config, sensitive=False) shm_mgr1.connect('test_bl1') print('\n\n==============access-locked-container===============') t0 = time.time() resp = shm_mgr1.read_key(key) t1 = time.time() print(resp) delay = t1 - t0 print(f'delay: {delay}') self.assertEqual(resp.get('succeeded'), True) self.assertTrue(delay >= 1) shm_mgr1.disconnect() else: time.sleep(1) print('------------release-lock-------------') shm_mgr.unlock_key(key) shm_mgr.disconnect() os.waitpid(-1, 0)
def test_2_rcode(self): shm_mgr = SharedMemoryManager(config, sensitive=False) shm_mgr.connect('test_err') print('\n\n=====================key-error-test====================') resp = shm_mgr.add_value( key='not_exists', value='a', backlogging=False, ) print(resp) self.assertEqual(resp.get('succeeded'), False) self.assertEqual(resp.get('rcode'), ReturnCodes.KEY_ERROR) print('\n\n=====================type-error-test====================') resp = shm_mgr.create_key( 'testing', SHMContainerTypes.DICT, value='a', backlogging=False, ) print(resp) self.assertEqual(resp.get('succeeded'), False) self.assertEqual(resp.get('rcode'), ReturnCodes.TYPE_ERROR) print( '\n\n===================unlock-not-locked-test==================') resp = shm_mgr.unlock_key('testing') print(resp) self.assertEqual(resp.get('succeeded'), True) self.assertEqual(resp.get('rcode'), ReturnCodes.NOT_LOCKED) print('\n\n=====================lock-test====================') print('------create-container------') resp = shm_mgr.create_key( 'testing', SHMContainerTypes.LIST, backlogging=False, ) self.assertEqual(resp.get('succeeded'), True) self.assertEqual(resp.get('rcode'), ReturnCodes.OK) print('------lock-container------') resp = shm_mgr.lock_key('testing') print(resp) self.assertEqual(resp.get('succeeded'), True) self.assertEqual(resp.get('rcode'), ReturnCodes.OK) print('------access-locked-container------') shm_mgr1 = SharedMemoryManager(config, sensitive=False) shm_mgr1.connect('test_err_1') resp = shm_mgr1.add_value( 'testing', [1, 2, 3], backlogging=False, ) print(resp) self.assertEqual(resp.get('succeeded'), False) self.assertEqual(resp.get('rcode'), ReturnCodes.LOCKED) print('------unlock-container----------') resp = shm_mgr.unlock_key('testing') print(resp) self.assertEqual(resp.get('succeeded'), True) self.assertEqual(resp.get('rcode'), ReturnCodes.OK) print('------access-locked-container-again------') resp = shm_mgr1.add_value( 'testing', [1, 2, 3], backlogging=False, ) print(resp) self.assertEqual(resp.get('succeeded'), True) self.assertEqual(resp.get('rcode'), ReturnCodes.OK) shm_mgr.disconnect() shm_mgr1.disconnect()
def test_0_normal_ops(self): shm_mgr = SharedMemoryManager(config, sensitive=False) shm_mgr.connect('test') print('conn_id: ', shm_mgr.current_connection.conn_id) self.assertIsInstance(shm_mgr.current_connection.conn_id, str) for td in DATA: print(f'\n==================={td["name"]}===================\n') # try to remove this KEY first resp = shm_mgr.clean_key(KEY) if not resp.get('succeeded'): self.assertEqual(resp.get('rcode'), ReturnCodes.KEY_ERROR) resp = shm_mgr.create_key(KEY, td['type'], td['2_create']) print('\n-----------Create-----------') print(resp) self.assertTrue(resp.get('succeeded')) resp = shm_mgr.read_key(KEY) print('\n-----------Read-----------') print(resp) self.assertTrue(resp.get('succeeded')) resp = shm_mgr.add_value(KEY, td['2_add']) print('\n-----------Add-----------') print(resp) self.assertTrue(resp.get('succeeded')) resp = shm_mgr.read_key(KEY) print('\n-----------Read-----------') print(resp) self.assertTrue(resp.get('succeeded')) resp = shm_mgr.remove_value(KEY, td['2_remove']) print('\n-----------Remove-----------') print(resp) self.assertTrue(resp.get('succeeded')) resp = shm_mgr.read_key(KEY) print('\n-----------Read-----------') print(resp) self.assertTrue(resp.get('succeeded')) self.assertEqual( resp.get('value')[td['remaining_key']], td['remaining'], ) shm_mgr.disconnect() self.assertEqual(shm_mgr.current_connection, None)
class BaseNode(): ''' The Base Class of Nodes ''' role = None def __init__(self, config, role=None): self.config = config self.role = role or self.role self.worker_pids = [] self.shm_worker_pid = None self.pkt_rpter_worker_pid = None self.node_id = self.config.basic.node_id def _write_master_pid(self): pid_path = self.config.basic.pid_file pid = os.getpid() with open(pid_path, 'w') as f: f.write(str(pid)) logger.debug( f'wrote pid file {pid_path} for master process, pid: {pid}') def _read_master_pid(self): pid_path = self.config.basic.pid_file try: with open(pid_path, 'r') as f: content = f.read() except FileNotFoundError: raise PidFileNotExists try: return int(content) except ValueError: raise ValueError('pid file has beed tampered') def _handle_term_master(self, signal, sf): logger.debug(f'Master process received signal: {signal}') logger.debug(f'The master process starts to shut down workers') self.shutdown_workers() pid_path = self.config.basic.pid_file if os.path.isfile(pid_path): logger.debug(f'Remove pid file: {pid_path}') os.remove(pid_path) def _handle_term_worker(self, signal, sf): pid = os.getpid() logger.debug(f'Worker {pid} received signal: {signal}') logger.debug(f'Shutting down worker {pid}') self.core.shutdown() def _handle_term_shm(self, signal, sf): pid = os.getpid() logger.debug(f'SharedMemoryManager {pid} received signal: {signal}') logger.debug(f'Shutting down SharedMemoryManager {pid}') self.shm_mgr.shutdown_worker() def _handle_term_pkt_rpter(self, signal, sf): pid = os.getpid() logger.debug(f'SpecialPacketRepeater {pid} received signal: {signal}') logger.debug(f'Shutting down SpecialPacketRepeater {pid}') self.pkt_rpter.shutdown() def _sig_master(self): sig.signal(sig.SIGHUP, sig.SIG_IGN) for s in TERM_SIGNALS: sig.signal(s, self._handle_term_master) def _sig_normal_worker(self): sig.signal(sig.SIGHUP, sig.SIG_IGN) for s in TERM_SIGNALS: sig.signal(s, self._handle_term_worker) def _sig_shm_worker(self): sig.signal(sig.SIGHUP, sig.SIG_IGN) for s in TERM_SIGNALS: sig.signal(s, self._handle_term_shm) def _sig_pkt_rpter_worker(self): sig.signal(sig.SIGHUP, sig.SIG_IGN) for s in TERM_SIGNALS: sig.signal(s, self._handle_term_pkt_rpter) def shutdown_workers(self): for pid in self.worker_pids: self._kill(pid) # wait for workers to exit remaining = list(self.worker_pids) while True: for pid in list(remaining): if self._process_exists(pid): os.waitpid(pid, os.WNOHANG) else: logger.debug(f'Worker {pid} terminated') remaining.remove(pid) if len(remaining) == 0: break time.sleep(0.5) # shutdown SharedMemoryManager worker at last shm_pid = self.shm_worker_pid self._kill(shm_pid) os.waitpid(shm_pid, 0) logger.debug(f'SharedMemoryManager worker {shm_pid} terminated') logger.debug('All workers terminated') def _kill(self, pid): try: logger.debug(f'Sending SIGTERM to {pid}') os.kill(pid, sig.SIGTERM) except ProcessLookupError: pass def _process_exists(self, pid): try: os.kill(pid, 0) except OSError: logger.debug(f'Process {pid} not exists') return False else: logger.debug(f'Process {pid} exists') return True def daemonize(self): pid = os.fork() if pid == -1: raise OSError('fork failed when doing daemonize') elif pid > 0: # double fork magic sys.exit(0) pid = os.fork() if pid == -1: raise OSError('fork failed when doing daemonize') def quit(sg, sf): sys.exit(0) if pid > 0: for s in TERM_SIGNALS: sig.signal(s, quit) time.sleep(5) else: self._sig_master() ppid = os.getppid() os.kill(ppid, sig.SIGTERM) os.setsid() logger.debug('Node daemonized') def _start_shm_mgr(self): self.shm_mgr = SharedMemoryManager(self.config) # start SharedMemoryManager worker pid = os.fork() if pid == -1: raise OSError('fork failed') elif pid == 0: self._sig_shm_worker() try: self.shm_mgr.run_as_worker() except Exception: err_msg = traceback.format_exc() shm_logger.error( f'Unexpected error occurred, SHM worker crashed. ' f'Traceback:\n{err_msg}') sys.exit(1) sys.exit(0) # the sub-process ends here else: self.shm_worker_pid = pid logger.info(f'Started SharedMemoryManager: {pid}') def _start_pkt_rpter(self): ''' The Repeater needs some components from the node, so we shouldn't use this method before components are initialized ''' pid = os.fork() if pid == -1: raise OSError('fork failed') elif pid == 0: self._sig_pkt_rpter_worker() self.pkt_rpter = SpecialPacketRepeater( self.config, self.efferent, self.protocol_wrapper, ) try: self.pkt_rpter.init_shm() self.pkt_rpter.run() except Exception: err_msg = traceback.format_exc() logger.error( f'Unexpected error occurred, SpecialPacketRepeater worker ' f'crashed. Traceback:\n{err_msg}') sys.exit(1) sys.exit(0) # the sub-process ends here else: self.pkt_rpter = SpecialPacketRepeater( self.config, self.efferent, self.protocol_wrapper, ) self.pkt_rpter_worker_pid = pid logger.info(f'Started SpecialPacketRepeater: {pid}') def _load_modules(self): self.afferent_cls = AFFERENT_MAPPING[self.role] self.main_afferent = self.afferent_cls(self.config) self.efferent = UDPTransmitter(self.config) self.protocol_wrapper = ProtocolWrapper( self.config, HeaderFormat, DataPktFormat, CtrlPktFormat, ConnCtrlPktFormat, ) self.logic_handler_cls = LOGIC_HANDLER_MAPPING[self.role] self.logic_handler = self.logic_handler_cls(self.config) self.core_cls = CORE_MAPPING.get(self.role) self.core = self.core_cls( self.config, main_afferent=self.main_afferent, minor_afferents=[], efferent=self.efferent, logic_handler=self.logic_handler, protocol_wrapper=self.protocol_wrapper, ) self.pkt_mgr = SpecialPacketManager(self.config) # The packet repeater is a part of the packet manager, so we will # use it as a normal module. Each worker shall have it's own packet # repeater but not share it like the shared memory manager worker self._start_pkt_rpter() self.logic_handler.init_shm() self.core.init_shm() self.core.self_allocate_core_id() self.pkt_mgr.init_shm() pid = os.getpid() logger.debug(f'Worker {pid} loaded modules') def _clean_modules(self): self._kill(self.pkt_rpter_worker_pid) os.waitpid(self.pkt_rpter_worker_pid, 0) logger.debug(f'SpecialPacketRepeater worker ' f'{self.pkt_rpter_worker_pid} terminated') self.core.shutdown() self.main_afferent.destroy() self.logic_handler.close_shm() self.core.close_shm() self.pkt_mgr.close_shm() self.main_afferent = None self.efferent = None self.protocol_wrapper = None self.logic_handler = None self.core = None self.pkt_mgr = None pid = os.getpid() logger.debug(f'Worker {pid} cleaned modules') def get_context(): return NodeContext def _create_context(self): NodeContext.pkt_rpter_pid = self.pkt_rpter_worker_pid NodeContext.local_ip = get_localhost_ip() NodeContext.listen_port = self.config.net.aff_listen_port NodeContext.core = self.core NodeContext.main_efferent = self.efferent NodeContext.protocol_wrapper = self.protocol_wrapper NodeContext.pkt_mgr = self.pkt_mgr NodeContext.id_generator = IDGenerator(self.node_id, self.core.core_id) pid = os.getpid() logger.debug(f'Worker {pid} created NodeContext') def _clean_context(self): NodeContext.pkt_rpter_pid = None NodeContext.id_generator = None NodeContext.local_ip = None NodeContext.listen_port = None NodeContext.core = None NodeContext.main_efferent = None NodeContext.protocol_wrapper = None pid = os.getpid() logger.debug(f'Worker {pid} cleaned NodeContext') def join_cluster(self): if self.role == Roles.CONTROLLER: raise RuntimeError( 'Controller node is the root node of the cluster') self.core.request_to_join_cluster() self.core.run_for_a_while(5) raise TimeoutError def run(self): pid_fl = self.config.basic.pid_file try: pid = self._read_master_pid() logger.warn( f'\n\tThe Neverland node is already running or the pid file\n' f'\t{pid_fl} is not removed, current pid: {pid}.\n' f'\tMake sure that the node is not running and try again.\n\n' f'\tIf you need to run multiple node on this computer, then\n' f'\tyou need to at least configure another pid file for it.') return except ValueError: logger.error( f'\n\tThe pid file {pid_fl} exists but seems it\'s not\n' f'\twritten by the Neverland node. Please make sure the node\n' f'\tis not running and the pid file is not occupied.') return except PidFileNotExists: pass self.daemonize() NodeContext.pid = os.getpid() self._write_master_pid() self._start_shm_mgr() # Before we start workers, we need to join the cluster first. if self.role != Roles.CONTROLLER: # Before we join the cluster, we need to load modules at first, # once we have joined the cluster, modules in the Master worker # shall be removed. self._load_modules() self._create_context() try: self.join_cluster() except SuccessfullyJoinedCluster: logger.info('Successfully joined the cluster.') except FailedToJoinCluster: logger.error('Cannot join the cluster, request not permitted') self._clean_modules() self._clean_context() self._on_break() return except TimeoutError: logger.error( 'No response from entrance node, Failed to join the cluster' ) self._clean_modules() self._clean_context() self._on_break() return self._clean_modules() self._clean_context() # start normal workers worker_amount = self.config.basic.worker_amount for _ in range(worker_amount): pid = os.fork() NodeContext.pid = os.getpid() if pid == -1: raise OSError('fork failed') elif pid == 0: self._sig_normal_worker() self._load_modules() self._create_context() try: self.core.run() except Exception: err_msg = traceback.format_exc() logger.error(f'Unexpected error occurred, node crashed. ' f'Traceback:\n{err_msg}') self._clean_modules() self._clean_context() sys.exit(1) self._clean_modules() self._clean_context() sys.exit(0) # the sub-process ends here else: self.worker_pids.append(pid) logger.info(f'Started Worker: {pid}') while True: try: os.waitpid(-1, 0) except ChildProcessError: break def shutdown(self): pid = self._read_master_pid() self._kill(pid) logger.info('Sent SIGTERM to the master process') def _on_break(self): ''' a hook that needs to be invoked while self.run has been broken by some exception ''' shm_pid = self.shm_worker_pid self._kill(shm_pid) os.waitpid(shm_pid, 0) logger.debug(f'SharedMemoryManager worker {shm_pid} terminated') pid_fl = self.config.basic.pid_file os.remove(pid_fl) logger.debug(f'Removed pid file: {pid_fl}') logger.info('Master process exits.\n\n')
class ConnectionManager(): ''' The Connection Manager We will store all informations of established connections in the shared memory. This ConnectionManager is aimed on converting these informations between JSONs and Connection objects. Providing Connection objects to the upper layer and store Connection objects in the shared memory in JSONs. As a manager, it should provide functionalities of establishing connection and closing connection as well. ''' SHM_SOCKET_NAME_TEMPLATE = 'SHM-ConnectionManager-%d.socket' # The SHM container to store established connections. # # Data structure: # { # "ip:port": { # "slot-0": { # "status": int, # "sn": int, # "iv": b64encode(iv), # "iv_duration": int, # }, # "slot-1": { # "status": int, # "sn": int, # "iv": b64encode(iv), # "iv_duration": int, # }, # "slot-2": { # "status": int, # "sn": int, # "iv": b64encode(iv), # "iv_duration": int, # }, # } # } SHM_KEY_TMP_CONNS = 'ConnectionManager-%d_Conns' def __init__(self, config): self.config = config self.iv_len = self.config.net.crypto.iv_len self.iv_duration_range = self.config.net.crypto.iv_duration_range if not 0 < self.iv_len < EVP_MAX_IV_LENGTH: raise ArgumentError('iv_len out of range') self.pid = NodeContext.pid def init_shm(self): ''' initialize the shared memory manager ''' self.shm_mgr = SharedMemoryManager(self.config) self.shm_mgr.connect(self.SHM_SOCKET_NAME_TEMPLATE % self.pid) self.shm_key_conns = self.SHM_KEY_TMP_CONNS % self.pid self.shm_mgr.create_key_and_ignore_conflict( self.shm_key_conns, SHMContainerTypes.DICT, ) def _remote_sa_2_key(self, remote): ''' convert remote socket address to a key string ''' ip = remote[0] port = remote[1] return f'{ip}:{port}' def _get_native_conn_info(self, remote): ''' get the native JSON data of connections ''' remote_name = self._remote_sa_2_key(remote) shm_data = self.shm_mgr.get_dict_value(self.shm_key_conns, remote_name) shm_value = shm_data.get('value') if shm_value is None: return { SLOT_0: None, SLOT_1: None, SLOT_2: None, } else: return shm_value def get_conns(self, remote): ''' get all connections of a remote node :param remote: socket address in tuple format, (ip, port) :returns: a dict of Connection objects: { SLOT_0: conn, SLOT_1: conn, SLOT_2: conn, } ''' ip = remote[0] port = remote[1] native_info = self._get_native_conn_info(remote) result = dict() for slot_name in SLOTS: conn_info = native_info.get(slot_name) if conn_info is None: conn = None else: iv = conn_info.get('iv') if iv is not None: iv = base64.b64decode(iv) conn_info.update( slot=slot_name, remote={ 'ip': ip, 'port': port }, iv=iv, ) conn = Connection(**conn_info) result.update({slot_name: conn}) return result def store_conn(self, conn, slot, override=False): ''' store a connection object to a slot :param conn: a Connection object :param slot: slot name, enumerated in SLOTS ''' remote = (conn.remote.ip, conn.remote.port) remote_name = self._remote_sa_2_key(remote) if not override: usable_slots = self.get_usable_slots(remote) if slot not in usable_slots: raise ConnSlotNotAvailable iv = conn.iv if iv is not None: # the base64 string must be str but not bytes iv = b64encode(iv).decode() conn_info = conn.__to_dict__() conn_info.update(iv=iv) self.shm_mgr.update_dict( key=self.shm_key_conns, dict_key=remote_name, value=conn_info, ) def get_usable_slots(self, remote): ''' get all usable slots of a remote :param remote: remote socket address, (ip, port) :returns: a list of slot names ''' usable_slots = list(SLOTS) conns = self.get_conns(remote) for slot in SLOTS: conn = conns.get(slot) if conn is None: break else: usable_slots.remove(slot) return usable_slots def new_conn(self, remote, synchronous=False, timeout=2, interval=0.1): ''' establish a new connection The establishing connection will be placed in slot-2. After the new connection is established, if we have established 2 connections with the specified node already then the connection in slot-0 will be removed and the connection in slot-1 will be moved to slot-0. The new connection will be placed in slot-1. :param remote: remote socket address, (ip, port) :param synchronous: If the sync argument is True, then the new_conn method will try to wait the connection complete and return a connection object. This operation will be blocking until it reaches the timeout or the connection completes. If the sync argument is False, then the new_conn method will return None immediately without waiting. :param timeout: seconds to timeout :param interval: the interval time of connection checking ''' usable_slots = self.get_usable_slots(remote) remote_name = self._remote_sa_2_key(remote) if SLOT_2 not in usable_slots: raise ConnSlotNotAvailable(f'slot-2 to {remote_name} is in using, ' f'cannot establish connection now') iv = os.urandom(self.iv_len) iv_duration = random.randint(*self.iv_duration_range) pkt = UDPPacket() pkt.fields = ObjectifiedDict( type=PktTypes.CONN_CTRL, dest=remote, communicating=1, iv_changed=1, iv_duration=iv_duration, iv=iv, ) pkt.next_hop = remote pkt = NodeContext.protocol_wrapper.wrap(pkt) NodeContext.pkt_mgr.repeat_pkt(pkt) conn_sn = NodeContext.id_generator.gen() conn = { "remote": { "ip": remote[0], "port": remote[1], }, "sn": conn_sn, "state": ConnStates.ESTABLISHING, "slot": SLOT_2, "iv": iv, "iv_duration": iv_duration, } conn = Connection(**conn) # though we have checked the SLOT_2 already, # but it still has the possibility... try: self.store_conn(conn, SLOT_2) except ConnSlotNotAvailable: logger.warn( f'slot-2 to {remote_name} seized, abort the establishment') # The request is sending, now we wait for the response if not synchronous: return None # while timeout > 0: # # TODO complete this after the response processing logic is done # timeout -= interval # time.sleep(interval) def get_conn(self, remote): ''' get a Connection object of the specified remote :param remote: remote socket address, (ip, port) :returns: Connection object ''' conns = self.get_conns() # according to the explanations above, the priority of slots is 1 > 0 conn_s1 = conns.get(SLOT_1) if conn_s1 is not None and conn_s1.state == ConnStates.ESTABLISHED: return conn_s1 conn_s0 = conns.get(SLOT_0) if conn_s0 is not None and conn_s1.state == ConnStates.ESTABLISHED: return conn_s0 raise NoConnAvailable def remove_conn(self, remote, slot): ''' close a connection :param remote: remote socket address, (ip, port) :param slot: slot name, enumerated in SLOTS ''' ip = remote[0] port = remote[1] remote_name = f'{ip}:{port}' native_info = self._get_native_conn_info(remote) native_info[slot] = None self.shm_mgr.update_dict(remote_name, native_info)
class SpecialPacketManager(): SHM_SOCKET_NAME_TEMPLATE = 'SHM-SpecialPacketManager-%d.socket' def __init__(self, config): self.config = config self.pid = NodeContext.pid self.shm_key_pkts = SHM_KEY_PKTS # These containers are for the SpecialPacketRepeater, the repeater # will also access special packets by the manager. self.shm_key_pkts_to_repeat = SHM_KEY_TMP_PKTS_TO_REPEAT % self.pid self.shm_key_last_repeat_time = SHM_KEY_TMP_LAST_REPEAT_TIME % self.pid self.shm_key_next_repeat_time = SHM_KEY_TMP_NEXT_REPEAT_TIME % self.pid self.shm_key_max_repeat_times = SHM_KEY_TMP_MAX_REPEAT_TIMES % self.pid self.shm_key_repeated_times = SHM_KEY_TMP_REPEATED_TIMES % self.pid def init_shm(self): ''' initialize the shared memory manager ''' self.shm_mgr = SharedMemoryManager(self.config) self.shm_mgr.connect(self.SHM_SOCKET_NAME_TEMPLATE % self.pid) self.shm_mgr.create_key_and_ignore_conflict( self.shm_key_pkts, SHMContainerTypes.DICT, ) self.shm_mgr.create_key_and_ignore_conflict( self.shm_key_pkts_to_repeat, SHMContainerTypes.LIST, ) self.shm_mgr.create_key_and_ignore_conflict( self.shm_key_last_repeat_time, SHMContainerTypes.DICT, ) self.shm_mgr.create_key_and_ignore_conflict( self.shm_key_next_repeat_time, SHMContainerTypes.DICT, ) self.shm_mgr.create_key_and_ignore_conflict( self.shm_key_max_repeat_times, SHMContainerTypes.DICT, ) self.shm_mgr.create_key_and_ignore_conflict( self.shm_key_repeated_times, SHMContainerTypes.DICT, ) def close_shm(self): self.shm_mgr.disconnect() def store_pkt(self, pkt, need_repeat=False, max_rpt_times=5): sn = pkt.fields.sn type_ = pkt.fields.type # The salt field is bytes, so we cannot serialize it in a JSON. # So, we shall encode it into a base64 string before store it. fields = pkt.fields.__to_dict__() salt = fields.get('salt') if salt is not None: salt_b64 = base64.b64encode(salt).decode() fields.update(salt=salt_b64) previous_hop = list(pkt.previous_hop) next_hop = list(pkt.next_hop) if sn is None: raise InvalidPkt( 'Packets to be stored must contain a serial number') value = { sn: { 'type': type_, 'fields': fields, 'previous_hop': previous_hop, 'next_hop': next_hop, } } self.shm_mgr.lock_key(self.shm_key_pkts) self.shm_mgr.add_value(self.shm_key_pkts, value) self.shm_mgr.unlock_key(self.shm_key_pkts) if need_repeat: self.shm_mgr.add_value(self.shm_key_pkts_to_repeat, [sn]) self.set_pkt_max_repeat_times(sn, max_rpt_times) self.set_pkt_repeated_times(sn, 0) hex_type = Converter.int_2_hex(type_) logger.debug(f'Stored a special packet, need_repeat: {need_repeat}, ' f'sn: {sn}, type: {hex_type}, dest: {pkt.fields.dest}') def get_pkt(self, sn): shm_data = self.shm_mgr.get_value(self.shm_key_pkts, sn) shm_value = shm_data.get('value') if shm_value is None: return None # and here, we restore the base64 encoded salt into bytes fields = shm_value.get('fields') salt_b64 = fields.get('salt') salt = base64.b64decode(salt_b64) fields.update(salt=salt) return UDPPacket( fields=fields, type=shm_value.get('type'), previous_hop=shm_value.get('previous_hop'), next_hop=shm_value.get('next_hop'), ) def remove_pkt(self, sn): self.cancel_repeat(sn) self.shm_mgr.lock_key(self.shm_key_pkts) self.shm_mgr.remove_value(self.shm_key_pkts, sn) self.shm_mgr.unlock_key(self.shm_key_pkts) logger.debug(f'Removed a special packet, sn: {sn}') def cancel_repeat(self, sn): self.shm_mgr.remove_value(self.shm_key_pkts_to_repeat, sn) self.shm_mgr.remove_value(self.shm_key_last_repeat_time, sn) self.shm_mgr.remove_value(self.shm_key_next_repeat_time, sn) self.shm_mgr.remove_value(self.shm_key_max_repeat_times, sn) self.shm_mgr.remove_value(self.shm_key_repeated_times, sn) logger.debug(f'Cancelled repeat for a packet, sn: {sn}') def repeat_pkt(self, pkt, max_rpt_times=5): self.store_pkt(pkt, need_repeat=True, max_rpt_times=max_rpt_times) def get_repeating_sn_list(self): shm_data = self.shm_mgr.read_key(self.shm_key_pkts_to_repeat) return shm_data.get('value') def set_pkt_last_repeat_time(self, sn, timestamp): self.shm_mgr.add_value(self.shm_key_last_repeat_time, {sn: timestamp}) logger.debug(f'set_pkt_last_repeat_time, sn: {sn}, ts: {timestamp}') def get_pkt_last_repeat_time(self, sn): shm_data = self.shm_mgr.get_value(self.shm_key_last_repeat_time, sn) return shm_data.get('value') def set_pkt_next_repeat_time(self, sn, timestamp): self.shm_mgr.add_value(self.shm_key_next_repeat_time, {sn: timestamp}) logger.debug(f'set_pkt_next_repeat_time, sn: {sn}, ts: {timestamp}') def get_pkt_next_repeat_time(self, sn): shm_data = self.shm_mgr.get_value(self.shm_key_next_repeat_time, sn) return shm_data.get('value') def set_pkt_max_repeat_times(self, sn, times): self.shm_mgr.add_value(self.shm_key_max_repeat_times, {sn: times}) logger.debug(f'set_pkt_max_repeat_times, sn: {sn}, times: {times}') def get_pkt_max_repeat_times(self, sn): shm_data = self.shm_mgr.get_value(self.shm_key_max_repeat_times, sn) return shm_data.get('value') def set_pkt_repeated_times(self, sn, times): self.shm_mgr.add_value(self.shm_key_repeated_times, {sn: times}) logger.debug(f'set_pkt_repeated_times, sn: {sn}, times: {times}') def get_pkt_repeated_times(self, sn): shm_data = self.shm_mgr.get_value(self.shm_key_repeated_times, sn) return shm_data.get('value') def increase_pkt_repeated_times(self, sn): repeated_times = self.get_pkt_repeated_times(sn) if repeated_times is None: repeated_times = 1 else: repeated_times += 1 self.set_pkt_repeated_times(sn, repeated_times)
def __init__(self, config): self.config = config self.shm_mgr = SharedMemoryManager(self.config)