コード例 #1
0
    def init_shm(self):
        ''' initialize the shared memory manager
        '''

        self.shm_mgr = SharedMemoryManager(self.config)
        self.shm_mgr.connect(self.SHM_SOCKET_NAME_TEMPLATE % self.pid)
        self.shm_mgr.create_key_and_ignore_conflict(
            self.shm_key_pkts,
            SHMContainerTypes.DICT,
        )
        self.shm_mgr.create_key_and_ignore_conflict(
            self.shm_key_pkts_to_repeat,
            SHMContainerTypes.LIST,
        )
        self.shm_mgr.create_key_and_ignore_conflict(
            self.shm_key_last_repeat_time,
            SHMContainerTypes.DICT,
        )
        self.shm_mgr.create_key_and_ignore_conflict(
            self.shm_key_next_repeat_time,
            SHMContainerTypes.DICT,
        )
        self.shm_mgr.create_key_and_ignore_conflict(
            self.shm_key_max_repeat_times,
            SHMContainerTypes.DICT,
        )
        self.shm_mgr.create_key_and_ignore_conflict(
            self.shm_key_repeated_times,
            SHMContainerTypes.DICT,
        )
コード例 #2
0
    def init_shm(self):
        ''' initialize the shared memory manager
        '''

        self.shm_mgr = SharedMemoryManager(self.config)
        self.shm_mgr.connect(self.SHM_SOCKET_NAME_TEMPLATE % self.pid)

        self.shm_key_conns = self.SHM_KEY_TMP_CONNS % self.pid
        self.shm_mgr.create_key_and_ignore_conflict(
            self.shm_key_conns,
            SHMContainerTypes.DICT,
        )
コード例 #3
0
def launch_shm_worker():
    global worker_pid, shm_mgr

    pid = os.fork()

    if pid == -1:
        raise OSError('fork failed, unable to run SharedMemoryManager')

    # run SHM worker in the child process
    elif pid == 0:
        shm_mgr = SharedMemoryManager(config)
        shm_mgr.run_as_worker()

    # send testing request in the parent process
    else:
        # wait for shm worker
        time.sleep(1)
        worker_pid = pid

    return pid
コード例 #4
0
ファイル: base.py プロジェクト: erosb/demos
    def _start_shm_mgr(self):
        self.shm_mgr = SharedMemoryManager(self.config)

        # start SharedMemoryManager worker
        pid = os.fork()
        if pid == -1:
            raise OSError('fork failed')
        elif pid == 0:
            self._sig_shm_worker()
            try:
                self.shm_mgr.run_as_worker()
            except Exception:
                err_msg = traceback.format_exc()
                shm_logger.error(
                    f'Unexpected error occurred, SHM worker crashed. '
                    f'Traceback:\n{err_msg}')
                sys.exit(1)

            sys.exit(0)  # the sub-process ends here
        else:
            self.shm_worker_pid = pid
            logger.info(f'Started SharedMemoryManager: {pid}')
コード例 #5
0
    def __init__(
            self,
            config,
            efferent,
            logic_handler,
            protocol_wrapper,
            main_afferent,
            minor_afferents=tuple(),
    ):
        ''' constructor

        :param config: the config
        :param efferent: an efferent instance
        :param logic_handler: a logic handler instance
        :param protocol_wrapper: a protocol wrapper instance
        :param main_afferent: the main afferent
        :param minor_afferents: a group of minor afferents,
                                any iterable type contains afferent instances
        '''

        self.__running = False

        self.core_id = None
        self._epoll = select.epoll()
        self.afferent_mapping = {}

        self.config = config
        self.main_afferent = main_afferent
        self.efferent = efferent
        self.logic_handler = logic_handler
        self.protocol_wrapper = protocol_wrapper

        self.shm_mgr = SharedMemoryManager(self.config)

        self.plug_afferent(self.main_afferent)

        for afferent in minor_afferents:
            self.plug_afferent(afferent)
コード例 #6
0
    def test_1_set(self):
        KEY_FOR_SET = 'set_test'
        VALUE_FOR_SET = 'aaaaaaaaaaaaa'

        print('\n\n=====================set-value====================')
        shm_mgr = SharedMemoryManager(config, sensitive=False)
        shm_mgr.connect('test_set')
        print('conn_id: ', shm_mgr.current_connection.conn_id)
        self.assertIsInstance(shm_mgr.current_connection.conn_id, str)

        resp = shm_mgr.create_key(KEY_FOR_SET, SHMContainerTypes.STR)
        self.assertTrue(resp.get('succeeded'))

        resp = shm_mgr.set_value(KEY_FOR_SET, VALUE_FOR_SET)
        self.assertTrue(resp.get('succeeded'))

        resp = shm_mgr.read_key(KEY_FOR_SET)
        self.assertTrue(resp.get('succeeded'))
        self.assertEqual(resp.get('value'), VALUE_FOR_SET)
        print(resp)

        shm_mgr.disconnect()
        self.assertEqual(shm_mgr.current_connection, None)
