Example #1
0
    def test_0_normal_ops(self):
        shm_mgr = SharedMemoryManager(config, sensitive=False)
        shm_mgr.connect('test')
        print('conn_id: ', shm_mgr.current_connection.conn_id)
        self.assertIsInstance(shm_mgr.current_connection.conn_id, str)

        for td in DATA:
            print(f'\n==================={td["name"]}===================\n')

            # try to remove this KEY first
            resp = shm_mgr.clean_key(KEY)
            if not resp.get('succeeded'):
                self.assertEqual(resp.get('rcode'), ReturnCodes.KEY_ERROR)

            resp = shm_mgr.create_key(KEY, td['type'], td['2_create'])
            print('\n-----------Create-----------')
            print(resp)
            self.assertTrue(resp.get('succeeded'))

            resp = shm_mgr.read_key(KEY)
            print('\n-----------Read-----------')
            print(resp)
            self.assertTrue(resp.get('succeeded'))

            resp = shm_mgr.add_value(KEY, td['2_add'])
            print('\n-----------Add-----------')
            print(resp)
            self.assertTrue(resp.get('succeeded'))

            resp = shm_mgr.read_key(KEY)
            print('\n-----------Read-----------')
            print(resp)
            self.assertTrue(resp.get('succeeded'))

            resp = shm_mgr.remove_value(KEY, td['2_remove'])
            print('\n-----------Remove-----------')
            print(resp)
            self.assertTrue(resp.get('succeeded'))

            resp = shm_mgr.read_key(KEY)
            print('\n-----------Read-----------')
            print(resp)
            self.assertTrue(resp.get('succeeded'))
            self.assertEqual(
                resp.get('value')[td['remaining_key']],
                td['remaining'],
            )

        shm_mgr.disconnect()
        self.assertEqual(shm_mgr.current_connection, None)
