Пример #1
0
    def test_999_backlog(self):
        global do_not_kill_shm_worker

        key = 'bl0'

        shm_mgr = SharedMemoryManager(config, sensitive=False)
        shm_mgr.connect('test_bl')
        shm_mgr.create_key(key, SHMContainerTypes.LIST)
        shm_mgr.lock_key(key)

        pid = os.fork()
        if pid == -1:
            raise OSError('fork failed')

        if pid == 0:
            # This socket will cause a warning, so we close it here
            shm_mgr.current_connection.socket.close()

            do_not_kill_shm_worker = True
            shm_mgr1 = SharedMemoryManager(config, sensitive=False)
            shm_mgr1.connect('test_bl1')

            print('\n\n==============access-locked-container===============')
            t0 = time.time()
            resp = shm_mgr1.read_key(key)
            t1 = time.time()

            print(resp)

            delay = t1 - t0
            print(f'delay: {delay}')

            self.assertEqual(resp.get('succeeded'), True)
            self.assertTrue(delay >= 1)

            shm_mgr1.disconnect()
        else:
            time.sleep(1)
            print('------------release-lock-------------')
            shm_mgr.unlock_key(key)
            shm_mgr.disconnect()
            os.waitpid(-1, 0)
Пример #2
0
class BaseCore():
    ''' The base model of cores

    Literally, the core is supposed to be a kernel-like component.
    It organizes other components to work together and administrate them.
    Some components are plugable, and some others are necessary.

    Here is the list of all components:
        afferents in neverland.afferents, plugable
        efferents in neverland.efferents, necessary
        logic handlers in neverland.logic, necessary
        protocol wrappers in neverland.protocol, necessary

    In the initial version, all these components are necessary, and afferents
    could be multiple.
    '''

    EV_MASK = select.EPOLLIN

    SHM_SOCKET_NAME_TEMPLATE = 'SHM-Core-%d.socket'

    # SHM container for containing allocated core id
    # data structure:
    #     [1, 2, 3, 4]
    SHM_KEY_CORE_ID = 'Core_id'

    # The shared status of cluster controlling,
    # enumerated in neverland.core.state.ClusterControllingStates
    SHM_KEY_CC_STATE = 'Core_CCState'

    def __init__(
            self,
            config,
            efferent,
            logic_handler,
            protocol_wrapper,
            main_afferent,
            minor_afferents=tuple(),
    ):
        ''' constructor

        :param config: the config
        :param efferent: an efferent instance
        :param logic_handler: a logic handler instance
        :param protocol_wrapper: a protocol wrapper instance
        :param main_afferent: the main afferent
        :param minor_afferents: a group of minor afferents,
                                any iterable type contains afferent instances
        '''

        self.__running = False

        self.core_id = None
        self._epoll = select.epoll()
        self.afferent_mapping = {}

        self.config = config
        self.main_afferent = main_afferent
        self.efferent = efferent
        self.logic_handler = logic_handler
        self.protocol_wrapper = protocol_wrapper

        self.shm_mgr = SharedMemoryManager(self.config)

        self.plug_afferent(self.main_afferent)

        for afferent in minor_afferents:
            self.plug_afferent(afferent)

    def set_cc_state(self, status):
        self.shm_mgr.set_value(self.SHM_KEY_CC_STATE, status)

    def get_cc_state(self):
        resp = self.shm_mgr.read_key(self.SHM_KEY_CC_STATE)
        return resp.get('value')

    @property
    def cc_state(self):
        return self.get_cc_state()

    def init_shm(self):
        self.shm_mgr.connect(self.SHM_SOCKET_NAME_TEMPLATE % NodeContext.pid)

        self.shm_mgr.create_key_and_ignore_conflict(
            self.SHM_KEY_CORE_ID,
            SHMContainerTypes.LIST,
        )
        self.shm_mgr.create_key_and_ignore_conflict(
            self.SHM_KEY_CC_STATE,
            SHMContainerTypes.INT,
            CCStates.INIT,
        )

        logger.debug(f'init_shm for core of worker {NodeContext.pid} has done')

    def close_shm(self):
        self.shm_mgr.disconnect()

    def self_allocate_core_id(self):
        ''' Let the core pick up an id for itself
        '''

        try:
            resp = self.shm_mgr.lock_key(self.SHM_KEY_CORE_ID)
        except SHMResponseTimeout:
            # Currently, SHM_MAX_BLOCKING_TIME is 4 seconds and
            # these works can be definitely done in 4 seconds.
            # If a SHMResponseTimeout occurred, then there must
            # be a deadlock
            raise SHMResponseTimeout(
                f'deadlock of key: {self.SHM_KEY_CORE_ID}')

        resp = self.shm_mgr.read_key(self.SHM_KEY_CORE_ID)
        allocated_id = resp.get('value')

        if len(allocated_id) == 0:
            id_ = 0
        else:
            last_id = allocated_id[-1]
            id_ = last_id + 1

        self.shm_mgr.add_value(
            self.SHM_KEY_CORE_ID,
            [id_],
        )
        self.core_id = id_

        self.shm_mgr.unlock_key(self.SHM_KEY_CORE_ID)
        logger.debug(
            f'core of worker {NodeContext.pid} has self-allocated id: {id_}')

    def plug_afferent(self, afferent):
        self._epoll.register(afferent.fd, self.EV_MASK)
        self.afferent_mapping.update({afferent.fd: afferent})

    def unplug_afferent(self, fd):
        ''' remove an afferent from the core

        :param fd: the file discriptor of the afferent, int
        '''

        if fd not in self.afferent_mapping:
            return

        self._epoll.unregister(fd)
        self.afferent_mapping.pop(fd)

    def request_to_join_cluster(self):
        ''' send a request of the node is going to join the cluster
        '''

        entrance = self.config.cluster_entrance
        identification = self.config.net.identification

        if entrance is None:
            raise ConfigError("cluster_entrance is not defined")
        if identification is None:
            raise ConfigError("identification is not defined")

        logger.info('Trying to join cluster...')

        content = {
            'identification': identification,
            'ip': get_localhost_ip(),
            'listen_port': self.config.net.aff_listen_port,
        }
        subject = CCSubjects.JOIN_CLUSTER
        dest = (entrance.ip, entrance.port)

        pkt = UDPPacket()
        pkt.fields = ObjectifiedDict(
            type=PktTypes.CTRL,
            dest=dest,
            subject=subject,
            content=content,
        )
        pkt.next_hop = dest
        pkt = self.protocol_wrapper.wrap(pkt)

        NodeContext.pkt_mgr.repeat_pkt(pkt)
        logger.info(
            f'Sending request to cluster entrance {entrance.ip}:{entrance.port}'
        )

        logger.info('[Node Status] WAITING_FOR_JOIN')
        self.set_cc_state(CCStates.WAITING_FOR_JOIN)

    def request_to_leave_cluster(self):
        ''' send a request of the node is going to detach from the cluster
        '''

        entrance = self.config.cluster_entrance
        identification = self.config.net.identification

        if entrance is None:
            raise ConfigError("cluster_entrance is not defined")
        if identification is None:
            raise ConfigError("identification is not defined")

        logger.info('Trying to leave cluster...')

        content = {"identification": identification}
        subject = CCSubjects.LEAVE_CLUSTER
        dest = (entrance.ip, entrance.port)

        pkt = UDPPacket()
        pkt.fields = ObjectifiedDict(
            type=PktTypes.CTRL,
            dest=dest,
            subject=subject,
            content=content,
        )
        pkt.next_hop = dest
        pkt = self.protocol_wrapper.wrap(pkt)

        NodeContext.pkt_mgr.repeat_pkt(pkt)
        logger.info(
            f'Sent request to cluster entrance {entrance.ip}:{entrance.port}')

        logger.info('[Node Status] WAITING_FOR_LEAVE')
        self.set_cc_state(CCStates.WAITING_FOR_LEAVE)

    def handle_pkt(self, pkt):
        pkt = self.protocol_wrapper.unwrap(pkt)
        if not pkt.valid:
            return

        try:
            pkt = self.logic_handler.handle_logic(pkt)
        except DropPacket:
            return

        pkt = self.protocol_wrapper.wrap(pkt)
        self.efferent.transmit(pkt)

    def _poll(self):
        events = self._epoll.poll(POLL_TIMEOUT)

        for fd, evt in events:
            afferent = self.afferent_mapping[fd]

            if evt & select.EPOLLERR:
                self.unplug_afferent(fd)
                afferent.destroy()
            elif evt & select.EPOLLIN:
                pkt = afferent.recv()
                self.handle_pkt(pkt)

    def run(self):
        self.set_cc_state(CCStates.WORKING)
        self.__running = True

        self.main_afferent.listen()
        addr = self.main_afferent.listen_addr
        port = self.main_afferent.listen_port
        logger.info(f'Main afferent is listening on {addr}:{port}')

        while self.__running:
            self._poll()

    def run_for_a_while(self, duration=None, polling_times=None):
        ''' run the core within the specified duration time or poll times

        :param duration: the duration time, seconds in int
        :param polling_times: times the _poll method shall be invoked, int

        These 2 arguments will not work together, if duration specified, the
        poll_times will be ignored
        '''

        if duration is None and polling_times is None:
            raise ArgumentError('no argument passed')

        self.main_afferent.listen()
        addr = self.main_afferent.listen_addr
        port = self.main_afferent.listen_port
        logger.info(f'Main afferent is listening on {addr}:{port}')

        if duration is not None:
            starting_time = time.time()
            while time.time() - starting_time <= duration:
                self._poll()
        elif polling_times is not None:
            for _ in polling_times:
                self._poll()

    def shutdown(self):
        self.__running = False