コード例 #7
0
class BaseCore():
    ''' The base model of cores

    Literally, the core is supposed to be a kernel-like component.
    It organizes other components to work together and administrate them.
    Some components are plugable, and some others are necessary.

    Here is the list of all components:
        afferents in neverland.afferents, plugable
        efferents in neverland.efferents, necessary
        logic handlers in neverland.logic, necessary
        protocol wrappers in neverland.protocol, necessary

    In the initial version, all these components are necessary, and afferents
    could be multiple.
    '''

    EV_MASK = select.EPOLLIN

    SHM_SOCKET_NAME_TEMPLATE = 'SHM-Core-%d.socket'

    # SHM container for containing allocated core id
    # data structure:
    #     [1, 2, 3, 4]
    SHM_KEY_CORE_ID = 'Core_id'

    # The shared status of cluster controlling,
    # enumerated in neverland.core.state.ClusterControllingStates
    SHM_KEY_CC_STATE = 'Core_CCState'

    def __init__(
            self,
            config,
            efferent,
            logic_handler,
            protocol_wrapper,
            main_afferent,
            minor_afferents=tuple(),
    ):
        ''' constructor

        :param config: the config
        :param efferent: an efferent instance
        :param logic_handler: a logic handler instance
        :param protocol_wrapper: a protocol wrapper instance
        :param main_afferent: the main afferent
        :param minor_afferents: a group of minor afferents,
                                any iterable type contains afferent instances
        '''

        self.__running = False

        self.core_id = None
        self._epoll = select.epoll()
        self.afferent_mapping = {}

        self.config = config
        self.main_afferent = main_afferent
        self.efferent = efferent
        self.logic_handler = logic_handler
        self.protocol_wrapper = protocol_wrapper

        self.shm_mgr = SharedMemoryManager(self.config)

        self.plug_afferent(self.main_afferent)

        for afferent in minor_afferents:
            self.plug_afferent(afferent)

    def set_cc_state(self, status):
        self.shm_mgr.set_value(self.SHM_KEY_CC_STATE, status)

    def get_cc_state(self):
        resp = self.shm_mgr.read_key(self.SHM_KEY_CC_STATE)
        return resp.get('value')

    @property
    def cc_state(self):
        return self.get_cc_state()

    def init_shm(self):
        self.shm_mgr.connect(self.SHM_SOCKET_NAME_TEMPLATE % NodeContext.pid)

        self.shm_mgr.create_key_and_ignore_conflict(
            self.SHM_KEY_CORE_ID,
            SHMContainerTypes.LIST,
        )
        self.shm_mgr.create_key_and_ignore_conflict(
            self.SHM_KEY_CC_STATE,
            SHMContainerTypes.INT,
            CCStates.INIT,
        )

        logger.debug(f'init_shm for core of worker {NodeContext.pid} has done')

    def close_shm(self):
        self.shm_mgr.disconnect()

    def self_allocate_core_id(self):
        ''' Let the core pick up an id for itself
        '''

        try:
            resp = self.shm_mgr.lock_key(self.SHM_KEY_CORE_ID)
        except SHMResponseTimeout:
            # Currently, SHM_MAX_BLOCKING_TIME is 4 seconds and
            # these works can be definitely done in 4 seconds.
            # If a SHMResponseTimeout occurred, then there must
            # be a deadlock
            raise SHMResponseTimeout(
                f'deadlock of key: {self.SHM_KEY_CORE_ID}')

        resp = self.shm_mgr.read_key(self.SHM_KEY_CORE_ID)
        allocated_id = resp.get('value')

        if len(allocated_id) == 0:
            id_ = 0
        else:
            last_id = allocated_id[-1]
            id_ = last_id + 1

        self.shm_mgr.add_value(
            self.SHM_KEY_CORE_ID,
            [id_],
        )
        self.core_id = id_

        self.shm_mgr.unlock_key(self.SHM_KEY_CORE_ID)
        logger.debug(
            f'core of worker {NodeContext.pid} has self-allocated id: {id_}')

    def plug_afferent(self, afferent):
        self._epoll.register(afferent.fd, self.EV_MASK)
        self.afferent_mapping.update({afferent.fd: afferent})

    def unplug_afferent(self, fd):
        ''' remove an afferent from the core

        :param fd: the file discriptor of the afferent, int
        '''

        if fd not in self.afferent_mapping:
            return

        self._epoll.unregister(fd)
        self.afferent_mapping.pop(fd)

    def request_to_join_cluster(self):
        ''' send a request of the node is going to join the cluster
        '''

        entrance = self.config.cluster_entrance
        identification = self.config.net.identification

        if entrance is None:
            raise ConfigError("cluster_entrance is not defined")
        if identification is None:
            raise ConfigError("identification is not defined")

        logger.info('Trying to join cluster...')

        content = {
            'identification': identification,
            'ip': get_localhost_ip(),
            'listen_port': self.config.net.aff_listen_port,
        }
        subject = CCSubjects.JOIN_CLUSTER
        dest = (entrance.ip, entrance.port)

        pkt = UDPPacket()
        pkt.fields = ObjectifiedDict(
            type=PktTypes.CTRL,
            dest=dest,
            subject=subject,
            content=content,
        )
        pkt.next_hop = dest
        pkt = self.protocol_wrapper.wrap(pkt)

        NodeContext.pkt_mgr.repeat_pkt(pkt)
        logger.info(
            f'Sending request to cluster entrance {entrance.ip}:{entrance.port}'
        )

        logger.info('[Node Status] WAITING_FOR_JOIN')
        self.set_cc_state(CCStates.WAITING_FOR_JOIN)

    def request_to_leave_cluster(self):
        ''' send a request of the node is going to detach from the cluster
        '''

        entrance = self.config.cluster_entrance
        identification = self.config.net.identification

        if entrance is None:
            raise ConfigError("cluster_entrance is not defined")
        if identification is None:
            raise ConfigError("identification is not defined")

        logger.info('Trying to leave cluster...')

        content = {"identification": identification}
        subject = CCSubjects.LEAVE_CLUSTER
        dest = (entrance.ip, entrance.port)

        pkt = UDPPacket()
        pkt.fields = ObjectifiedDict(
            type=PktTypes.CTRL,
            dest=dest,
            subject=subject,
            content=content,
        )
        pkt.next_hop = dest
        pkt = self.protocol_wrapper.wrap(pkt)

        NodeContext.pkt_mgr.repeat_pkt(pkt)
        logger.info(
            f'Sent request to cluster entrance {entrance.ip}:{entrance.port}')

        logger.info('[Node Status] WAITING_FOR_LEAVE')
        self.set_cc_state(CCStates.WAITING_FOR_LEAVE)

    def handle_pkt(self, pkt):
        pkt = self.protocol_wrapper.unwrap(pkt)
        if not pkt.valid:
            return

        try:
            pkt = self.logic_handler.handle_logic(pkt)
        except DropPacket:
            return

        pkt = self.protocol_wrapper.wrap(pkt)
        self.efferent.transmit(pkt)

    def _poll(self):
        events = self._epoll.poll(POLL_TIMEOUT)

        for fd, evt in events:
            afferent = self.afferent_mapping[fd]

            if evt & select.EPOLLERR:
                self.unplug_afferent(fd)
                afferent.destroy()
            elif evt & select.EPOLLIN:
                pkt = afferent.recv()
                self.handle_pkt(pkt)

    def run(self):
        self.set_cc_state(CCStates.WORKING)
        self.__running = True

        self.main_afferent.listen()
        addr = self.main_afferent.listen_addr
        port = self.main_afferent.listen_port
        logger.info(f'Main afferent is listening on {addr}:{port}')

        while self.__running:
            self._poll()

    def run_for_a_while(self, duration=None, polling_times=None):
        ''' run the core within the specified duration time or poll times

        :param duration: the duration time, seconds in int
        :param polling_times: times the _poll method shall be invoked, int

        These 2 arguments will not work together, if duration specified, the
        poll_times will be ignored
        '''

        if duration is None and polling_times is None:
            raise ArgumentError('no argument passed')

        self.main_afferent.listen()
        addr = self.main_afferent.listen_addr
        port = self.main_afferent.listen_port
        logger.info(f'Main afferent is listening on {addr}:{port}')

        if duration is not None:
            starting_time = time.time()
            while time.time() - starting_time <= duration:
                self._poll()
        elif polling_times is not None:
            for _ in polling_times:
                self._poll()

    def shutdown(self):
        self.__running = False