Example #2
0
class SpecialPacketManager():

    SHM_SOCKET_NAME_TEMPLATE = 'SHM-SpecialPacketManager-%d.socket'

    def __init__(self, config):
        self.config = config
        self.pid = NodeContext.pid

        self.shm_key_pkts = SHM_KEY_PKTS

        # These containers are for the SpecialPacketRepeater, the repeater
        # will also access special packets by the manager.
        self.shm_key_pkts_to_repeat = SHM_KEY_TMP_PKTS_TO_REPEAT % self.pid
        self.shm_key_last_repeat_time = SHM_KEY_TMP_LAST_REPEAT_TIME % self.pid
        self.shm_key_next_repeat_time = SHM_KEY_TMP_NEXT_REPEAT_TIME % self.pid
        self.shm_key_max_repeat_times = SHM_KEY_TMP_MAX_REPEAT_TIMES % self.pid
        self.shm_key_repeated_times = SHM_KEY_TMP_REPEATED_TIMES % self.pid

    def init_shm(self):
        ''' initialize the shared memory manager
        '''

        self.shm_mgr = SharedMemoryManager(self.config)
        self.shm_mgr.connect(self.SHM_SOCKET_NAME_TEMPLATE % self.pid)
        self.shm_mgr.create_key_and_ignore_conflict(
            self.shm_key_pkts,
            SHMContainerTypes.DICT,
        )
        self.shm_mgr.create_key_and_ignore_conflict(
            self.shm_key_pkts_to_repeat,
            SHMContainerTypes.LIST,
        )
        self.shm_mgr.create_key_and_ignore_conflict(
            self.shm_key_last_repeat_time,
            SHMContainerTypes.DICT,
        )
        self.shm_mgr.create_key_and_ignore_conflict(
            self.shm_key_next_repeat_time,
            SHMContainerTypes.DICT,
        )
        self.shm_mgr.create_key_and_ignore_conflict(
            self.shm_key_max_repeat_times,
            SHMContainerTypes.DICT,
        )
        self.shm_mgr.create_key_and_ignore_conflict(
            self.shm_key_repeated_times,
            SHMContainerTypes.DICT,
        )

    def close_shm(self):
        self.shm_mgr.disconnect()

    def store_pkt(self, pkt, need_repeat=False, max_rpt_times=5):
        sn = pkt.fields.sn
        type_ = pkt.fields.type

        # The salt field is bytes, so we cannot serialize it in a JSON.
        # So, we shall encode it into a base64 string before store it.
        fields = pkt.fields.__to_dict__()
        salt = fields.get('salt')
        if salt is not None:
            salt_b64 = base64.b64encode(salt).decode()
            fields.update(salt=salt_b64)

        previous_hop = list(pkt.previous_hop)
        next_hop = list(pkt.next_hop)

        if sn is None:
            raise InvalidPkt(
                'Packets to be stored must contain a serial number')

        value = {
            sn: {
                'type': type_,
                'fields': fields,
                'previous_hop': previous_hop,
                'next_hop': next_hop,
            }
        }

        self.shm_mgr.lock_key(self.shm_key_pkts)
        self.shm_mgr.add_value(self.shm_key_pkts, value)
        self.shm_mgr.unlock_key(self.shm_key_pkts)

        if need_repeat:
            self.shm_mgr.add_value(self.shm_key_pkts_to_repeat, [sn])
            self.set_pkt_max_repeat_times(sn, max_rpt_times)
            self.set_pkt_repeated_times(sn, 0)

        hex_type = Converter.int_2_hex(type_)
        logger.debug(f'Stored a special packet, need_repeat: {need_repeat}, '
                     f'sn: {sn}, type: {hex_type}, dest: {pkt.fields.dest}')

    def get_pkt(self, sn):
        shm_data = self.shm_mgr.get_value(self.shm_key_pkts, sn)
        shm_value = shm_data.get('value')

        if shm_value is None:
            return None

        # and here, we restore the base64 encoded salt into bytes
        fields = shm_value.get('fields')
        salt_b64 = fields.get('salt')
        salt = base64.b64decode(salt_b64)
        fields.update(salt=salt)

        return UDPPacket(
            fields=fields,
            type=shm_value.get('type'),
            previous_hop=shm_value.get('previous_hop'),
            next_hop=shm_value.get('next_hop'),
        )

    def remove_pkt(self, sn):
        self.cancel_repeat(sn)

        self.shm_mgr.lock_key(self.shm_key_pkts)
        self.shm_mgr.remove_value(self.shm_key_pkts, sn)
        self.shm_mgr.unlock_key(self.shm_key_pkts)

        logger.debug(f'Removed a special packet, sn: {sn}')

    def cancel_repeat(self, sn):
        self.shm_mgr.remove_value(self.shm_key_pkts_to_repeat, sn)
        self.shm_mgr.remove_value(self.shm_key_last_repeat_time, sn)
        self.shm_mgr.remove_value(self.shm_key_next_repeat_time, sn)
        self.shm_mgr.remove_value(self.shm_key_max_repeat_times, sn)
        self.shm_mgr.remove_value(self.shm_key_repeated_times, sn)
        logger.debug(f'Cancelled repeat for a packet, sn: {sn}')

    def repeat_pkt(self, pkt, max_rpt_times=5):
        self.store_pkt(pkt, need_repeat=True, max_rpt_times=max_rpt_times)

    def get_repeating_sn_list(self):
        shm_data = self.shm_mgr.read_key(self.shm_key_pkts_to_repeat)
        return shm_data.get('value')

    def set_pkt_last_repeat_time(self, sn, timestamp):
        self.shm_mgr.add_value(self.shm_key_last_repeat_time, {sn: timestamp})
        logger.debug(f'set_pkt_last_repeat_time, sn: {sn}, ts: {timestamp}')

    def get_pkt_last_repeat_time(self, sn):
        shm_data = self.shm_mgr.get_value(self.shm_key_last_repeat_time, sn)
        return shm_data.get('value')

    def set_pkt_next_repeat_time(self, sn, timestamp):
        self.shm_mgr.add_value(self.shm_key_next_repeat_time, {sn: timestamp})
        logger.debug(f'set_pkt_next_repeat_time, sn: {sn}, ts: {timestamp}')

    def get_pkt_next_repeat_time(self, sn):
        shm_data = self.shm_mgr.get_value(self.shm_key_next_repeat_time, sn)
        return shm_data.get('value')

    def set_pkt_max_repeat_times(self, sn, times):
        self.shm_mgr.add_value(self.shm_key_max_repeat_times, {sn: times})
        logger.debug(f'set_pkt_max_repeat_times, sn: {sn}, times: {times}')

    def get_pkt_max_repeat_times(self, sn):
        shm_data = self.shm_mgr.get_value(self.shm_key_max_repeat_times, sn)
        return shm_data.get('value')

    def set_pkt_repeated_times(self, sn, times):
        self.shm_mgr.add_value(self.shm_key_repeated_times, {sn: times})
        logger.debug(f'set_pkt_repeated_times, sn: {sn}, times: {times}')

    def get_pkt_repeated_times(self, sn):
        shm_data = self.shm_mgr.get_value(self.shm_key_repeated_times, sn)
        return shm_data.get('value')

    def increase_pkt_repeated_times(self, sn):
        repeated_times = self.get_pkt_repeated_times(sn)

        if repeated_times is None:
            repeated_times = 1
        else:
            repeated_times += 1

        self.set_pkt_repeated_times(sn, repeated_times)