Пример #3
0
    def test_2_rcode(self):
        shm_mgr = SharedMemoryManager(config, sensitive=False)
        shm_mgr.connect('test_err')

        print('\n\n=====================key-error-test====================')
        resp = shm_mgr.add_value(
            key='not_exists',
            value='a',
            backlogging=False,
        )
        print(resp)
        self.assertEqual(resp.get('succeeded'), False)
        self.assertEqual(resp.get('rcode'), ReturnCodes.KEY_ERROR)

        print('\n\n=====================type-error-test====================')
        resp = shm_mgr.create_key(
            'testing',
            SHMContainerTypes.DICT,
            value='a',
            backlogging=False,
        )
        print(resp)
        self.assertEqual(resp.get('succeeded'), False)
        self.assertEqual(resp.get('rcode'), ReturnCodes.TYPE_ERROR)

        print(
            '\n\n===================unlock-not-locked-test==================')
        resp = shm_mgr.unlock_key('testing')
        print(resp)
        self.assertEqual(resp.get('succeeded'), True)
        self.assertEqual(resp.get('rcode'), ReturnCodes.NOT_LOCKED)

        print('\n\n=====================lock-test====================')
        print('------create-container------')
        resp = shm_mgr.create_key(
            'testing',
            SHMContainerTypes.LIST,
            backlogging=False,
        )
        self.assertEqual(resp.get('succeeded'), True)
        self.assertEqual(resp.get('rcode'), ReturnCodes.OK)

        print('------lock-container------')
        resp = shm_mgr.lock_key('testing')
        print(resp)
        self.assertEqual(resp.get('succeeded'), True)
        self.assertEqual(resp.get('rcode'), ReturnCodes.OK)

        print('------access-locked-container------')
        shm_mgr1 = SharedMemoryManager(config, sensitive=False)
        shm_mgr1.connect('test_err_1')
        resp = shm_mgr1.add_value(
            'testing',
            [1, 2, 3],
            backlogging=False,
        )
        print(resp)
        self.assertEqual(resp.get('succeeded'), False)
        self.assertEqual(resp.get('rcode'), ReturnCodes.LOCKED)

        print('------unlock-container----------')
        resp = shm_mgr.unlock_key('testing')
        print(resp)
        self.assertEqual(resp.get('succeeded'), True)
        self.assertEqual(resp.get('rcode'), ReturnCodes.OK)

        print('------access-locked-container-again------')
        resp = shm_mgr1.add_value(
            'testing',
            [1, 2, 3],
            backlogging=False,
        )
        print(resp)
        self.assertEqual(resp.get('succeeded'), True)
        self.assertEqual(resp.get('rcode'), ReturnCodes.OK)

        shm_mgr.disconnect()
        shm_mgr1.disconnect()