コード例 #8
0
    def test_999_backlog(self):
        global do_not_kill_shm_worker

        key = 'bl0'

        shm_mgr = SharedMemoryManager(config, sensitive=False)
        shm_mgr.connect('test_bl')
        shm_mgr.create_key(key, SHMContainerTypes.LIST)
        shm_mgr.lock_key(key)

        pid = os.fork()
        if pid == -1:
            raise OSError('fork failed')

        if pid == 0:
            # This socket will cause a warning, so we close it here
            shm_mgr.current_connection.socket.close()

            do_not_kill_shm_worker = True
            shm_mgr1 = SharedMemoryManager(config, sensitive=False)
            shm_mgr1.connect('test_bl1')

            print('\n\n==============access-locked-container===============')
            t0 = time.time()
            resp = shm_mgr1.read_key(key)
            t1 = time.time()

            print(resp)

            delay = t1 - t0
            print(f'delay: {delay}')

            self.assertEqual(resp.get('succeeded'), True)
            self.assertTrue(delay >= 1)

            shm_mgr1.disconnect()
        else:
            time.sleep(1)
            print('------------release-lock-------------')
            shm_mgr.unlock_key(key)
            shm_mgr.disconnect()
            os.waitpid(-1, 0)
コード例 #9
0
    def test_2_rcode(self):
        shm_mgr = SharedMemoryManager(config, sensitive=False)
        shm_mgr.connect('test_err')

        print('\n\n=====================key-error-test====================')
        resp = shm_mgr.add_value(
            key='not_exists',
            value='a',
            backlogging=False,
        )
        print(resp)
        self.assertEqual(resp.get('succeeded'), False)
        self.assertEqual(resp.get('rcode'), ReturnCodes.KEY_ERROR)

        print('\n\n=====================type-error-test====================')
        resp = shm_mgr.create_key(
            'testing',
            SHMContainerTypes.DICT,
            value='a',
            backlogging=False,
        )
        print(resp)
        self.assertEqual(resp.get('succeeded'), False)
        self.assertEqual(resp.get('rcode'), ReturnCodes.TYPE_ERROR)

        print(
            '\n\n===================unlock-not-locked-test==================')
        resp = shm_mgr.unlock_key('testing')
        print(resp)
        self.assertEqual(resp.get('succeeded'), True)
        self.assertEqual(resp.get('rcode'), ReturnCodes.NOT_LOCKED)

        print('\n\n=====================lock-test====================')
        print('------create-container------')
        resp = shm_mgr.create_key(
            'testing',
            SHMContainerTypes.LIST,
            backlogging=False,
        )
        self.assertEqual(resp.get('succeeded'), True)
        self.assertEqual(resp.get('rcode'), ReturnCodes.OK)

        print('------lock-container------')
        resp = shm_mgr.lock_key('testing')
        print(resp)
        self.assertEqual(resp.get('succeeded'), True)
        self.assertEqual(resp.get('rcode'), ReturnCodes.OK)

        print('------access-locked-container------')
        shm_mgr1 = SharedMemoryManager(config, sensitive=False)
        shm_mgr1.connect('test_err_1')
        resp = shm_mgr1.add_value(
            'testing',
            [1, 2, 3],
            backlogging=False,
        )
        print(resp)
        self.assertEqual(resp.get('succeeded'), False)
        self.assertEqual(resp.get('rcode'), ReturnCodes.LOCKED)

        print('------unlock-container----------')
        resp = shm_mgr.unlock_key('testing')
        print(resp)
        self.assertEqual(resp.get('succeeded'), True)
        self.assertEqual(resp.get('rcode'), ReturnCodes.OK)

        print('------access-locked-container-again------')
        resp = shm_mgr1.add_value(
            'testing',
            [1, 2, 3],
            backlogging=False,
        )
        print(resp)
        self.assertEqual(resp.get('succeeded'), True)
        self.assertEqual(resp.get('rcode'), ReturnCodes.OK)

        shm_mgr.disconnect()
        shm_mgr1.disconnect()
