コード例 #1
0
ファイル: test_lru.py プロジェクト: meiliqi/cacheout
def test_lru_get_eviction(cache: LRUCache):
    """Test that LRUCache evicts least recently accessed entries first."""
    keys = random.sample(list(cache.keys()), len(cache))

    for key in keys:
        cache.get(key)

    assert_keys_evicted_in_order(cache, keys)
コード例 #2
0
class StackCache:
    """堆栈缓存

    使用运行内存作为高速缓存,可有效提高并发的处理能力

    """

    def __init__(self, maxsize=0xff, ttl=None):

        self._cache = LRUCache(maxsize, ttl)

    def has(self, key):

        return self._cache.has(key)

    def get(self, key, default=None):

        return self._cache.get(key, default)

    def set(self, key, val, ttl=None):

        self._cache.set(key, val, ttl)

    def delete(self, key):

        return self._cache.delete(key)

    def size(self):

        return self._cache.size()
コード例 #3
0
ファイル: test_lru.py プロジェクト: johnbergvall/cacheout
def test_lru_get_set_eviction(cache: LRUCache):
    """Test that LRUCache evicts least recently set/accessed entries first."""
    all_keys = list(cache.keys())
    get_keys = random.sample(all_keys, len(cache) // 2)
    set_keys = random.sample(list(set(all_keys).difference(get_keys)), len(cache) // 2)

    assert not set(get_keys).intersection(set_keys)
    assert set(get_keys + set_keys) == set(all_keys)

    for key in get_keys:
        cache.get(key)

    for key in set_keys:
        cache.set(key, key)

    keys = get_keys + set_keys

    assert_keys_evicted_in_order(cache, keys)
コード例 #4
0
class StackCache:
    def __init__(self, maxsize=0xff, ttl=None):

        self._cache = LRUCache(maxsize, ttl)

    def has(self, key):

        return self._cache.has(key)

    def get(self, key, default=None):

        return self._cache.get(key, default)

    def set(self, key, val, ttl=None):

        self._cache.set(key, val, ttl)

    def delete(self, key):

        return self._cache.delete(key)

    def size(self):

        return self._cache.size()
コード例 #5
0
class tcp_http_pcap():
    def __init__(self, pcap_collection_data, max_queue_size, work_queue,
                 interface, custom_tag, return_deep_info, http_filter_json,
                 cache_size, session_size, bpf_filter, timeout, debug):
        """
		构造函数
		:param max_queue_size: 资产队列最大长度
		:param work_queue: 捕获资产数据消息发送队列
		:param interface: 捕获流量的网卡名
		:param custom_tag: 数据标签,用于区分不同的采集引擎
		:param return_deep_info: 是否处理更多信息,包括原始请求、响应头和正文
		:param http_filter_json: HTTP过滤器配置,支持按状态和内容类型过滤
		:param cache_size: 缓存的已处理数据条数,120秒内重复的数据将不会重复采集
		:param session_size: 缓存的HTTP/TCP会话数量,30秒未使用的会话将被自动清除
		:param bpf_filter: 数据包底层过滤器
		:param timeout: 采集程序的运行超时时间,默认为启动后1小时自动退出
		:param debug: 调试开关
		"""
        self.pcap_collection_data = pcap_collection_data
        self.total_msg_num = 0
        self.max_queue_size = max_queue_size
        self.work_queue = work_queue
        self.debug = debug
        self.timeout = timeout
        self.bpf_filter = bpf_filter
        self.cache_size = cache_size
        self.session_size = session_size
        self.http_filter_json = http_filter_json
        self.return_deep_info = return_deep_info
        self.custom_tag = custom_tag
        self.interface = interface
        self.sniffer = pcap.pcap(self.interface,
                                 snaplen=65535,
                                 promisc=True,
                                 timeout_ms=self.timeout,
                                 immediate=False)
        self.sniffer.setfilter(self.bpf_filter)
        self.tcp_stream_cache = Cache(maxsize=self.session_size,
                                      ttl=30,
                                      timer=time.time,
                                      default=None)
        if self.cache_size:
            self.tcp_cache = LRUCache(maxsize=self.cache_size,
                                      ttl=120,
                                      timer=time.time,
                                      default=None)
            self.http_cache = LRUCache(maxsize=self.cache_size,
                                       ttl=120,
                                       timer=time.time,
                                       default=None)
        # http数据分析正则
        self.decode_request_regex = re.compile(
            r'^([A-Z]+) +([^ \r\n]+) +HTTP/\d+(?:\.\d+)?[^\r\n]*(.*?)$', re.S)
        self.decode_response_regex = re.compile(
            r'^HTTP/(\d+(?:\.\d+)?) (\d+)[^\r\n]*(.*?)$', re.S)
        self.decode_body_regex = re.compile(
            rb'<meta[^>]+?charset=[\'"]?([a-z\d\-]+)[\'"]?', re.I)

    def run(self):
        """
		入口函数
		"""
        for ts, pkt in self.sniffer:
            # self.total_msg_num += 1
            # if self.total_msg_num%1000 == 0:
            # 	print("Packet analysis rate: %s"%(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())+" - "+str(self.total_msg_num)))
            packet = self.pkt_decode(pkt)
            if not packet:
                continue

            # print('{}:{}->{}:{}: Seq:{}, Ack:{}, Flag: {}, Len: {}'.format(packet.src, packet.sport, packet.dst, packet.dport, packet.ack, packet.seq, packet.flags, len(packet.data)))
            cache_key = '{}:{}'.format(packet.src, packet.sport)
            # SYN & ACK
            if packet.flags == 0x12:
                if self.cache_size and self.tcp_cache.get(cache_key):
                    continue

                self.tcp_stream_cache.set('S_{}'.format(packet.ack),
                                          packet.seq + 1)

            # ACK || PSH-ACK
            elif packet.flags in [0x10, 0x18, 0x19]:
                # 长度为0的数据包不处理
                if len(packet.data) == 0:
                    continue

                # 第一个有数据的请求包,先缓存下来
                # Seq == SYN-ACK Ack
                pre_cs_seq = self.tcp_stream_cache.get('S_{}'.format(
                    packet.seq))
                if pre_cs_seq:
                    c_s_key = 'C_{}'.format(packet.ack)

                    self.tcp_stream_cache.set(c_s_key, packet.data)
                    self.tcp_stream_cache.delete('S_{}'.format(packet.seq))
                    continue

                # 1. 提取服务器主动响应的通讯,例如:MySQL
                # Seq == SYN-ACK Seq + 1
                if 'TCP' in self.pcap_collection_data:
                    pre_sc_seq = self.tcp_stream_cache.get('S_{}'.format(
                        packet.ack))
                    if pre_sc_seq == packet.seq:
                        self.tcp_stream_cache.delete('S_{}'.format(packet.ack))

                        # TCP瞬时重复处理
                        if self.cache_size:
                            self.tcp_cache.set(cache_key, True)

                        data = {
                            'pro': 'TCP',
                            'tag': self.custom_tag,
                            'ip': packet.src,
                            'port': packet.sport,
                            'data': packet.data.hex()
                        }
                        self.send_msg(data)
                        continue

                # 2. 提取需要请求服务器才会响应的通讯,例如:HTTP
                # Seq == PSH ACK(C->S) Ack
                send_data = self.tcp_stream_cache.get('C_{}'.format(
                    packet.seq))
                # 判断是否存在请求数据
                if send_data:
                    # 删除已使用的缓存
                    self.tcp_stream_cache.delete('C_{}'.format(packet.seq))

                    # HTTP通讯采集判断
                    if 'HTTP' in self.pcap_collection_data and packet.data[:
                                                                           5] == b'HTTP/':
                        request_dict = self.decode_request(
                            send_data, packet.src, str(packet.sport))
                        if not request_dict:
                            continue

                        http_cache_key = '{}:{}'.format(
                            request_dict['method'], request_dict['uri'])
                        if self.cache_size and self.http_cache.get(
                                http_cache_key):
                            continue

                        response_dict = self.decode_response(packet.data)
                        if response_dict:
                            # HTTP瞬时重复处理
                            if self.cache_size:
                                self.http_cache.set(http_cache_key, True)

                            response_code = response_dict['status']
                            content_type = response_dict['type']

                            # 根据响应状态码和页面类型进行过滤
                            if self.http_filter_json:
                                filter_code = self.http_filter(
                                    'response_code',
                                    response_code) if response_code else False
                                filter_type = self.http_filter(
                                    'content_type',
                                    content_type) if content_type else False
                                if filter_code or filter_type:
                                    continue

                            data = {
                                'pro': 'HTTP',
                                'tag': self.custom_tag,
                                'ip': packet.src,
                                'port': packet.sport,
                                'method': request_dict['method'],
                                'code': response_code,
                                'type': content_type,
                                'server': response_dict['server'],
                                'header': response_dict['headers'],
                                'url': request_dict['uri'],
                                'body': response_dict['body']
                            }

                            self.send_msg(data)
                            continue

                    # TCP通讯采集判断
                    elif 'TCP' in self.pcap_collection_data:
                        # TCP瞬时重复处理
                        if self.cache_size:
                            self.tcp_cache.set(cache_key, True)

                        # 2.2 非 HTTP 通讯
                        data = {
                            'pro': 'TCP',
                            'tag': self.custom_tag,
                            'ip': packet.src,
                            'port': packet.sport,
                            'data': packet.data.hex()
                        }
                        self.send_msg(data)

        self.sniffer.close()

    def http_filter(self, key, value):
        """
		检查字符串中是否包含特定的规则
		:param key: 规则键名,response_code(状态码)或 content_type(内容类型)
		:param value: 要检查的字符串
		:return: True - 包含, False - 不包含
		"""
        if key in self.http_filter_json:
            for rule in self.http_filter_json[key]:
                if rule in value:
                    return True
        return False

    def pkt_decode(self, pkt):
        try:
            ip_type = ''
            packet = dpkt.ethernet.Ethernet(pkt)
            if isinstance(packet.data, dpkt.ip.IP):
                ip_type = 'ip4'
            elif isinstance(packet.data, dpkt.ip6.IP6):
                ip_type = 'ip6'
            if ip_type and isinstance(packet.data.data, dpkt.tcp.TCP):
                if packet.data.data.flags == 0x12 or \
                 packet.data.data.flags in [0x10, 0x18, 0x19] and len(packet.data.data.data) > 0:
                    tcp_pkt = packet.data.data
                    if ip_type == 'ip4':
                        tcp_pkt.src = self.ip_addr(packet.data.src)
                        tcp_pkt.dst = self.ip_addr(packet.data.dst)
                    else:
                        tcp_pkt.src = self.ip6_addr(''.join(
                            ['%02X' % x for x in packet.data.src]))
                        tcp_pkt.dst = self.ip6_addr(''.join(
                            ['%02X' % x for x in packet.data.dst]))
                    return tcp_pkt
        except KeyboardInterrupt:
            print('\nExit.')
            os.kill(os.getpid(), signal.SIGKILL)
        except Exception as e:
            # print(str(e))
            # print(("".join(['%02X ' % b for b in pkt])))
            pass
        return None

    def ip_addr(self, ip):
        return '%d.%d.%d.%d' % tuple(ip)

    def ip6_addr(self, ip6):
        ip6_addr = ''
        ip6_list = re.findall(r'.{4}', ip6)
        for i in range(len(ip6_list)):
            ip6_addr += ':%s' % (ip6_list[i].lstrip('0')
                                 if ip6_list[i].lstrip('0') else '0')
        return ip6_addr.lstrip(':')

    def decode_request(self, data, sip, sport):
        pos = data.find(b'\r\n\r\n')
        body = data[pos + 4:] if pos > 0 else b''
        data_str = str(data[:pos] if pos > 0 else data, 'utf-8', 'ignore')
        m = self.decode_request_regex.match(data_str)
        if m:
            if m.group(2)[:1] != '/':
                return None

            headers = m.group(3).strip() if m.group(3) else ''
            header_dict = self.parse_headers(headers)
            host_domain = ''
            # host domain
            if 'host' in header_dict and re.search('[a-zA-Z]',
                                                   header_dict['host']):
                host_domain = header_dict['host']
            # host ip
            else:
                host_domain = sip + ':' + sport if sport != '80' else sip
            url = 'http://{}{}'.format(
                host_domain, m.group(2)) if host_domain else m.group(2)

            return {
                'method': m.group(1) if m.group(1) else '',
                'uri': url,
                'headers': headers,
                'body': str(body, 'utf-8', 'ignore')
            }

        return {
            'method':
            '',
            'uri':
            'http://{}:{}/'.format(sip if ':' not in sip else '[' + sip + ']',
                                   sport),
            'headers':
            '',
            'body':
            ''
        }

    def decode_response(self, data):
        pos = data.find(b'\r\n\r\n')
        body = data[pos + 4:] if pos > 0 else b''
        header_str = str(data[:pos] if pos > 0 else data, 'utf-8', 'ignore')
        m = self.decode_response_regex.match(header_str)
        if m:
            headers = m.group(3).strip() if m.group(3) else ''
            headers_dict = self.parse_headers(headers)
            if self.return_deep_info and 'transfer-encoding' in headers_dict and headers_dict[
                    'transfer-encoding'] == 'chunked':
                body = self.decode_chunked(body)

            if self.return_deep_info and 'content-encoding' in headers_dict:
                if headers_dict['content-encoding'] == 'gzip':
                    body = self.decode_gzip(body)
                elif headers_dict['content-encoding'] == 'br':
                    body = self.decode_brotli(body)

            content_type = '' if 'content-type' not in headers_dict else headers_dict[
                'content-type']
            server = '' if 'server' not in headers_dict else headers_dict[
                'server']
            return {
                'version': m.group(1) if m.group(1) else '',
                'status': m.group(2) if m.group(2) else '',
                'headers': headers,
                'type': content_type,
                'server': server,
                'body': self.decode_body(body, content_type)
            }

        return None

    def decode_gzip(self, data):
        '''
		还原 HTTP 响应中采用 gzip 压缩的数据
		标识:
		Content-Encoding: gzip
		'''
        try:
            buf = io.BytesIO(data)
            gf = gzip.GzipFile(fileobj=buf)
            content = gf.read()
            gf.close()

            return content
        except:
            return data

    def decode_brotli(self, data):
        '''
		还原 HTTP 响应中采用 brotli 压缩的数据
		标识:
		Content-Encoding: br
		'''
        try:
            return brotli.decompress(data)
        except:
            return data

    def decode_chunked(self, data):
        '''
		还原 HTTP 响应中被 Chunked 的数据
		示例:
		Transfer-Encoding: chunked

		1b
		{"ret":0, "messge":"error"}
		'''
        line_end = data.find(b'\r\n')
        if line_end > 0:
            data_len = -1
            try:
                data_len = int(data[:line_end], 16)
                if data_len == 0:
                    return b''

                if data_len > 0:
                    new_data = data[line_end + 2:line_end + 2 + data_len]
                    return new_data + self.decode_chunked(
                        data[line_end + 2 + data_len + 2:])
            except:
                return data

        return data

    def decode_body(self, data, content_type):
        charset_white_list = [
            'big5', 'big5-hkscs', 'cesu-8', 'euc-jp', 'euc-kr', 'gb18030',
            'gb2312', 'gbk', 'ibm-thai', 'ibm00858', 'ibm01140', 'ibm01141',
            'ibm01142', 'ibm01143', 'ibm01144', 'ibm01145', 'ibm01146',
            'ibm01147', 'ibm01148', 'ibm01149', 'ibm037', 'ibm1026', 'ibm1047',
            'ibm273', 'ibm277', 'ibm278', 'ibm280', 'ibm284', 'ibm285',
            'ibm290', 'ibm297', 'ibm420', 'ibm424', 'ibm437', 'ibm500',
            'ibm775', 'ibm850', 'ibm852', 'ibm855', 'ibm857', 'ibm860',
            'ibm861', 'ibm862', 'ibm863', 'ibm864', 'ibm865', 'ibm866',
            'ibm868', 'ibm869', 'ibm870', 'ibm871', 'ibm918',
            'iso-10646-ucs-2', 'iso-2022-cn', 'iso-2022-jp', 'iso-2022-jp-2',
            'iso-2022-kr', 'iso-8859-1', 'iso-8859-10', 'iso-8859-13',
            'iso-8859-15', 'iso-8859-16', 'iso-8859-2', 'iso-8859-3',
            'iso-8859-4', 'iso-8859-5', 'iso-8859-6', 'iso-8859-7',
            'iso-8859-8', 'iso-8859-9', 'jis_x0201', 'jis_x0212-1990',
            'koi8-r', 'koi8-u', 'shift_jis', 'tis-620', 'us-ascii', 'utf-16',
            'utf-16be', 'utf-16le', 'utf-32', 'utf-32be', 'utf-32le', 'utf-8',
            'windows-1250', 'windows-1251', 'windows-1252', 'windows-1253',
            'windows-1254', 'windows-1255', 'windows-1256', 'windows-1257',
            'windows-1258', 'windows-31j', 'x-big5-hkscs-2001',
            'x-big5-solaris', 'x-euc-jp-linux', 'x-euc-tw', 'x-eucjp-open',
            'x-ibm1006', 'x-ibm1025', 'x-ibm1046', 'x-ibm1097', 'x-ibm1098',
            'x-ibm1112', 'x-ibm1122', 'x-ibm1123', 'x-ibm1124', 'x-ibm1166',
            'x-ibm1364', 'x-ibm1381', 'x-ibm1383', 'x-ibm300', 'x-ibm33722',
            'x-ibm737', 'x-ibm833', 'x-ibm834', 'x-ibm856', 'x-ibm874',
            'x-ibm875', 'x-ibm921', 'x-ibm922', 'x-ibm930', 'x-ibm933',
            'x-ibm935', 'x-ibm937', 'x-ibm939', 'x-ibm942', 'x-ibm942c',
            'x-ibm943', 'x-ibm943c', 'x-ibm948', 'x-ibm949', 'x-ibm949c',
            'x-ibm950', 'x-ibm964', 'x-ibm970', 'x-iscii91',
            'x-iso-2022-cn-cns', 'x-iso-2022-cn-gb', 'x-iso-8859-11',
            'x-jis0208', 'x-jisautodetect', 'x-johab', 'x-macarabic',
            'x-maccentraleurope', 'x-maccroatian', 'x-maccyrillic',
            'x-macdingbat', 'x-macgreek', 'x-machebrew', 'x-maciceland',
            'x-macroman', 'x-macromania', 'x-macsymbol', 'x-macthai',
            'x-macturkish', 'x-macukraine', 'x-ms932_0213', 'x-ms950-hkscs',
            'x-ms950-hkscs-xp', 'x-mswin-936', 'x-pck', 'x-sjis',
            'x-sjis_0213', 'x-utf-16le-bom', 'x-utf-32be-bom',
            'x-utf-32le-bom', 'x-windows-50220', 'x-windows-50221',
            'x-windows-874', 'x-windows-949', 'x-windows-950',
            'x-windows-iso2022jp'
        ]
        content_type = content_type.lower() if content_type else ''
        if 'charset=' in content_type:
            charset = content_type[content_type.find('charset=') +
                                   8:].strip('" ;\r\n').lower()
            if charset != 'iso-8859-1' and charset in charset_white_list:
                return str(data, charset, 'ignore')

        m = self.decode_body_regex.match(data)
        if m:
            charset = m.group(1).lower() if m.group(1) else ''
            if charset != 'iso-8859-1' and charset in charset_white_list:
                return str(data, charset, 'ignore')

        return str(data, 'utf-8', 'ignore')

    def parse_headers(self, data):
        headers = {}
        lines = data.split('\r\n')
        for _ in lines:
            pos = _.find(':')
            if pos > 0:
                headers[_[:pos].lower()] = _[pos + 1:].strip()
        return headers

    def send_msg(self, data):
        result = json.dumps(data)
        if self.debug:
            print(result)
        if len(self.work_queue) >= self.max_queue_size * 0.95:
            self.work_queue.clear()
        self.work_queue.append(result)
コード例 #6
0
ファイル: db.py プロジェクト: glenlancer/my_dict
class DbOperator():
    __DB_USERNAME = '******'
    __DB_PASSWORD = '******'
    __DB_DBNAME = 'dict_db'
    __CACHE_MAXSIZE = 512

    # Use extra spaces, since no ordinary key allows this.
    __ALL_WORDS_KEY = ' ALL_WORDS '
    __ALL_ARTICLE_KEY = ' ALL_ARTICLE '

    def __init__(self):
        self.messages = []
        # TODO Implement in memory data handling to avoid unnecessary
        # Database accessing.
        self.words_detail_cache = LRUCache(maxsize=self.__CACHE_MAXSIZE)
        self.words_name_cache = LRUCache(maxsize=self.__CACHE_MAXSIZE)
        self.usage_cache = LRUCache(maxsize=self.__CACHE_MAXSIZE)
        self.article_detail_cache = LRUCache(maxsize=self.__CACHE_MAXSIZE)
        self.article_name_cache = LRUCache(maxsize=self.__CACHE_MAXSIZE)
        self.reference_cache = LRUCache(maxsize=self.__CACHE_MAXSIZE)

    def __del__(self):
        self.db_close()

    def __cache_analysis(self):
        pass

    def try_db_connect(self):
        try:
            self.db_connect()
        except Exception as e:
            '''
            Error code:
            (1) 2003 - Can't connect to server, possibly due to MySql service is not up or installed.
            (2) 1044 - Access denied, possibly due to db doesn't exist.
            (3) 1045 - Access denied, possibly due to user doesn't exist.
            (4) 1049 - Unknown database, possibly due to db doesn't exist.
            '''
            return e.args
        return (0, 'Success')

    def db_connect(self):
        self.conn = pymysql.connect(host='localhost',
                                    user=self.__DB_USERNAME,
                                    password=self.__DB_PASSWORD,
                                    database=self.__DB_DBNAME,
                                    charset='utf8')
        self.cursor = self.conn.cursor()

    def db_connect_with_no_specified_db(self):
        self.conn = pymysql.connect(host='localhost',
                                    user=self.__DB_USERNAME,
                                    password=self.__DB_PASSWORD,
                                    charset='utf8')
        self.cursor = self.conn.cursor()

    def db_create_database(self):
        try:
            self.db_connect_with_no_specified_db()
        except Exception as e:
            os.system(f'cat {e.args[0]}:{e.args[1]} > dict_error.log')
            return
        self.db_create_db_and_tables()

    def db_create_db_and_tables(self):
        sqls = [
            'create database if not exists dict_db', 'use dict_db', '''
                create table Words (
                    WID           serial,
                    Word          varchar(50) not null unique,
                    Meaning       varchar(250) not null,
                    Pronunciation varchar(50),
                    Exchange      varchar(100),
                    `date`        datetime not null,
                    primary key (Word)
                ) ENGINE=InnoDB Default Charset=utf8
            ''', '''
                create table `Usage` (
                    UID     serial,
                    Word    varchar(50) not null,
                    `Usage` text not null,
                    primary key (UID),
                    foreign key (Word) references Words(Word)
                ) ENGINE=InnoDB Default Charset=utf8
            ''', '''
                create table Article (
                    AID     serial,
                    Title   varchar(100) not null unique,
                    Content text not null,
                    primary key (AID)
                ) ENGINE=InnoDB Default Charset=utf8
            ''', '''
                create table Reference (
                    RID serial,
                    Word varchar(50) not null,
                    Title varchar(100) not null,
                    primary key (RID),
                    foreign key (Word) references Words(Word),
                    foreign key (Title) references Article(Title)
                ) ENGINE=InnoDB Default Charset=utf8
            '''
        ]
        self.execute_all_sqls(sqls, False)

    def db_close(self):
        if hasattr(self, 'cursor'):
            self.cursor.close()
        if hasattr(self, 'conn'):
            self.conn.close()

    def db_commit(self):
        self.conn.commit()

    def db_export_to_file(self, file_name):
        return os.system(
            f'mysqldump -u{self.__DB_USERNAME} -p{self.__DB_PASSWORD} {self.__DB_DBNAME} > {file_name}'
        )

    def db_import_from_file(self, file_name):
        return os.system(
            f'mysql -u{self.__DB_USERNAME} -p{self.__DB_PASSWORD} {self.__DB_DBNAME} < {file_name}'
        )

    def db_fetchone(self, sql):
        try:
            self.cursor.execute(sql)
            # When result is empty, fetchone() returns None.
            res = self.cursor.fetchone()
            if res is None:
                return None
            return list(res)
        except Exception as e:
            self.messages.append(f'SQL failed: {sql}, due to {e.args[-1]}')
            return None

    def db_fetchall(self, sql):
        try:
            self.cursor.execute(sql)
            return list(map(lambda x: list(x), self.cursor.fetchall()))
        except Exception as e:
            self.messages.append(f'SQL failed: {sql}, due to {e.args[-1]}')
            return list()

    def db_execute(self, sql):
        try:
            self.cursor.execute(sql)
        except Exception as e:
            self.messages.append(f'SQL failed: {sql}, due to {e.args[-1]}')
            return False
        return True

    def execute_all_sqls(self, sqls, need_commit=True):
        try:
            for sql in sqls:
                self.cursor.execute(sql)
            if need_commit:
                self.db_commit()
        except Exception as e:
            self.messages.append(f'SQL failed: {sqls}, due to {e.args[-1]}')
            return False
        return True

    def select_word(self, word):
        record = self.words_detail_cache.get(word)
        if record:
            return record
        sql = f'SELECT Meaning, Pronunciation, Exchange FROM Words WHERE Word = "{word}"'
        record = self.db_fetchone(sql)
        if record:
            self.words_detail_cache.add(word, record)
        return record

    def select_like_word(self, word, clear_cache=False):
        if clear_cache:
            self.words_name_cache.delete(word)
        else:
            records = self.words_name_cache.get(word)
            if records is not None:
                return records
        sql = f'SELECT Word FROM Words WHERE Word LIKE "%{word}%"'
        records = self.db_fetchall(sql)
        records = list(map(lambda x: x[0], records))
        self.words_name_cache.add(word, records)
        return records

    def select_all_words(self, clear_cache=False):
        if clear_cache:
            self.words_name_cache.delete(self.__ALL_WORDS_KEY)
        else:
            records = self.words_name_cache.get(self.__ALL_WORDS_KEY)
            if records is not None:
                return records
        sql = 'SELECT Word FROM Words'
        records = self.db_fetchall(sql)
        records = list(map(lambda x: x[0], records))
        self.words_name_cache.add(self.__ALL_WORDS_KEY, records)
        return records

    def select_usages(self, word):
        records = self.usage_cache.get(word)
        if records is not None:
            return records
        sql = f'SELECT `Usage` FROM `Usage` WHERE Word = "{word}"'
        records = self.db_fetchall(sql)
        records = list(map(lambda x: x[0], records))
        self.usage_cache.add(word, records)
        return records

    def select_article_for_word(self, word):
        records = self.reference_cache.get(word)
        if records is not None:
            return records
        sql = ''.join([
            'SELECT Article.Title from Article ',
            'JOIN Reference ON Reference.Title = Article.Title ',
            f'WHERE Word = "{word}"'
        ])
        records = self.db_fetchall(sql)
        records = list(map(lambda x: x[0], records))
        self.reference_cache.add(word, records)
        return records

    def select_article(self, title):
        record = self.article_detail_cache.get(title)
        if record:
            return record
        esd_title = escape_double_quotes(title)
        sql = f'SELECT Content FROM Article WHERE Title = "{esd_title}"'
        record = self.db_fetchone(sql)
        if record:
            self.article_detail_cache.add(title, record[0])
        return record if record is None else record[0]

    def select_like_article(self, title, clear_cache=False):
        if clear_cache:
            self.article_name_cache.delete(title)
        else:
            records = self.article_name_cache.get(title)
            if records is not None:
                return records
        esd_title = escape_double_quotes(title)
        sql = f'SELECT Title FROM Article WHERE Title LIKE "%{esd_title}%"'
        records = self.db_fetchall(sql)
        records = list(map(lambda x: x[0], records))
        self.article_name_cache.add(title, records)
        return records

    def select_all_article_titles(self, clear_cache=False):
        if clear_cache:
            self.article_name_cache.delete(self.__ALL_ARTICLE_KEY)
        else:
            records = self.article_name_cache.get(self.__ALL_ARTICLE_KEY)
            if records is not None:
                return records
        sql = 'SELECT Title FROM Article'
        records = self.db_fetchall(sql)
        records = list(map(lambda x: x[0], records))
        self.article_name_cache.add(self.__ALL_ARTICLE_KEY, records)
        return records

    def select_all_articles(self):
        sql = 'SELECT Title, Content FROM Article'
        records = self.db_fetchall(sql)
        self.article_detail_cache.clear()
        count = 0
        for record in records:
            if count < self.__CACHE_MAXSIZE:
                self.article_detail_cache.add(record[0], record[1])
                count += 1
            else:
                break
        return records

    def insert_word(self, word, meaning, pron, exchange):
        esd_meaning = escape_double_quotes(meaning)
        esd_pronunciation = escape_double_quotes(pron)
        esd_exchange = escape_double_quotes(exchange)
        sql = ''.join([
            'INSERT INTO Words\n',
            '(Word, Meaning, Pronunciation, Exchange, `date`)\n', 'VALUES\n',
            f'("{word}", "{esd_meaning}", "{esd_pronunciation}", "{esd_exchange}", CURDATE())'
        ])
        res = self.db_execute(sql)
        if not res:
            return False
        self.words_detail_cache.add(word, [meaning, pron, exchange])
        self.words_name_cache.clear()
        return True

    def update_word(self, word, meaning, pron, exchange):
        esd_meaning = escape_double_quotes(meaning)
        esd_pronunciation = escape_double_quotes(pron)
        esd_exchange = escape_double_quotes(exchange)
        sql = ''.join([
            f'UPDATE Words SET Meaning="{esd_meaning}", Pronunciation="{esd_pronunciation}", Exchange="{esd_exchange}", `date`=CURDATE()\n',
            f'WHERE Word="{word}"'
        ])
        res = self.db_execute(sql)
        if not res:
            return False
        self.words_detail_cache.set(word, [meaning, pron, exchange])
        return True

    def insert_article(self, title, content):
        esd_title = escape_double_quotes(title)
        esd_content = escape_double_quotes(content)
        sql = f'INSERT INTO Article (Title, Content) VALUES ("{esd_title}", "{esd_content}")'
        res = self.db_execute(sql)
        if not res:
            return False
        self.article_detail_cache.add(title, content)
        self.article_name_cache.clear()
        return True

    def update_article(self, title, content):
        esd_title = escape_double_quotes(title)
        esd_content = escape_double_quotes(content)
        sql = ''.join([
            f'UPDATE Article SET Content="{esd_content}"\n',
            f'WHERE Title="{esd_title}"'
        ])
        res = self.db_execute(sql)
        if not res:
            return False
        self.article_detail_cache.set(title, content)
        return True

    def insert_usage(self, word, usage):
        esd_usage = escape_double_quotes(usage)
        sql = f'INSERT INTO `Usage` (Word, `Usage`) VALUES ("{word}", "{esd_usage}")'
        res = self.db_execute(sql)
        if not res:
            return False
        usages = self.usage_cache.get(word)
        if usages is None:
            usages = []
        usages.append(usage)
        self.usage_cache.set(word, usages)
        return True

    def insert_article(self, title, content):
        esd_title = escape_double_quotes(title)
        esd_content = escape_double_quotes(content)
        sql = f'INSERT INTO Article (Title, Content) VALUES ("{esd_title}", "{esd_content}")'
        res = self.db_execute(sql)
        if not res:
            return False
        self.article_detail_cache.add(title, content)
        self.article_name_cache.clear()
        return True

    def truncate_reference(self):
        sql = 'TRUNCATE TABLE Reference'
        res = self.db_execute(sql)
        if not res:
            return False
        self.reference_cache.clear()
        return True

    def insert_reference(self, word, title):
        esd_title = escape_double_quotes(title)
        sql = f'INSERT INTO Reference (Word, Title) VALUES ("{word}", "{esd_title}")'
        res = self.db_execute(sql)
        if not res:
            return False
        res = self.reference_cache.get(word)
        if res is None:
            self.reference_cache.add(word, [title])
        else:
            self.reference_cache.set(word, res.append(title))
        return True

    def drop_all_tables(self):
        sqls = [
            'DROP TABLE Reference', 'DROP TABLE `Usage`', 'DROP TABLE Article',
            'DROP TABLE Words'
        ]
        self.clear_all_caches()
        return self.execute_all_sqls(sqls, False)

    def delete_a_word(self, word):
        sqls = [
            'DELETE FROM Reference WHERE Word="{}"'.format(word),
            'DELETE FROM `Usage` WHERE Word="{}"'.format(word),
            'DELETE FROM Words WHERE Word="{}"'.format(word)
        ]
        res = self.execute_all_sqls(sqls)
        if not res:
            return False
        self.words_detail_cache.delete(word)
        self.usage_cache.delete(word)
        self.reference_cache.delete(word)
        # TODO Do we need to touch words_name_cache?
        # Seems not needed, since used a clear_cache parameter.
        return True

    def delete_a_article(self, title):
        esd_title = escape_double_quotes(title)
        sqls = [
            'DELETE FROM Reference WHERE Title="{}"'.format(esd_title),
            'DELETE FROM Article WHERE Title="{}"'.format(esd_title)
        ]
        res = self.execute_all_sqls(sqls)
        if not res:
            return False
        self.article_detail_cache.delete(title)
        for key in self.reference_cache.keys():
            value = self.reference_cache.get(key)
            if title in value:
                value.remove(title)
            self.reference_cache.set(key, value)
        # TODO Do we need to touch article_name_cache?
        # Seems not needed, since used a clear_cache parameter.
        return True

    def clear_all_caches(self):
        self.words_detail_cache.clear()
        self.words_name_cache.clear()
        self.usage_cache.clear()
        self.article_detail_cache.clear()
        self.article_name_cache.clear()
        self.reference_cache.clear()

    def print_messages(self):
        if DEBUG_FLAG:
            print('--- All messages ---')
            for message in self.messages:
                print(message)
            self.messages = []
            print('--- End of all messages ---')
コード例 #7
0
ファイル: test_lru.py プロジェクト: meiliqi/cacheout
def test_lru_get_default(cache: LRUCache):
    """Test that LRUCache.get() returns a default value."""
    default = "bar"
    assert cache.get("foo", default=default) == default
コード例 #8
0
ファイル: test_lru.py プロジェクト: meiliqi/cacheout
def test_lru_get(cache: LRUCache):
    """Test that LRUCache.get() returns cached value."""
    for key, value in cache.items():
        assert cache.get(key) == value
コード例 #9
0
class tcp_http_shark():
    def __init__(self, work_queue, interface, custom_tag, return_deep_info,
                 http_filter_json, cache_size, session_size, bpf_filter,
                 timeout, debug):
        """
		构造函数
		:param work_queue: 捕获资产数据消息发送队列
		:param interface: 捕获流量的网卡名
		:param custom_tag: 数据标签,用于区分不同的采集引擎
		:param return_deep_info: 是否处理更多信息,包括原始请求、响应头和正文
		:param http_filter_json: HTTP过滤器配置,支持按状态和内容类型过滤
		:param cache_size: 缓存的已处理数据条数,120秒内重复的数据将不会发送Syslog
		:param session_size: 缓存的HTTP/TCP会话数量,16秒未使用的会话将被自动清除
		:param bpf_filter: 数据包底层过滤器
		:param timeout: 采集程序的运行超时时间,默认为启动后1小时自动退出
		:param debug: 调试开关
		"""
        self.work_queue = work_queue
        self.debug = debug
        self.timeout = timeout
        self.bpf_filter = bpf_filter
        self.cache_size = cache_size
        self.session_size = session_size
        self.http_filter_json = http_filter_json
        self.return_deep_info = return_deep_info
        self.custom_tag = custom_tag
        self.interface = interface
        self.pktcap = pyshark.LiveCapture(interface=self.interface,
                                          bpf_filter=self.bpf_filter,
                                          use_json=False,
                                          debug=self.debug)
        if self.session_size:
            self.http_stream_cache = Cache(maxsize=self.session_size,
                                           ttl=16,
                                           timer=time.time,
                                           default=None)
            self.tcp_stream_cache = Cache(maxsize=self.session_size,
                                          ttl=16,
                                          timer=time.time,
                                          default=None)
        if self.cache_size:
            self.http_cache = LRUCache(maxsize=self.cache_size,
                                       ttl=120,
                                       timer=time.time,
                                       default=None)
            self.tcp_cache = LRUCache(maxsize=self.cache_size,
                                      ttl=120,
                                      timer=time.time,
                                      default=None)
        # 检测页面编码的正则表达式
        self.encode_regex = re.compile(
            rb'<meta [^>]*?charset=["\']?([a-z\-\d]+)["\'>]?', re.I)

    def http_filter(self, key, value):
        """
		检查字符串中是否包含特定的规则
		:param key: 规则键名,response_code(状态码)或 content_type(内容类型)
		:param value: 要检查的字符串
		:return: True - 包含, False - 不包含
		"""
        if key in self.http_filter_json:
            for rule in self.http_filter_json[key]:
                if rule in value:
                    return True
        return False

    def run(self):
        """
		入口函数
		"""
        try:
            self.pktcap.apply_on_packets(self.proc_packet,
                                         timeout=self.timeout)
        except concurrent.futures.TimeoutError:
            print("\nTimeoutError.")

    def proc_packet(self, pkt):
        """
		全局数据包处理:识别、路由及结果发送
		:param pkt: 数据包
		:return: JSON or None
		"""
        try:
            pkt_json = None
            pkt_dict = dir(pkt)

            if 'ip' in pkt_dict:
                if 'http' in pkt_dict:
                    pkt_json = self.proc_http(pkt)
                elif 'tcp' in pkt_dict:
                    pkt_json = self.proc_tcp(pkt)

            if pkt_json:
                result = json.dumps(pkt_json)
                if self.debug:
                    print(result)
                self.work_queue.append(result)
        except:
            traceback.print_exc()

    def proc_http(self, pkt):
        """
		处理 HTTP 包
		:param pkt: 数据包
		:return: JSON or None
		"""
        http_dict = dir(pkt.http)

        if 'request' in http_dict and self.session_size:
            req = {
                'url':
                pkt.http.request_full_uri
                if 'request_full_uri' in http_dict else pkt.http.request_uri,
                'method':
                pkt.http.request_method
                if 'request_method' in http_dict else ''
            }

            self.http_stream_cache.set(pkt.tcp.stream, req)

        elif 'response' in http_dict:
            pkt_json = {}
            src_addr = pkt.ip.src
            src_port = pkt[pkt.transport_layer].srcport

            if self.session_size:
                cache_req = self.http_stream_cache.get(pkt.tcp.stream)
                if cache_req:
                    pkt_json['url'] = cache_req['url']
                    pkt_json['method'] = cache_req['method']
                    self.http_stream_cache.delete(pkt.tcp.stream)

            if 'url' not in pkt_json:
                if 'response_for_uri' in http_dict:
                    pkt_json["url"] = pkt.http.response_for_uri
                else:
                    pkt_json["url"] = '/'

            # 处理 URL 只有URI的情况
            if pkt_json["url"][0] == '/':
                if src_port == '80':
                    pkt_json["url"] = "http://%s%s" % (src_addr,
                                                       pkt_json["url"])
                else:
                    pkt_json["url"] = "http://%s:%s%s" % (src_addr, src_port,
                                                          pkt_json["url"])

            if self.cache_size:
                # 缓存机制,防止短时间大量处理重复响应
                exists = self.http_cache.get(pkt_json['url'])
                if exists:
                    return None

                self.http_cache.set(pkt_json["url"], True)

            pkt_json["pro"] = 'HTTP'
            pkt_json["tag"] = self.custom_tag
            pkt_json["ip"] = src_addr
            pkt_json["port"] = src_port

            if 'response_code' in http_dict:
                if self.http_filter_json:
                    return_status = self.http_filter('response_code',
                                                     pkt.http.response_code)
                    if return_status:
                        return None
                pkt_json["code"] = pkt.http.response_code

            if 'content_type' in http_dict:
                if self.http_filter_json:
                    return_status = self.http_filter('content_type',
                                                     pkt.http.content_type)
                    if return_status:
                        return None
                pkt_json["type"] = pkt.http.content_type.lower()
            else:
                pkt_json["type"] = 'unkown'

            if 'server' in http_dict:
                pkt_json["server"] = pkt.http.server

            # 开启深度数据分析,返回header和body等数据
            if self.return_deep_info:
                charset = 'utf-8'
                # 检测 Content-Type 中的编码信息
                if 'type' in pkt_json and 'charset=' in pkt_json["type"]:
                    charset = pkt_json["type"][pkt_json["type"].find('charset='
                                                                     ) +
                                               8:].strip().lower()
                    if not charset:
                        charset = 'utf-8'
                if 'payload' in dir(pkt.tcp):
                    payload = bytes.fromhex(
                        str(pkt.tcp.payload).replace(':', ''))
                    if payload.find(b'HTTP/') == 0:
                        split_pos = payload.find(b'\r\n\r\n')
                        if split_pos <= 0 or split_pos > 2048:
                            split_pos = 2048
                        pkt_json["header"] = str(payload[:split_pos], 'utf-8',
                                                 'ignore')
                        data = str(payload[split_pos + 4:], 'utf-8', 'ignore')
                if 'file_data' in http_dict and pkt.http.file_data.raw_value and pkt_json[
                        'type'] != 'application/octet-stream':
                    data = bytes.fromhex(pkt.http.file_data.raw_value)
                elif 'data' in http_dict:
                    data = bytes.fromhex(pkt.http.data)
                elif 'segment_data' in dir(pkt.tcp):
                    data = bytes.fromhex(pkt.tcp.segment_data.replace(":", ""))
                else:
                    data = ''

                if data:
                    # 检测页面 Meta 中的编码信息
                    data_head = data[:500] if data.find(
                        b'</head>', 0,
                        1024) == -1 else data[:data.find(b'</head>')]
                    match = self.encode_regex.search(data_head)
                    if match:
                        charset = str(
                            match.group(1).strip().lower(), 'utf-8', 'ignore')

                    response_body = proc_body_str(str(data, charset, 'ignore'),
                                                  16 * 1024)
                    # response_body = self.proc_body_json(str(data, charset, 'ignore'), 16*1024)
                    pkt_json["body"] = response_body
                else:
                    pkt_json["body"] = ''

            return pkt_json

        return None

    def proc_tcp(self, pkt):
        """
		处理 TCP 包
		:param pkt: 数据包
		:return: JSON or None
		"""
        tcp_stream = pkt.tcp.stream

        pkt_json = {}
        pkt_json["pro"] = 'TCP'
        pkt_json["tag"] = self.custom_tag

        # SYN+ACK
        if pkt.tcp.flags == '0x00000012':
            server_ip = pkt.ip.src
            server_port = pkt[pkt.transport_layer].srcport
            tcp_info = '%s:%s' % (server_ip, server_port)

            if self.cache_size:
                exists = self.tcp_cache.get(tcp_info)
                if exists:
                    return None
                self.tcp_cache.set(tcp_info, True)

            if self.return_deep_info and self.session_size:
                self.tcp_stream_cache.set(tcp_stream, tcp_info)
            else:
                pkt_json["ip"] = server_ip
                pkt_json["port"] = server_port

                return pkt_json

        # -r on开启深度数据分析,采集server第一个响应数据包
        if self.return_deep_info and pkt.tcp.seq == "1" and "payload" in dir(
                pkt.tcp) and self.session_size:
            tcp_info = self.tcp_stream_cache.get(tcp_stream)
            if tcp_info:
                # 防止误处理客户端发第一个包的情况
                src_host = '{}:{}'.format(pkt.ip.src,
                                          pkt[pkt.transport_layer].srcport)
                if tcp_info != src_host:
                    return None

                self.tcp_stream_cache.delete(tcp_stream)

                pkt_json["ip"] = pkt.ip.src
                pkt_json["port"] = pkt[pkt.transport_layer].srcport
                payload_data = pkt.tcp.payload.replace(":", "")
                if payload_data.startswith("48545450"):  # ^HTTP
                    return None

                # HTTPS Protocol
                # TODO: other https port support
                if pkt_json["port"] == "443" and payload_data.startswith(
                        "1603"):  # SSL
                    pkt_json["pro"] = 'HTTPS'
                    pkt_json["url"] = "https://{}/".format(pkt_json["ip"])
                else:
                    pkt_json["data"] = proc_data_str(payload_data, 16 * 1024)

                return pkt_json
        return None