Пример #4
0
class SpecialPacketManager():

    SHM_SOCKET_NAME_TEMPLATE = 'SHM-SpecialPacketManager-%d.socket'

    def __init__(self, config):
        self.config = config
        self.pid = NodeContext.pid

        self.shm_key_pkts = SHM_KEY_PKTS

        # These containers are for the SpecialPacketRepeater, the repeater
        # will also access special packets by the manager.
        self.shm_key_pkts_to_repeat = SHM_KEY_TMP_PKTS_TO_REPEAT % self.pid
        self.shm_key_last_repeat_time = SHM_KEY_TMP_LAST_REPEAT_TIME % self.pid
        self.shm_key_next_repeat_time = SHM_KEY_TMP_NEXT_REPEAT_TIME % self.pid
        self.shm_key_max_repeat_times = SHM_KEY_TMP_MAX_REPEAT_TIMES % self.pid
        self.shm_key_repeated_times = SHM_KEY_TMP_REPEATED_TIMES % self.pid

    def init_shm(self):
        ''' initialize the shared memory manager
        '''

        self.shm_mgr = SharedMemoryManager(self.config)
        self.shm_mgr.connect(self.SHM_SOCKET_NAME_TEMPLATE % self.pid)
        self.shm_mgr.create_key_and_ignore_conflict(
            self.shm_key_pkts,
            SHMContainerTypes.DICT,
        )
        self.shm_mgr.create_key_and_ignore_conflict(
            self.shm_key_pkts_to_repeat,
            SHMContainerTypes.LIST,
        )
        self.shm_mgr.create_key_and_ignore_conflict(
            self.shm_key_last_repeat_time,
            SHMContainerTypes.DICT,
        )
        self.shm_mgr.create_key_and_ignore_conflict(
            self.shm_key_next_repeat_time,
            SHMContainerTypes.DICT,
        )
        self.shm_mgr.create_key_and_ignore_conflict(
            self.shm_key_max_repeat_times,
            SHMContainerTypes.DICT,
        )
        self.shm_mgr.create_key_and_ignore_conflict(
            self.shm_key_repeated_times,
            SHMContainerTypes.DICT,
        )

    def close_shm(self):
        self.shm_mgr.disconnect()

    def store_pkt(self, pkt, need_repeat=False, max_rpt_times=5):
        sn = pkt.fields.sn
        type_ = pkt.fields.type

        # The salt field is bytes, so we cannot serialize it in a JSON.
        # So, we shall encode it into a base64 string before store it.
        fields = pkt.fields.__to_dict__()
        salt = fields.get('salt')
        if salt is not None:
            salt_b64 = base64.b64encode(salt).decode()
            fields.update(salt=salt_b64)

        previous_hop = list(pkt.previous_hop)
        next_hop = list(pkt.next_hop)

        if sn is None:
            raise InvalidPkt(
                'Packets to be stored must contain a serial number')

        value = {
            sn: {
                'type': type_,
                'fields': fields,
                'previous_hop': previous_hop,
                'next_hop': next_hop,
            }
        }

        self.shm_mgr.lock_key(self.shm_key_pkts)
        self.shm_mgr.add_value(self.shm_key_pkts, value)
        self.shm_mgr.unlock_key(self.shm_key_pkts)

        if need_repeat:
            self.shm_mgr.add_value(self.shm_key_pkts_to_repeat, [sn])
            self.set_pkt_max_repeat_times(sn, max_rpt_times)
            self.set_pkt_repeated_times(sn, 0)

        hex_type = Converter.int_2_hex(type_)
        logger.debug(f'Stored a special packet, need_repeat: {need_repeat}, '
                     f'sn: {sn}, type: {hex_type}, dest: {pkt.fields.dest}')

    def get_pkt(self, sn):
        shm_data = self.shm_mgr.get_value(self.shm_key_pkts, sn)
        shm_value = shm_data.get('value')

        if shm_value is None:
            return None

        # and here, we restore the base64 encoded salt into bytes
        fields = shm_value.get('fields')
        salt_b64 = fields.get('salt')
        salt = base64.b64decode(salt_b64)
        fields.update(salt=salt)

        return UDPPacket(
            fields=fields,
            type=shm_value.get('type'),
            previous_hop=shm_value.get('previous_hop'),
            next_hop=shm_value.get('next_hop'),
        )

    def remove_pkt(self, sn):
        self.cancel_repeat(sn)

        self.shm_mgr.lock_key(self.shm_key_pkts)
        self.shm_mgr.remove_value(self.shm_key_pkts, sn)
        self.shm_mgr.unlock_key(self.shm_key_pkts)

        logger.debug(f'Removed a special packet, sn: {sn}')

    def cancel_repeat(self, sn):
        self.shm_mgr.remove_value(self.shm_key_pkts_to_repeat, sn)
        self.shm_mgr.remove_value(self.shm_key_last_repeat_time, sn)
        self.shm_mgr.remove_value(self.shm_key_next_repeat_time, sn)
        self.shm_mgr.remove_value(self.shm_key_max_repeat_times, sn)
        self.shm_mgr.remove_value(self.shm_key_repeated_times, sn)
        logger.debug(f'Cancelled repeat for a packet, sn: {sn}')

    def repeat_pkt(self, pkt, max_rpt_times=5):
        self.store_pkt(pkt, need_repeat=True, max_rpt_times=max_rpt_times)

    def get_repeating_sn_list(self):
        shm_data = self.shm_mgr.read_key(self.shm_key_pkts_to_repeat)
        return shm_data.get('value')

    def set_pkt_last_repeat_time(self, sn, timestamp):
        self.shm_mgr.add_value(self.shm_key_last_repeat_time, {sn: timestamp})
        logger.debug(f'set_pkt_last_repeat_time, sn: {sn}, ts: {timestamp}')

    def get_pkt_last_repeat_time(self, sn):
        shm_data = self.shm_mgr.get_value(self.shm_key_last_repeat_time, sn)
        return shm_data.get('value')

    def set_pkt_next_repeat_time(self, sn, timestamp):
        self.shm_mgr.add_value(self.shm_key_next_repeat_time, {sn: timestamp})
        logger.debug(f'set_pkt_next_repeat_time, sn: {sn}, ts: {timestamp}')

    def get_pkt_next_repeat_time(self, sn):
        shm_data = self.shm_mgr.get_value(self.shm_key_next_repeat_time, sn)
        return shm_data.get('value')

    def set_pkt_max_repeat_times(self, sn, times):
        self.shm_mgr.add_value(self.shm_key_max_repeat_times, {sn: times})
        logger.debug(f'set_pkt_max_repeat_times, sn: {sn}, times: {times}')

    def get_pkt_max_repeat_times(self, sn):
        shm_data = self.shm_mgr.get_value(self.shm_key_max_repeat_times, sn)
        return shm_data.get('value')

    def set_pkt_repeated_times(self, sn, times):
        self.shm_mgr.add_value(self.shm_key_repeated_times, {sn: times})
        logger.debug(f'set_pkt_repeated_times, sn: {sn}, times: {times}')

    def get_pkt_repeated_times(self, sn):
        shm_data = self.shm_mgr.get_value(self.shm_key_repeated_times, sn)
        return shm_data.get('value')

    def increase_pkt_repeated_times(self, sn):
        repeated_times = self.get_pkt_repeated_times(sn)

        if repeated_times is None:
            repeated_times = 1
        else:
            repeated_times += 1

        self.set_pkt_repeated_times(sn, repeated_times)