コード例 #10
0
    def test_0_normal_ops(self):
        shm_mgr = SharedMemoryManager(config, sensitive=False)
        shm_mgr.connect('test')
        print('conn_id: ', shm_mgr.current_connection.conn_id)
        self.assertIsInstance(shm_mgr.current_connection.conn_id, str)

        for td in DATA:
            print(f'\n==================={td["name"]}===================\n')

            # try to remove this KEY first
            resp = shm_mgr.clean_key(KEY)
            if not resp.get('succeeded'):
                self.assertEqual(resp.get('rcode'), ReturnCodes.KEY_ERROR)

            resp = shm_mgr.create_key(KEY, td['type'], td['2_create'])
            print('\n-----------Create-----------')
            print(resp)
            self.assertTrue(resp.get('succeeded'))

            resp = shm_mgr.read_key(KEY)
            print('\n-----------Read-----------')
            print(resp)
            self.assertTrue(resp.get('succeeded'))

            resp = shm_mgr.add_value(KEY, td['2_add'])
            print('\n-----------Add-----------')
            print(resp)
            self.assertTrue(resp.get('succeeded'))

            resp = shm_mgr.read_key(KEY)
            print('\n-----------Read-----------')
            print(resp)
            self.assertTrue(resp.get('succeeded'))

            resp = shm_mgr.remove_value(KEY, td['2_remove'])
            print('\n-----------Remove-----------')
            print(resp)
            self.assertTrue(resp.get('succeeded'))

            resp = shm_mgr.read_key(KEY)
            print('\n-----------Read-----------')
            print(resp)
            self.assertTrue(resp.get('succeeded'))
            self.assertEqual(
                resp.get('value')[td['remaining_key']],
                td['remaining'],
            )

        shm_mgr.disconnect()
        self.assertEqual(shm_mgr.current_connection, None)
コード例 #11
0
ファイル: base.py プロジェクト: erosb/demos
class BaseNode():
    ''' The Base Class of Nodes
    '''

    role = None

    def __init__(self, config, role=None):
        self.config = config
        self.role = role or self.role

        self.worker_pids = []
        self.shm_worker_pid = None
        self.pkt_rpter_worker_pid = None

        self.node_id = self.config.basic.node_id

    def _write_master_pid(self):
        pid_path = self.config.basic.pid_file
        pid = os.getpid()

        with open(pid_path, 'w') as f:
            f.write(str(pid))

        logger.debug(
            f'wrote pid file {pid_path} for master process, pid: {pid}')

    def _read_master_pid(self):
        pid_path = self.config.basic.pid_file

        try:
            with open(pid_path, 'r') as f:
                content = f.read()
        except FileNotFoundError:
            raise PidFileNotExists

        try:
            return int(content)
        except ValueError:
            raise ValueError('pid file has beed tampered')

    def _handle_term_master(self, signal, sf):
        logger.debug(f'Master process received signal: {signal}')
        logger.debug(f'The master process starts to shut down workers')
        self.shutdown_workers()

        pid_path = self.config.basic.pid_file
        if os.path.isfile(pid_path):
            logger.debug(f'Remove pid file: {pid_path}')
            os.remove(pid_path)

    def _handle_term_worker(self, signal, sf):
        pid = os.getpid()
        logger.debug(f'Worker {pid} received signal: {signal}')
        logger.debug(f'Shutting down worker {pid}')
        self.core.shutdown()

    def _handle_term_shm(self, signal, sf):
        pid = os.getpid()
        logger.debug(f'SharedMemoryManager {pid} received signal: {signal}')
        logger.debug(f'Shutting down SharedMemoryManager {pid}')
        self.shm_mgr.shutdown_worker()

    def _handle_term_pkt_rpter(self, signal, sf):
        pid = os.getpid()
        logger.debug(f'SpecialPacketRepeater {pid} received signal: {signal}')
        logger.debug(f'Shutting down SpecialPacketRepeater {pid}')
        self.pkt_rpter.shutdown()

    def _sig_master(self):
        sig.signal(sig.SIGHUP, sig.SIG_IGN)
        for s in TERM_SIGNALS:
            sig.signal(s, self._handle_term_master)

    def _sig_normal_worker(self):
        sig.signal(sig.SIGHUP, sig.SIG_IGN)
        for s in TERM_SIGNALS:
            sig.signal(s, self._handle_term_worker)

    def _sig_shm_worker(self):
        sig.signal(sig.SIGHUP, sig.SIG_IGN)
        for s in TERM_SIGNALS:
            sig.signal(s, self._handle_term_shm)

    def _sig_pkt_rpter_worker(self):
        sig.signal(sig.SIGHUP, sig.SIG_IGN)
        for s in TERM_SIGNALS:
            sig.signal(s, self._handle_term_pkt_rpter)

    def shutdown_workers(self):
        for pid in self.worker_pids:
            self._kill(pid)

        # wait for workers to exit
        remaining = list(self.worker_pids)
        while True:
            for pid in list(remaining):
                if self._process_exists(pid):
                    os.waitpid(pid, os.WNOHANG)
                else:
                    logger.debug(f'Worker {pid} terminated')
                    remaining.remove(pid)

            if len(remaining) == 0:
                break

            time.sleep(0.5)

        # shutdown SharedMemoryManager worker at last
        shm_pid = self.shm_worker_pid
        self._kill(shm_pid)
        os.waitpid(shm_pid, 0)
        logger.debug(f'SharedMemoryManager worker {shm_pid} terminated')

        logger.debug('All workers terminated')

    def _kill(self, pid):
        try:
            logger.debug(f'Sending SIGTERM to {pid}')
            os.kill(pid, sig.SIGTERM)
        except ProcessLookupError:
            pass

    def _process_exists(self, pid):
        try:
            os.kill(pid, 0)
        except OSError:
            logger.debug(f'Process {pid} not exists')
            return False
        else:
            logger.debug(f'Process {pid} exists')
            return True

    def daemonize(self):
        pid = os.fork()
        if pid == -1:
            raise OSError('fork failed when doing daemonize')
        elif pid > 0:
            # double fork magic
            sys.exit(0)

        pid = os.fork()
        if pid == -1:
            raise OSError('fork failed when doing daemonize')

        def quit(sg, sf):
            sys.exit(0)

        if pid > 0:
            for s in TERM_SIGNALS:
                sig.signal(s, quit)
            time.sleep(5)
        else:
            self._sig_master()
            ppid = os.getppid()
            os.kill(ppid, sig.SIGTERM)
            os.setsid()

        logger.debug('Node daemonized')

    def _start_shm_mgr(self):
        self.shm_mgr = SharedMemoryManager(self.config)

        # start SharedMemoryManager worker
        pid = os.fork()
        if pid == -1:
            raise OSError('fork failed')
        elif pid == 0:
            self._sig_shm_worker()
            try:
                self.shm_mgr.run_as_worker()
            except Exception:
                err_msg = traceback.format_exc()
                shm_logger.error(
                    f'Unexpected error occurred, SHM worker crashed. '
                    f'Traceback:\n{err_msg}')
                sys.exit(1)

            sys.exit(0)  # the sub-process ends here
        else:
            self.shm_worker_pid = pid
            logger.info(f'Started SharedMemoryManager: {pid}')

    def _start_pkt_rpter(self):
        '''
        The Repeater needs some components from the node,
        so we shouldn't use this method before components are initialized
        '''

        pid = os.fork()
        if pid == -1:
            raise OSError('fork failed')
        elif pid == 0:
            self._sig_pkt_rpter_worker()
            self.pkt_rpter = SpecialPacketRepeater(
                self.config,
                self.efferent,
                self.protocol_wrapper,
            )

            try:
                self.pkt_rpter.init_shm()
                self.pkt_rpter.run()
            except Exception:
                err_msg = traceback.format_exc()
                logger.error(
                    f'Unexpected error occurred, SpecialPacketRepeater worker '
                    f'crashed. Traceback:\n{err_msg}')
                sys.exit(1)

            sys.exit(0)  # the sub-process ends here
        else:
            self.pkt_rpter = SpecialPacketRepeater(
                self.config,
                self.efferent,
                self.protocol_wrapper,
            )
            self.pkt_rpter_worker_pid = pid
            logger.info(f'Started SpecialPacketRepeater: {pid}')

    def _load_modules(self):
        self.afferent_cls = AFFERENT_MAPPING[self.role]
        self.main_afferent = self.afferent_cls(self.config)

        self.efferent = UDPTransmitter(self.config)

        self.protocol_wrapper = ProtocolWrapper(
            self.config,
            HeaderFormat,
            DataPktFormat,
            CtrlPktFormat,
            ConnCtrlPktFormat,
        )

        self.logic_handler_cls = LOGIC_HANDLER_MAPPING[self.role]
        self.logic_handler = self.logic_handler_cls(self.config)

        self.core_cls = CORE_MAPPING.get(self.role)
        self.core = self.core_cls(
            self.config,
            main_afferent=self.main_afferent,
            minor_afferents=[],
            efferent=self.efferent,
            logic_handler=self.logic_handler,
            protocol_wrapper=self.protocol_wrapper,
        )

        self.pkt_mgr = SpecialPacketManager(self.config)

        # The packet repeater is a part of the packet manager, so we will
        # use it as a normal module. Each worker shall have it's own packet
        # repeater but not share it like the shared memory manager worker
        self._start_pkt_rpter()

        self.logic_handler.init_shm()

        self.core.init_shm()
        self.core.self_allocate_core_id()

        self.pkt_mgr.init_shm()

        pid = os.getpid()
        logger.debug(f'Worker {pid} loaded modules')

    def _clean_modules(self):
        self._kill(self.pkt_rpter_worker_pid)
        os.waitpid(self.pkt_rpter_worker_pid, 0)
        logger.debug(f'SpecialPacketRepeater worker '
                     f'{self.pkt_rpter_worker_pid} terminated')

        self.core.shutdown()
        self.main_afferent.destroy()

        self.logic_handler.close_shm()
        self.core.close_shm()
        self.pkt_mgr.close_shm()

        self.main_afferent = None
        self.efferent = None
        self.protocol_wrapper = None
        self.logic_handler = None
        self.core = None
        self.pkt_mgr = None

        pid = os.getpid()
        logger.debug(f'Worker {pid} cleaned modules')

    def get_context():
        return NodeContext

    def _create_context(self):
        NodeContext.pkt_rpter_pid = self.pkt_rpter_worker_pid
        NodeContext.local_ip = get_localhost_ip()
        NodeContext.listen_port = self.config.net.aff_listen_port
        NodeContext.core = self.core
        NodeContext.main_efferent = self.efferent
        NodeContext.protocol_wrapper = self.protocol_wrapper
        NodeContext.pkt_mgr = self.pkt_mgr

        NodeContext.id_generator = IDGenerator(self.node_id, self.core.core_id)

        pid = os.getpid()
        logger.debug(f'Worker {pid} created NodeContext')

    def _clean_context(self):
        NodeContext.pkt_rpter_pid = None
        NodeContext.id_generator = None
        NodeContext.local_ip = None
        NodeContext.listen_port = None
        NodeContext.core = None
        NodeContext.main_efferent = None
        NodeContext.protocol_wrapper = None

        pid = os.getpid()
        logger.debug(f'Worker {pid} cleaned NodeContext')

    def join_cluster(self):
        if self.role == Roles.CONTROLLER:
            raise RuntimeError(
                'Controller node is the root node of the cluster')

        self.core.request_to_join_cluster()
        self.core.run_for_a_while(5)
        raise TimeoutError

    def run(self):
        pid_fl = self.config.basic.pid_file
        try:
            pid = self._read_master_pid()
            logger.warn(
                f'\n\tThe Neverland node is already running or the pid file\n'
                f'\t{pid_fl} is not removed, current pid: {pid}.\n'
                f'\tMake sure that the node is not running and try again.\n\n'
                f'\tIf you need to run multiple node on this computer, then\n'
                f'\tyou need to at least configure another pid file for it.')
            return
        except ValueError:
            logger.error(
                f'\n\tThe pid file {pid_fl} exists but seems it\'s not\n'
                f'\twritten by the Neverland node. Please make sure the node\n'
                f'\tis not running and the pid file is not occupied.')
            return
        except PidFileNotExists:
            pass

        self.daemonize()
        NodeContext.pid = os.getpid()

        self._write_master_pid()
        self._start_shm_mgr()

        # Before we start workers, we need to join the cluster first.
        if self.role != Roles.CONTROLLER:
            # Before we join the cluster, we need to load modules at first,
            # once we have joined the cluster, modules in the Master worker
            # shall be removed.
            self._load_modules()
            self._create_context()

            try:
                self.join_cluster()
            except SuccessfullyJoinedCluster:
                logger.info('Successfully joined the cluster.')
            except FailedToJoinCluster:
                logger.error('Cannot join the cluster, request not permitted')
                self._clean_modules()
                self._clean_context()
                self._on_break()
                return
            except TimeoutError:
                logger.error(
                    'No response from entrance node, Failed to join the cluster'
                )
                self._clean_modules()
                self._clean_context()
                self._on_break()
                return

            self._clean_modules()
            self._clean_context()

        # start normal workers
        worker_amount = self.config.basic.worker_amount
        for _ in range(worker_amount):
            pid = os.fork()
            NodeContext.pid = os.getpid()

            if pid == -1:
                raise OSError('fork failed')
            elif pid == 0:
                self._sig_normal_worker()
                self._load_modules()
                self._create_context()

                try:
                    self.core.run()
                except Exception:
                    err_msg = traceback.format_exc()
                    logger.error(f'Unexpected error occurred, node crashed. '
                                 f'Traceback:\n{err_msg}')

                    self._clean_modules()
                    self._clean_context()
                    sys.exit(1)

                self._clean_modules()
                self._clean_context()
                sys.exit(0)  # the sub-process ends here
            else:
                self.worker_pids.append(pid)
                logger.info(f'Started Worker: {pid}')

        while True:
            try:
                os.waitpid(-1, 0)
            except ChildProcessError:
                break

    def shutdown(self):
        pid = self._read_master_pid()
        self._kill(pid)
        logger.info('Sent SIGTERM to the master process')

    def _on_break(self):
        '''
        a hook that needs to be invoked while self.run has been broken
        by some exception
        '''

        shm_pid = self.shm_worker_pid
        self._kill(shm_pid)
        os.waitpid(shm_pid, 0)
        logger.debug(f'SharedMemoryManager worker {shm_pid} terminated')

        pid_fl = self.config.basic.pid_file
        os.remove(pid_fl)
        logger.debug(f'Removed pid file: {pid_fl}')
        logger.info('Master process exits.\n\n')
コード例 #12
0
class ConnectionManager():
    ''' The Connection Manager

    We will store all informations of established connections in the shared
    memory. This ConnectionManager is aimed on converting these informations
    between JSONs and Connection objects. Providing Connection objects to the
    upper layer and store Connection objects in the shared memory in JSONs.

    As a manager, it should provide functionalities of establishing connection
    and closing connection as well.
    '''

    SHM_SOCKET_NAME_TEMPLATE = 'SHM-ConnectionManager-%d.socket'

    # The SHM container to store established connections.
    #
    # Data structure:
    #     {
    #         "ip:port": {
    #             "slot-0": {
    #                 "status": int,
    #                 "sn": int,
    #                 "iv": b64encode(iv),
    #                 "iv_duration": int,
    #             },
    #             "slot-1": {
    #                 "status": int,
    #                 "sn": int,
    #                 "iv": b64encode(iv),
    #                 "iv_duration": int,
    #             },
    #             "slot-2": {
    #                 "status": int,
    #                 "sn": int,
    #                 "iv": b64encode(iv),
    #                 "iv_duration": int,
    #             },
    #         }
    #     }
    SHM_KEY_TMP_CONNS = 'ConnectionManager-%d_Conns'

    def __init__(self, config):
        self.config = config
        self.iv_len = self.config.net.crypto.iv_len
        self.iv_duration_range = self.config.net.crypto.iv_duration_range

        if not 0 < self.iv_len < EVP_MAX_IV_LENGTH:
            raise ArgumentError('iv_len out of range')

        self.pid = NodeContext.pid

    def init_shm(self):
        ''' initialize the shared memory manager
        '''

        self.shm_mgr = SharedMemoryManager(self.config)
        self.shm_mgr.connect(self.SHM_SOCKET_NAME_TEMPLATE % self.pid)

        self.shm_key_conns = self.SHM_KEY_TMP_CONNS % self.pid
        self.shm_mgr.create_key_and_ignore_conflict(
            self.shm_key_conns,
            SHMContainerTypes.DICT,
        )

    def _remote_sa_2_key(self, remote):
        ''' convert remote socket address to a key string
        '''

        ip = remote[0]
        port = remote[1]
        return f'{ip}:{port}'

    def _get_native_conn_info(self, remote):
        ''' get the native JSON data of connections
        '''

        remote_name = self._remote_sa_2_key(remote)
        shm_data = self.shm_mgr.get_dict_value(self.shm_key_conns, remote_name)
        shm_value = shm_data.get('value')

        if shm_value is None:
            return {
                SLOT_0: None,
                SLOT_1: None,
                SLOT_2: None,
            }
        else:
            return shm_value

    def get_conns(self, remote):
        ''' get all connections of a remote node

        :param remote: socket address in tuple format, (ip, port)
        :returns: a dict of Connection objects:
                    {
                        SLOT_0: conn,
                        SLOT_1: conn,
                        SLOT_2: conn,
                    }
        '''

        ip = remote[0]
        port = remote[1]
        native_info = self._get_native_conn_info(remote)

        result = dict()
        for slot_name in SLOTS:
            conn_info = native_info.get(slot_name)

            if conn_info is None:
                conn = None
            else:
                iv = conn_info.get('iv')
                if iv is not None:
                    iv = base64.b64decode(iv)

                conn_info.update(
                    slot=slot_name,
                    remote={
                        'ip': ip,
                        'port': port
                    },
                    iv=iv,
                )
                conn = Connection(**conn_info)

            result.update({slot_name: conn})
        return result

    def store_conn(self, conn, slot, override=False):
        ''' store a connection object to a slot

        :param conn: a Connection object
        :param slot: slot name, enumerated in SLOTS
        '''

        remote = (conn.remote.ip, conn.remote.port)
        remote_name = self._remote_sa_2_key(remote)

        if not override:
            usable_slots = self.get_usable_slots(remote)
            if slot not in usable_slots:
                raise ConnSlotNotAvailable

        iv = conn.iv
        if iv is not None:
            # the base64 string must be str but not bytes
            iv = b64encode(iv).decode()

        conn_info = conn.__to_dict__()
        conn_info.update(iv=iv)

        self.shm_mgr.update_dict(
            key=self.shm_key_conns,
            dict_key=remote_name,
            value=conn_info,
        )

    def get_usable_slots(self, remote):
        ''' get all usable slots of a remote

        :param remote: remote socket address, (ip, port)
        :returns: a list of slot names
        '''

        usable_slots = list(SLOTS)

        conns = self.get_conns(remote)
        for slot in SLOTS:
            conn = conns.get(slot)
            if conn is None:
                break
            else:
                usable_slots.remove(slot)

        return usable_slots

    def new_conn(self, remote, synchronous=False, timeout=2, interval=0.1):
        ''' establish a new connection

        The establishing connection will be placed in slot-2.

        After the new connection is established, if we have established 2
        connections with the specified node already then the connection in
        slot-0 will be removed and the connection in slot-1 will be moved
        to slot-0. The new connection will be placed in slot-1.

        :param remote: remote socket address, (ip, port)
        :param synchronous:
                    If the sync argument is True, then the new_conn method
                    will try to wait the connection complete and return a
                    connection object. This operation will be blocking until
                    it reaches the timeout or the connection completes.

                    If the sync argument is False, then the new_conn method
                    will return None immediately without waiting.
        :param timeout: seconds to timeout
        :param interval: the interval time of connection checking
        '''

        usable_slots = self.get_usable_slots(remote)
        remote_name = self._remote_sa_2_key(remote)

        if SLOT_2 not in usable_slots:
            raise ConnSlotNotAvailable(f'slot-2 to {remote_name} is in using, '
                                       f'cannot establish connection now')

        iv = os.urandom(self.iv_len)
        iv_duration = random.randint(*self.iv_duration_range)

        pkt = UDPPacket()
        pkt.fields = ObjectifiedDict(
            type=PktTypes.CONN_CTRL,
            dest=remote,
            communicating=1,
            iv_changed=1,
            iv_duration=iv_duration,
            iv=iv,
        )
        pkt.next_hop = remote
        pkt = NodeContext.protocol_wrapper.wrap(pkt)
        NodeContext.pkt_mgr.repeat_pkt(pkt)

        conn_sn = NodeContext.id_generator.gen()
        conn = {
            "remote": {
                "ip": remote[0],
                "port": remote[1],
            },
            "sn": conn_sn,
            "state": ConnStates.ESTABLISHING,
            "slot": SLOT_2,
            "iv": iv,
            "iv_duration": iv_duration,
        }
        conn = Connection(**conn)

        # though we have checked the SLOT_2 already,
        # but it still has the possibility...
        try:
            self.store_conn(conn, SLOT_2)
        except ConnSlotNotAvailable:
            logger.warn(
                f'slot-2 to {remote_name} seized, abort the establishment')

        # The request is sending, now we wait for the response
        if not synchronous:
            return None

        # while timeout > 0:
        # # TODO complete this after the response processing logic is done

        # timeout -= interval
        # time.sleep(interval)

    def get_conn(self, remote):
        ''' get a Connection object of the specified remote

        :param remote: remote socket address, (ip, port)
        :returns: Connection object
        '''

        conns = self.get_conns()

        # according to the explanations above, the priority of slots is 1 > 0
        conn_s1 = conns.get(SLOT_1)
        if conn_s1 is not None and conn_s1.state == ConnStates.ESTABLISHED:
            return conn_s1

        conn_s0 = conns.get(SLOT_0)
        if conn_s0 is not None and conn_s1.state == ConnStates.ESTABLISHED:
            return conn_s0

        raise NoConnAvailable

    def remove_conn(self, remote, slot):
        ''' close a connection

        :param remote: remote socket address, (ip, port)
        :param slot: slot name, enumerated in SLOTS
        '''

        ip = remote[0]
        port = remote[1]
        remote_name = f'{ip}:{port}'

        native_info = self._get_native_conn_info(remote)
        native_info[slot] = None

        self.shm_mgr.update_dict(remote_name, native_info)
コード例 #13
0
class SpecialPacketManager():

    SHM_SOCKET_NAME_TEMPLATE = 'SHM-SpecialPacketManager-%d.socket'

    def __init__(self, config):
        self.config = config
        self.pid = NodeContext.pid

        self.shm_key_pkts = SHM_KEY_PKTS

        # These containers are for the SpecialPacketRepeater, the repeater
        # will also access special packets by the manager.
        self.shm_key_pkts_to_repeat = SHM_KEY_TMP_PKTS_TO_REPEAT % self.pid
        self.shm_key_last_repeat_time = SHM_KEY_TMP_LAST_REPEAT_TIME % self.pid
        self.shm_key_next_repeat_time = SHM_KEY_TMP_NEXT_REPEAT_TIME % self.pid
        self.shm_key_max_repeat_times = SHM_KEY_TMP_MAX_REPEAT_TIMES % self.pid
        self.shm_key_repeated_times = SHM_KEY_TMP_REPEATED_TIMES % self.pid

    def init_shm(self):
        ''' initialize the shared memory manager
        '''

        self.shm_mgr = SharedMemoryManager(self.config)
        self.shm_mgr.connect(self.SHM_SOCKET_NAME_TEMPLATE % self.pid)
        self.shm_mgr.create_key_and_ignore_conflict(
            self.shm_key_pkts,
            SHMContainerTypes.DICT,
        )
        self.shm_mgr.create_key_and_ignore_conflict(
            self.shm_key_pkts_to_repeat,
            SHMContainerTypes.LIST,
        )
        self.shm_mgr.create_key_and_ignore_conflict(
            self.shm_key_last_repeat_time,
            SHMContainerTypes.DICT,
        )
        self.shm_mgr.create_key_and_ignore_conflict(
            self.shm_key_next_repeat_time,
            SHMContainerTypes.DICT,
        )
        self.shm_mgr.create_key_and_ignore_conflict(
            self.shm_key_max_repeat_times,
            SHMContainerTypes.DICT,
        )
        self.shm_mgr.create_key_and_ignore_conflict(
            self.shm_key_repeated_times,
            SHMContainerTypes.DICT,
        )

    def close_shm(self):
        self.shm_mgr.disconnect()

    def store_pkt(self, pkt, need_repeat=False, max_rpt_times=5):
        sn = pkt.fields.sn
        type_ = pkt.fields.type

        # The salt field is bytes, so we cannot serialize it in a JSON.
        # So, we shall encode it into a base64 string before store it.
        fields = pkt.fields.__to_dict__()
        salt = fields.get('salt')
        if salt is not None:
            salt_b64 = base64.b64encode(salt).decode()
            fields.update(salt=salt_b64)

        previous_hop = list(pkt.previous_hop)
        next_hop = list(pkt.next_hop)

        if sn is None:
            raise InvalidPkt(
                'Packets to be stored must contain a serial number')

        value = {
            sn: {
                'type': type_,
                'fields': fields,
                'previous_hop': previous_hop,
                'next_hop': next_hop,
            }
        }

        self.shm_mgr.lock_key(self.shm_key_pkts)
        self.shm_mgr.add_value(self.shm_key_pkts, value)
        self.shm_mgr.unlock_key(self.shm_key_pkts)

        if need_repeat:
            self.shm_mgr.add_value(self.shm_key_pkts_to_repeat, [sn])
            self.set_pkt_max_repeat_times(sn, max_rpt_times)
            self.set_pkt_repeated_times(sn, 0)

        hex_type = Converter.int_2_hex(type_)
        logger.debug(f'Stored a special packet, need_repeat: {need_repeat}, '
                     f'sn: {sn}, type: {hex_type}, dest: {pkt.fields.dest}')

    def get_pkt(self, sn):
        shm_data = self.shm_mgr.get_value(self.shm_key_pkts, sn)
        shm_value = shm_data.get('value')

        if shm_value is None:
            return None

        # and here, we restore the base64 encoded salt into bytes
        fields = shm_value.get('fields')
        salt_b64 = fields.get('salt')
        salt = base64.b64decode(salt_b64)
        fields.update(salt=salt)

        return UDPPacket(
            fields=fields,
            type=shm_value.get('type'),
            previous_hop=shm_value.get('previous_hop'),
            next_hop=shm_value.get('next_hop'),
        )

    def remove_pkt(self, sn):
        self.cancel_repeat(sn)

        self.shm_mgr.lock_key(self.shm_key_pkts)
        self.shm_mgr.remove_value(self.shm_key_pkts, sn)
        self.shm_mgr.unlock_key(self.shm_key_pkts)

        logger.debug(f'Removed a special packet, sn: {sn}')

    def cancel_repeat(self, sn):
        self.shm_mgr.remove_value(self.shm_key_pkts_to_repeat, sn)
        self.shm_mgr.remove_value(self.shm_key_last_repeat_time, sn)
        self.shm_mgr.remove_value(self.shm_key_next_repeat_time, sn)
        self.shm_mgr.remove_value(self.shm_key_max_repeat_times, sn)
        self.shm_mgr.remove_value(self.shm_key_repeated_times, sn)
        logger.debug(f'Cancelled repeat for a packet, sn: {sn}')

    def repeat_pkt(self, pkt, max_rpt_times=5):
        self.store_pkt(pkt, need_repeat=True, max_rpt_times=max_rpt_times)

    def get_repeating_sn_list(self):
        shm_data = self.shm_mgr.read_key(self.shm_key_pkts_to_repeat)
        return shm_data.get('value')

    def set_pkt_last_repeat_time(self, sn, timestamp):
        self.shm_mgr.add_value(self.shm_key_last_repeat_time, {sn: timestamp})
        logger.debug(f'set_pkt_last_repeat_time, sn: {sn}, ts: {timestamp}')

    def get_pkt_last_repeat_time(self, sn):
        shm_data = self.shm_mgr.get_value(self.shm_key_last_repeat_time, sn)
        return shm_data.get('value')

    def set_pkt_next_repeat_time(self, sn, timestamp):
        self.shm_mgr.add_value(self.shm_key_next_repeat_time, {sn: timestamp})
        logger.debug(f'set_pkt_next_repeat_time, sn: {sn}, ts: {timestamp}')

    def get_pkt_next_repeat_time(self, sn):
        shm_data = self.shm_mgr.get_value(self.shm_key_next_repeat_time, sn)
        return shm_data.get('value')

    def set_pkt_max_repeat_times(self, sn, times):
        self.shm_mgr.add_value(self.shm_key_max_repeat_times, {sn: times})
        logger.debug(f'set_pkt_max_repeat_times, sn: {sn}, times: {times}')

    def get_pkt_max_repeat_times(self, sn):
        shm_data = self.shm_mgr.get_value(self.shm_key_max_repeat_times, sn)
        return shm_data.get('value')

    def set_pkt_repeated_times(self, sn, times):
        self.shm_mgr.add_value(self.shm_key_repeated_times, {sn: times})
        logger.debug(f'set_pkt_repeated_times, sn: {sn}, times: {times}')

    def get_pkt_repeated_times(self, sn):
        shm_data = self.shm_mgr.get_value(self.shm_key_repeated_times, sn)
        return shm_data.get('value')

    def increase_pkt_repeated_times(self, sn):
        repeated_times = self.get_pkt_repeated_times(sn)

        if repeated_times is None:
            repeated_times = 1
        else:
            repeated_times += 1

        self.set_pkt_repeated_times(sn, repeated_times)
コード例 #14
0
    def __init__(self, config):
        self.config = config

        self.shm_mgr = SharedMemoryManager(self.config)