예제 #1
0
def test_cache_default():
    """Test that Cache can set the default for Cache.get()."""
    cache = Cache(default=True)

    assert cache.get(1) is True
    assert 1 not in cache
    assert cache.get(2, default=False) is False
    assert 2 not in cache
예제 #2
0
def test_cache_get(cache: Cache):
    """Test that cache.get() returns a cache key value or a default value if missing."""
    key, value = ("key", "value")

    assert cache.get(key) is None
    assert cache.get(key, default=1) == 1
    assert key not in cache

    cache.set(key, value)
    assert cache.get(key) == value
예제 #3
0
def test_cache_get_default_callable(cache: Cache):
    """Test that cache.get() uses a default function when value is not found to set cache keys."""
    def default(key):
        return key

    key = "key"

    assert cache.get(key) is None
    assert cache.get(key, default=default) == key
    assert key in cache
예제 #4
0
def test_cache_default_callable():
    """Test that Cache can set a default function for Cache.get()."""
    def default(key):
        return False

    def default_override(key):
        return key

    cache = Cache(default=default)

    assert cache.get("key1") is False
    assert cache.get("key1", default=default_override) is False
    assert cache.get("key2", default=default_override) == "key2"
    assert cache.get("key3", default=3) == 3
예제 #5
0
def test_cache_set_many(cache: Cache):
    """Test that cache.set_many() sets multiple cache key/values."""
    items = {"a": 1, "b": 2, "c": 3}
    cache.set_many(items)

    for key, value in items.items():
        assert cache.get(key) == value
예제 #6
0
class Predictor_Ocr:
    def __init__(self, ip, port, password):
        self.ip = ip
        self.port = port
        self.password = password
        self.redis_pool = redis.ConnectionPool(host=ip,
                                               port=port,
                                               db=0,
                                               password=password,
                                               encoding='utf-8')
        self.cache = Cache()
        self.cur = os.getcwd()

    def get_model(self, model_key):
        if self.cache.has(model_key) is True:
            return self.cache.get(model_key)
        else:
            path = "%s/%s/data/%s.h5" % (self.cur, model_key, model_key)
            model = load_model(path)
            self.cache.set(model_key, model)
            return model

    def work(self):
        while True:
            width = 64
            height = 64
            red = redis.Redis(connection_pool=self.redis_pool)
            task_str = red.lpop("aiocr")
            if task_str is not None:
                task = json.loads(task_str)
                create_user_id = task["create_user_id"]
                image_id = task["image_id"]
                polygon_id = task["polygon_id"]
                image = task["image"]
                algorithm_set = task["algorithm"]
                image_list = [image]
                image_batch = np.array(image_list, dtype=np.float32).reshape(
                    -1, width, height, 1)
                image_batch = image_batch / 255.0
                out_list = []
                top_k = 5
                thr = 0.1
                for algorithm in algorithm_set:
                    path = "%s/%s/data/elem_list.json" % (self.cur, algorithm)
                    with open(path, 'r') as f:
                        elem_list = json.load(f)
                    model = self.get_model(algorithm)
                    out = model.predict(image_batch)[0]
                    top_candidate = out.argsort()[::-1][0:top_k]
                    for item in top_candidate:
                        if out[item] > thr and elem_list[item] > -1:
                            out_list.append(elem_list[item])
                key = "%s_%s_%s_%s" % ("rs_aiocr", create_user_id, image_id,
                                       polygon_id)
                red.set(key, json.dumps(out_list))

            time.sleep(0.005)
예제 #7
0
def test_cache_iter(cache: Cache):
    """Test that iterating over cache yields each cache key."""
    items: dict = {"a": 1, "b": 2, "c": 3}
    cache.set_many(items)

    keys = []
    for key in cache:
        assert cache.get(key) == items[key]
        keys.append(key)

    assert set(keys) == set(items)
예제 #8
0
def test_cache_add(cache: Cache):
    """Test that cache.add() sets a cache key but only if it doesn't exist."""
    key, value = ("key", "value")
    ttl = 2

    cache.add(key, value, ttl)
    assert cache.get(key) == value

    assert cache.expire_times()[key] == ttl

    cache.add(key, value, ttl + 1)
    assert cache.expire_times()[key] == ttl

    cache.set(key, value, ttl + 1)
    assert cache.expire_times()[key] == ttl + 1
class XMLParser():
    def __init__(self, path):
        self.cache = Cache()
        self.root = etree.parse(r'{0}'.format(path))

    def storeElementWithKey(self, element):
        udid = str(uuid.uuid4())
        self.cache.set(udid, element)
        return udid

    def getStorElementWithKey(self, udid):
        print(udid)
        return self.cache.get(udid)

    def getElementByXpath(self, xpath):
        var = self.root.xpath(xpath)
        return self.storeElementWithKey(var)
예제 #10
0
def test_cache_memoize_arg_normalization(cache: Cache):
    """Test that cache.memoize() normalizes argument ordering for positional and keyword
    arguments."""
    @cache.memoize(typed=True)
    def func(a, b, c, d, **kwargs):
        return a, b, c, d

    for args, kwargs in (
        ((1, 2, 3, 4), {
            "e": 5
        }),
        ((1, 2, 3), {
            "d": 4,
            "e": 5
        }),
        ((1, 2), {
            "c": 3,
            "d": 4,
            "e": 5
        }),
        ((1, ), {
            "b": 2,
            "c": 3,
            "d": 4,
            "e": 5
        }),
        ((), {
            "a": 1,
            "b": 2,
            "c": 3,
            "d": 4,
            "e": 5
        }),
        ((), {
            "a": 1,
            "b": 2,
            "c": 3,
            "d": 4,
            "e": 5
        }),
    ):
        cached = func(*args, **kwargs)
        assert cache.get(func.cache_key(*args, **kwargs)) is cached
        assert len(cache) == 1
예제 #11
0
class Server(Thread):
    def __init__(self):
        super().__init__()
        self.channels = []  # save subscribed channels
        self.gateways = {
            'binance': BinanceWs,
            'okex': OkexWs,
            'huobi': HuobiWs,
            'wootrade': WootradeWs
        }  # class of exchanges
        self.ongoing_gateway = {}  # has been instantiated exchanges
        self.feature = [
            'price', 'orderbook', 'trade', 'kline', 'order', 'wallet'
        ]
        self.cache = Cache(maxsize=256, timer=time.time)
        self.server = redis.StrictRedis(host='localhost',
                                        port=6379)  # redis server
        LOG_FORMAT = '%(asctime)s - %(levelname)s - %(message)s'
        DATE_FORMAT = '%m/%d/%Y %H:%M:%S %p'
        logging.basicConfig(filename='server.log',
                            level=logging.WARNING,
                            format=LOG_FORMAT,
                            datefmt=DATE_FORMAT)
        Thread(target=self.receiver).start()
        Thread(target=self.publisher).start()

    def _integrate_gateway(self, gateway, symbol: str, feature: str):
        try:
            # call the correspond function
            switch = {
                'price': gateway.sub_price,
                'orderbook': gateway.sub_orderbook,
                'trade': gateway.sub_trade,
                'kline': gateway.sub_kline,
                'order': gateway.sub_order,
                'wallet': gateway.sub_wallet_balance
            }
            switch.get(
                feature,
                lambda r: print(f'{feature} feature dose not exist!'))(symbol)
        except Exception as e:
            logging.error(f'Integrate gateway error: {e}')

    def _handle_receiver(self):
        try:
            # get sub request from redis key
            recv = self.server.get('sub')
            if recv:
                recv = json.loads(recv)
                # print(recv)
                for item in recv:
                    if item not in self.channels:
                        gateway, symbol, feature = item.split('&')
                        # Determine whether the exchange exists,
                        # if not, return an error,
                        # if it exists, instantiate the exchange
                        if gateway in self.gateways:
                            # Determine whether the exchange has been instantiated
                            if gateway not in self.ongoing_gateway:
                                instant_gate = self.gateways[gateway]()
                                self.ongoing_gateway.update(
                                    {gateway: instant_gate})
                            self._integrate_gateway(
                                self.ongoing_gateway[gateway], symbol, feature)
                            # add channel
                            self.channels.append(item)
                        else:
                            msg = f'{gateway} does not exist'
                            data = {'code': 500, 'data': None, 'msg': msg}
                            self.server.publish(channel=json.dumps(recv),
                                                message=json.dumps(data))
                self.server.delete('sub')
        except Exception as e:
            logging.error(f'handle receiver error: {e}')

    def receiver(self):
        try:
            while True:
                time.sleep(2)
                # receive subscription request every 2 seconds
                if self.server.exists('sub'):
                    self._handle_receiver()
        except Exception as e:
            logging.error(f'receiver error: {e}')

    def publisher(self):
        try:
            while True:
                time.sleep(0.001)
                for channel in self.channels:
                    gateway, symbol, feature = channel.split('&')
                    data = self.ongoing_gateway[gateway].data[symbol][feature]
                    # print(data)
                    # if ticker does not existed or exchange server error
                    if data is None or len(data) == 0:
                        data = f'{gateway}&{symbol}&{feature} does not exist'
                    # set cache, determine the duplicate data
                    cache_data = self.cache.get(
                        f'{gateway}&{symbol}&{feature}')
                    if cache_data is None:
                        # cache data does not exist, update cache data and set effective time to 5 seconds
                        if feature == 'kline' and type(
                                data) == list and len(data) > 0:
                            self.cache.set(f'{gateway}&{symbol}&{feature}',
                                           data[0],
                                           ttl=15)
                        else:
                            self.cache.set(f'{gateway}&{symbol}&{feature}',
                                           data,
                                           ttl=15)
                    else:
                        if data == cache_data:  # trade / orderbook / price
                            data = {
                                'code': 403,
                                'data': data,
                                'msg': 'Duplicate Data'
                            }
                            self.server.publish(channel=json.dumps(channel),
                                                message=json.dumps(data))
                            continue
                        elif type(data
                                  ) == list and data[0] == cache_data:  # kline
                            data = {
                                'code': 403,
                                'data': data,
                                'msg': 'Duplicate Data'
                            }
                            self.server.publish(channel=json.dumps(channel),
                                                message=json.dumps(data))
                        else:  # if the new data differ from cache data, update cache data
                            if feature == 'kline' and len(data) > 0:
                                self.cache.set(f'{gateway}&{symbol}&{feature}',
                                               data[0],
                                               ttl=15)
                            else:
                                self.cache.set(f'{gateway}&{symbol}&{feature}',
                                               data,
                                               ttl=15)
                            data = {'code': 200, 'data': data, 'msg': 'OK'}
                            self.server.publish(channel=json.dumps(channel),
                                                message=json.dumps(data))
        except Exception as e:
            logging.error(f'Publisher error: {e}')
예제 #12
0
def cache_tt():
    cache = Cache(maxsize=2, ttl=2)
    cache.set(1, 'one')
    print(cache.get(1))
    time.sleep(3)
    print(cache.get(1))
예제 #13
0
def test_cache_set(cache: Cache):
    """Test that cache.set() sets cache key/value."""
    key, value = ("key", "value")
    cache.set(key, value)
    assert cache.get(key) == value
예제 #14
0
class Redis(object):
    def __init__(self):
        self._data = Cache()
        self._tokens = Cache()

    def check_token(self, token):
        return token in self._tokens.keys()

    def set_token(self, key, value):
        self._tokens.set(key, value)

    def keys(self, pattern):
        result = []
        for key in self._data.keys():
            if re.search(pattern, key) != None:
                result.append(key)
        return result

    def set(self, key, value, ttl=None):
        self._data.set(key, value, ttl)
        return 1

    def hset(self, hash, key, value, ttl=None):
        self._data.set(hash, key, ttl)
        self._data.set(key, value, ttl)
        return 1

    def lset(self, name_list, index, value, ttl=None):
        if name_list not in self._data.keys():
            array_of_data = [0 for i in range(index + 1)]
            array_of_data[index] = value
        else:
            array_of_data = self.get(name_list)
            if (len(array_of_data) <= index):
                array_of_data += [
                    0 for i in range(index + 1 - len(array_of_data))
                ]
                array_of_data[index] = value
            else:
                array_of_data[index] = value
        self._data.set(name_list, array_of_data, ttl)
        return 1

    def get(self, key):
        return self._data.get(key)

    def hget(self, hash, key):
        find_key = self._data.get(hash)
        if find_key != key:
            return None
        return self._data.get(find_key)

    def lget(self, name_of_list, index):
        array_of_data = self._data.get(name_of_list)
        if array_of_data == None:
            return None
        print(array_of_data)
        if index < len(array_of_data):
            return array_of_data[index]
        else:
            return None

    def delete(self, key):
        if key in self._data.keys():
            self._data.delete(key)
            return 1
        else:
            return None
예제 #15
0
class KeyValue:
    #define two stores a dictionary for indeterminate storage and a ttl cache for expiring data
    def __init__(self):
        self.cache = {}
        self.cache_ttl = Cache(maxsize=MAX_SIZE,
                               ttl=0,
                               timer=time.time,
                               default=None)

    #add data to cache if no ttl. add to cache_ttl if time limit provided
    async def put(self, key, value, expire_time=DEFAULT_TIME):
        if (not self._checkKey(key)):
            raise KeyError
        if (not self._checkValue(value)):
            raise KeyError

        if expire_time != 0:  #if data has expire time set to ttl cache and delete if exists in indeterminate cache
            self.cache_ttl.set(key, value, ttl=expire_time)
            await self._delete_cache(key)
        else:
            self.cache[key] = value
        return 1

    #retrieve data if avialable
    async def retrieve(self, key):
        if (not self._checkKey(key)):
            raise KeyError
        result = await self._retrieve_cache(key)
        result_ttl = await self._retrieve_cache_ttl(key)
        if (result == False and result_ttl == False):
            raise KeyError
        elif result:
            return result
        else:
            return result_ttl

    async def delete(self, key):
        if (not self._checkKey(key)):
            raise KeyError
        await self._delete_cache(key)
        await self._delete_cache_ttl(key)
        return 1

    #retrieval for cache and ttl cache
    async def _retrieve_cache(self, key):
        if (not await self._contains_cache(key)):
            return False
        return self.cache[key]

    async def _retrieve_cache_ttl(self, key):
        if (not await self._contains_cache_ttl(key)):
            return False
        return self.cache_ttl.get(key)

    #deletion for cache and ttl cache
    async def _delete_cache(self, key):
        if (not await self._contains_cache(key)):
            return 1
        del self.cache[key]
        return 1

    async def _delete_cache_ttl(self, key):
        if (not await self._contains_cache_ttl(key)):
            return 1
        del self.cache_ttl[key]
        return 1

    #check key and value being alpha numberic strings of approriate length
    def _checkKey(self, key):
        if (isinstance(key, str) and key.isalnum() and len(key) <= KEY_LENGTH):
            return True
        else:
            return False

    def _checkValue(self, value):
        if (isinstance(value, str) and value.isalnum()
                and len(value) <= VALUE_LENGTH):
            return True
        else:
            return False

    #check each data store for key values
    async def _contains_cache(self, key):
        return key in self.cache.keys()

    async def _contains_cache_ttl(self, key):
        return self.cache_ttl.has(key)
예제 #16
0
统计缓存
"""
# from cacheout import Cache# 如果选择LFUCache 就导入即可
# from cacheout import LFUCache
# cache = LFUCache()
import time
from cacheout import Cache

# 默认的缓存大小为256, 默认存活时间是关闭的
cache = Cache(maxsize=256, ttl=0, timer=time.time, default=None)

# 通过key/value的形式进行set与get
cache.set(1, 'foobar')
cache.set(2, 'foobar2')
ret = cache.get(1)
print("ret:", ret)

# 可以为每个键值对设置存活过期时间:
# cache.set(3, {"data": {}}, ttl=1)
# assert cache.get(3) == {"data": {}}
# time.sleep(2)
# assert cache.get(3) == {"data": {}}

# 为缓存函数提供了键值对的存活时间:
# @cache.memoize()
# def func(a, b):
#     pass
#
#
# @cache.memoize()
예제 #17
0
# -*- coding: utf-8 -*-
# __file__  : test_01.py
# __time__  : 2020/6/29 3:31 下午
import asyncio
import time

from cacheout import Cache, CacheManager

cache = Cache()

cache = Cache(maxsize=256, ttl=0, timer=time.time, default=None)  # defaults

cache.set(1, "foobar")

assert cache.get(1) == "foobar"

assert cache.get(2) is None
assert cache.get(2, default=False) is False
assert 2 not in cache

assert 2 not in cache
assert cache.get(2, default=lambda key: key) == 2
assert cache.get(2) == 2
assert 2 in cache

cache.set(3, {"data": {}}, ttl=1)
assert cache.get(3) == {"data": {}}
time.sleep(1)
assert cache.get(3) is None

예제 #18
0
class tcp_http_shark():
    def __init__(self, work_queue, interface, custom_tag, return_deep_info,
                 http_filter_json, cache_size, session_size, bpf_filter,
                 timeout, debug):
        """
		构造函数
		:param work_queue: 捕获资产数据消息发送队列
		:param interface: 捕获流量的网卡名
		:param custom_tag: 数据标签,用于区分不同的采集引擎
		:param return_deep_info: 是否处理更多信息,包括原始请求、响应头和正文
		:param http_filter_json: HTTP过滤器配置,支持按状态和内容类型过滤
		:param cache_size: 缓存的已处理数据条数,120秒内重复的数据将不会发送Syslog
		:param session_size: 缓存的HTTP/TCP会话数量,16秒未使用的会话将被自动清除
		:param bpf_filter: 数据包底层过滤器
		:param timeout: 采集程序的运行超时时间,默认为启动后1小时自动退出
		:param debug: 调试开关
		"""
        self.work_queue = work_queue
        self.debug = debug
        self.timeout = timeout
        self.bpf_filter = bpf_filter
        self.cache_size = cache_size
        self.session_size = session_size
        self.http_filter_json = http_filter_json
        self.return_deep_info = return_deep_info
        self.custom_tag = custom_tag
        self.interface = interface
        self.pktcap = pyshark.LiveCapture(interface=self.interface,
                                          bpf_filter=self.bpf_filter,
                                          use_json=False,
                                          debug=self.debug)
        if self.session_size:
            self.http_stream_cache = Cache(maxsize=self.session_size,
                                           ttl=16,
                                           timer=time.time,
                                           default=None)
            self.tcp_stream_cache = Cache(maxsize=self.session_size,
                                          ttl=16,
                                          timer=time.time,
                                          default=None)
        if self.cache_size:
            self.http_cache = LRUCache(maxsize=self.cache_size,
                                       ttl=120,
                                       timer=time.time,
                                       default=None)
            self.tcp_cache = LRUCache(maxsize=self.cache_size,
                                      ttl=120,
                                      timer=time.time,
                                      default=None)
        # 检测页面编码的正则表达式
        self.encode_regex = re.compile(
            rb'<meta [^>]*?charset=["\']?([a-z\-\d]+)["\'>]?', re.I)

    def http_filter(self, key, value):
        """
		检查字符串中是否包含特定的规则
		:param key: 规则键名,response_code(状态码)或 content_type(内容类型)
		:param value: 要检查的字符串
		:return: True - 包含, False - 不包含
		"""
        if key in self.http_filter_json:
            for rule in self.http_filter_json[key]:
                if rule in value:
                    return True
        return False

    def run(self):
        """
		入口函数
		"""
        try:
            self.pktcap.apply_on_packets(self.proc_packet,
                                         timeout=self.timeout)
        except concurrent.futures.TimeoutError:
            print("\nTimeoutError.")

    def proc_packet(self, pkt):
        """
		全局数据包处理:识别、路由及结果发送
		:param pkt: 数据包
		:return: JSON or None
		"""
        try:
            pkt_json = None
            pkt_dict = dir(pkt)

            if 'ip' in pkt_dict:
                if 'http' in pkt_dict:
                    pkt_json = self.proc_http(pkt)
                elif 'tcp' in pkt_dict:
                    pkt_json = self.proc_tcp(pkt)

            if pkt_json:
                result = json.dumps(pkt_json)
                if self.debug:
                    print(result)
                self.work_queue.append(result)
        except:
            traceback.print_exc()

    def proc_http(self, pkt):
        """
		处理 HTTP 包
		:param pkt: 数据包
		:return: JSON or None
		"""
        http_dict = dir(pkt.http)

        if 'request' in http_dict and self.session_size:
            req = {
                'url':
                pkt.http.request_full_uri
                if 'request_full_uri' in http_dict else pkt.http.request_uri,
                'method':
                pkt.http.request_method
                if 'request_method' in http_dict else ''
            }

            self.http_stream_cache.set(pkt.tcp.stream, req)

        elif 'response' in http_dict:
            pkt_json = {}
            src_addr = pkt.ip.src
            src_port = pkt[pkt.transport_layer].srcport

            if self.session_size:
                cache_req = self.http_stream_cache.get(pkt.tcp.stream)
                if cache_req:
                    pkt_json['url'] = cache_req['url']
                    pkt_json['method'] = cache_req['method']
                    self.http_stream_cache.delete(pkt.tcp.stream)

            if 'url' not in pkt_json:
                if 'response_for_uri' in http_dict:
                    pkt_json["url"] = pkt.http.response_for_uri
                else:
                    pkt_json["url"] = '/'

            # 处理 URL 只有URI的情况
            if pkt_json["url"][0] == '/':
                if src_port == '80':
                    pkt_json["url"] = "http://%s%s" % (src_addr,
                                                       pkt_json["url"])
                else:
                    pkt_json["url"] = "http://%s:%s%s" % (src_addr, src_port,
                                                          pkt_json["url"])

            if self.cache_size:
                # 缓存机制,防止短时间大量处理重复响应
                exists = self.http_cache.get(pkt_json['url'])
                if exists:
                    return None

                self.http_cache.set(pkt_json["url"], True)

            pkt_json["pro"] = 'HTTP'
            pkt_json["tag"] = self.custom_tag
            pkt_json["ip"] = src_addr
            pkt_json["port"] = src_port

            if 'response_code' in http_dict:
                if self.http_filter_json:
                    return_status = self.http_filter('response_code',
                                                     pkt.http.response_code)
                    if return_status:
                        return None
                pkt_json["code"] = pkt.http.response_code

            if 'content_type' in http_dict:
                if self.http_filter_json:
                    return_status = self.http_filter('content_type',
                                                     pkt.http.content_type)
                    if return_status:
                        return None
                pkt_json["type"] = pkt.http.content_type.lower()
            else:
                pkt_json["type"] = 'unkown'

            if 'server' in http_dict:
                pkt_json["server"] = pkt.http.server

            # 开启深度数据分析,返回header和body等数据
            if self.return_deep_info:
                charset = 'utf-8'
                # 检测 Content-Type 中的编码信息
                if 'type' in pkt_json and 'charset=' in pkt_json["type"]:
                    charset = pkt_json["type"][pkt_json["type"].find('charset='
                                                                     ) +
                                               8:].strip().lower()
                    if not charset:
                        charset = 'utf-8'
                if 'payload' in dir(pkt.tcp):
                    payload = bytes.fromhex(
                        str(pkt.tcp.payload).replace(':', ''))
                    if payload.find(b'HTTP/') == 0:
                        split_pos = payload.find(b'\r\n\r\n')
                        if split_pos <= 0 or split_pos > 2048:
                            split_pos = 2048
                        pkt_json["header"] = str(payload[:split_pos], 'utf-8',
                                                 'ignore')
                        data = str(payload[split_pos + 4:], 'utf-8', 'ignore')
                if 'file_data' in http_dict and pkt.http.file_data.raw_value and pkt_json[
                        'type'] != 'application/octet-stream':
                    data = bytes.fromhex(pkt.http.file_data.raw_value)
                elif 'data' in http_dict:
                    data = bytes.fromhex(pkt.http.data)
                elif 'segment_data' in dir(pkt.tcp):
                    data = bytes.fromhex(pkt.tcp.segment_data.replace(":", ""))
                else:
                    data = ''

                if data:
                    # 检测页面 Meta 中的编码信息
                    data_head = data[:500] if data.find(
                        b'</head>', 0,
                        1024) == -1 else data[:data.find(b'</head>')]
                    match = self.encode_regex.search(data_head)
                    if match:
                        charset = str(
                            match.group(1).strip().lower(), 'utf-8', 'ignore')

                    response_body = proc_body_str(str(data, charset, 'ignore'),
                                                  16 * 1024)
                    # response_body = self.proc_body_json(str(data, charset, 'ignore'), 16*1024)
                    pkt_json["body"] = response_body
                else:
                    pkt_json["body"] = ''

            return pkt_json

        return None

    def proc_tcp(self, pkt):
        """
		处理 TCP 包
		:param pkt: 数据包
		:return: JSON or None
		"""
        tcp_stream = pkt.tcp.stream

        pkt_json = {}
        pkt_json["pro"] = 'TCP'
        pkt_json["tag"] = self.custom_tag

        # SYN+ACK
        if pkt.tcp.flags == '0x00000012':
            server_ip = pkt.ip.src
            server_port = pkt[pkt.transport_layer].srcport
            tcp_info = '%s:%s' % (server_ip, server_port)

            if self.cache_size:
                exists = self.tcp_cache.get(tcp_info)
                if exists:
                    return None
                self.tcp_cache.set(tcp_info, True)

            if self.return_deep_info and self.session_size:
                self.tcp_stream_cache.set(tcp_stream, tcp_info)
            else:
                pkt_json["ip"] = server_ip
                pkt_json["port"] = server_port

                return pkt_json

        # -r on开启深度数据分析,采集server第一个响应数据包
        if self.return_deep_info and pkt.tcp.seq == "1" and "payload" in dir(
                pkt.tcp) and self.session_size:
            tcp_info = self.tcp_stream_cache.get(tcp_stream)
            if tcp_info:
                # 防止误处理客户端发第一个包的情况
                src_host = '{}:{}'.format(pkt.ip.src,
                                          pkt[pkt.transport_layer].srcport)
                if tcp_info != src_host:
                    return None

                self.tcp_stream_cache.delete(tcp_stream)

                pkt_json["ip"] = pkt.ip.src
                pkt_json["port"] = pkt[pkt.transport_layer].srcport
                payload_data = pkt.tcp.payload.replace(":", "")
                if payload_data.startswith("48545450"):  # ^HTTP
                    return None

                # HTTPS Protocol
                # TODO: other https port support
                if pkt_json["port"] == "443" and payload_data.startswith(
                        "1603"):  # SSL
                    pkt_json["pro"] = 'HTTPS'
                    pkt_json["url"] = "https://{}/".format(pkt_json["ip"])
                else:
                    pkt_json["data"] = proc_data_str(payload_data, 16 * 1024)

                return pkt_json
        return None
예제 #19
0
class BasicData:
    """
        获取存储用于计算的基础数据,如股票,ETF,行情,行业分类。。。
        可以引用其他数据源进行操作,底层进行替换
    """
    def __init__(self, data_config_file="dataBase.yaml"):
        """
        初始化
        """
        self.mysql_conf = setup.MySql(data_config_file)
        self.engine = create_engine(self.mysql_conf.PyMySql_STR)
        self.db_session_factory = sessionmaker(bind=self.engine)
        self.app = setup.App()
        self.cache = Cache()

    def get_session(self):
        return self.db_session_factory()

    @staticmethod
    def get_axr_date(security):
        """
        获取指字标的的除权日期列表
        :param security:
        :return:
        """
        df = jq.finance.run_query(
            query(finance.STK_XR_XD.code, finance.STK_XR_XD.a_xr_date).filter(
                finance.STK_XR_XD.code == security,
                finance.STK_XR_XD.a_xr_date.isnot(None)).order_by(
                    finance.STK_XR_XD.a_xr_date.desc()).limit(5))
        datas = df['a_xr_date'].tolist()
        return datas

    def get_trade_days(self):
        """
        获取最近10天交易日
        :return:
        """
        if self.cache.get('trade_days') is None:
            self.cache.set(
                'trade_days',
                jq.get_trade_days(end_date=datetime.datetime.now(), count=10),
                21600)
        return self.cache.get('trade_days')

    def get_all_securities(self, types=[]):
        """
        获取全部股票信息,更新标的信息,没有的添加,有的看其是否已经被st,退市,进行更新。
        """
        flog.FinanceLoger.logger.info('证券信息更新开始...!')
        now = datetime.datetime.now()
        db_session = self.db_session_factory()
        list_screening = db_session.query(Setting).filter(
            Setting.name == 'security.down.last').first()
        list_date = datetime.datetime.strptime(list_screening.value,
                                               '%Y-%m-%d')
        day_count = (now - list_date).days
        if day_count < self.app.conf.conf['Update']['SecuritiesInterval']:
            return

        for x in types:
            res = jq.get_all_securities(types=x, date=None)
            i = 0
            for index, security in res.iterrows():
                s = index
                security_temp = db_session.query(YSecurity).filter(
                    YSecurity.security == s).first()
                if security_temp:
                    security_temp.security = s
                    security_temp.display_name = security["display_name"]
                    security_temp.name = security['name']
                    security_temp.start_date = security['start_date']
                    security_temp.end_date = security["end_date"]
                    security_temp.update_date = now.date()
                else:
                    security_temp = YSecurity(
                        security=s,
                        display_name=security["display_name"],
                        name=security["name"],
                        start_date=security["start_date"],
                        end_date=security["end_date"],
                        stype=security["type"],
                        status=0,
                        update_date=now.date())
                    db_session.add(security_temp)
                db_session.commit()
                i += 1
            flog.FinanceLoger.logger.info('本次标[{}]的更新完成,共更新{}条!'.format(x, i))
        list_screening.value = now.date().strftime('%Y-%m-%d')
        db_session.commit()
        db_session.close()
        flog.FinanceLoger.logger.info('证券信息更新结束...!')
        return

    def execute_sql(self, sql):
        """
        执行指定的sql语句
        :param sql:
        :return:
        """
        try:
            db_session = self.db_session_factory()
            db_session.execute(sql)
            db_session.commit()
        except Exception as e:
            flog.FinanceLoger.logger.error('excute sql:{0} error e-{1}'.format(
                sql, e))
            db_session.rollback()
        finally:
            db_session.close()
        flog.FinanceLoger.logger.debug('excute sql:{}'.format(sql))
        return

    def clean_data_by_table(self, table_name):
        """
        清理标的数据表
        """
        sql = 'truncate table {}'.format(table_name)
        self.execute_sql(sql)
        flog.FinanceLoger.logger.info(
            'truncate table {} success'.format(table_name))
        return

    def get_security_prices(self, security):
        """
        获取数据库中指定标的的行情数量和最后行情日期
        :param security:
        :return: tuple 总的行情数量 k_count, 最后行情日 last_trade_k
        """
        db_session = self.db_session_factory()
        k_count = db_session.query(func.count(
            KlineDay.kday)).filter(KlineDay.security == security).scalar()
        last_trade_k = db_session.query(KlineDay).filter(
            KlineDay.security == security).order_by(
                KlineDay.kday.desc()).first()
        db_session.close()
        return k_count, last_trade_k

    def get_day_price(self, security):
        """
        获取单只股票的指定时间段的前复权日线数据,可以单独执行
        """
        # today
        now = datetime.datetime.now()
        last_year_day = now - datetime.timedelta(days=366)
        scount, last_k = self.get_security_prices(security)
        xr_datas = self.get_axr_date(security)
        start_date = last_year_day.date()  # 默认下载一年数据

        if 180 > scount > 0:
            sql = "delete from kline_day where security = '{0}' ".format(
                security)
            self.execute_sql(sql)
        elif scount >= 180 and last_k is not None:
            local_data_date = last_k.kday
            start_date = local_data_date + datetime.timedelta(days=1)

        trade_days = self.get_trade_days()
        end_date = trade_days[-1]
        if now.date() == end_date:
            if now.hour < 15:
                end_date = end_date + datetime.timedelta(days=-1)
        if start_date > end_date:
            return

        # 除权日,全量下载
        if end_date in xr_datas:
            if scount > 0:
                sql = "delete from kline_day where security = '{0}' ".format(
                    security)
                self.execute_sql(sql)
            start_date = last_year_day.date()

        res = jq.get_price(security,
                           start_date=start_date,
                           end_date=end_date,
                           frequency='daily',
                           fields=[
                               'open', 'close', 'high', 'low', 'volume',
                               'money', 'factor', 'high_limit', 'low_limit',
                               'avg', 'pre_close', 'paused'
                           ],
                           skip_paused=True,
                           fq='pre')
        # 跳过停牌日行情,可能会下载不到数据
        if res.empty:
            return
        '''增加股票代码列'''
        res['security'] = security
        res['update_date'] = now.date()
        try:
            pymysql.install_as_MySQLdb()
            mysqlconnect = create_engine(self.mysql_conf.MYSQL_CON_STR)
            res.to_sql(name="kline_day",
                       con=mysqlconnect,
                       if_exists='append',
                       index=True,
                       index_label='kday',
                       chunksize=1000)
            new_count, _ = self.get_security_prices(security)
            if new_count > 240:
                # 清理老数据
                sql = "delete from kline_day where security = '{0}' and kday <= '{1}'".format(
                    security, str(last_year_day))
                self.execute_sql(sql)
        except Exception as e:
            flog.FinanceLoger.logger.error("更新行情时出错,标的:{},错误信息:{}".format(
                security, e))
        flog.FinanceLoger.logger.debug("更新了行情,标的:{}".format(security))
        return

    @staticmethod
    def verfiy_finance(security):
        """
        验证基本面
        :param security:
        :return: bool 验证是否通过
        """
        fund_df = jq.get_fundamentals(
            query(valuation, indicator).filter(valuation.code == security))

        fund_df = fund_df.fillna(value=100)
        if fund_df is None or fund_df.empty:
            flog.FinanceLoger.logger.info("标的{},获取不到财务数据".format(security))
            return False

        # and fund_df.iloc[0]["turnover_ratio"] > 0.01 and fund_df.iloc[0]["roe"] > 0.01 \
        #     and fund_df.iloc[0]["net_profit_margin"] > 5
        if fund_df.iloc[0]["market_cap"] > 80 and fund_df.iloc[0][
                "circulating_market_cap"] > 50:
            return True
        # fund_df.to_csv(security + '.csv')
        return False

    @staticmethod
    def get_finance(code=None):
        """
        获取指定财务条件的标的列表
        :return:
        """
        if not (code is None):
            q = query(valuation, indicator).filter(valuation.code == code)
        else:
            q = query(valuation.code, valuation.market_cap,
                      valuation.circulating_market_cap, indicator.roe,
                      indicator.gross_profit_margin).filter(
                          valuation.market_cap > 80,
                          valuation.circulating_market_cap > 50,
                          valuation.turnover_ratio > 0.1,
                          indicator.roe > 0.05).order_by(
                              # 按市值降序排列
                              valuation.market_cap.desc())
        # 取某行,某列的值 market_cap = df.iloc[0]['market_cap']
        return jq.get_fundamentals(q)

    def get_all_price(self, stype):
        """
        遍历全部股票,获取日线数据,当天更新,市场结束即更新
        :param stype: 标的类型 count 数量,获取120均线,必须最晚时间在240个行情数据
        :return:
        """
        if stype is None:
            return
        flog.FinanceLoger.logger.info("开始全部标的价格获取...")
        '''从本地数据库里获取全部股票信息,代码,上市日期,退市日期'''
        db_session = self.db_session_factory()
        for s in stype:
            securities = db_session.query(YSecurity).filter(
                YSecurity.type == s, YSecurity.status == 1).all()
            j = 0
            '''循环取出每个标的的行情数据'''
            for security in securities:
                self.get_day_price(security=security.security)
                j += 1
            flog.FinanceLoger.logger.info(
                "获取了指定标的类型{}的数据,共计拉取了{}条符合条件的价格信息".format(s, j))
        db_session.close()
        return

    def get_industries_store(self, name='jq_l2', date=None):
        """
        获取行业信息并存入数据
        :param name:
        :param date:
        :return:
        """
        res = jq.get_industries(name=name, date=None)
        '''增加类别列,行业分类者,聚宽,申万,国证'''
        res['type'] = name
        '''DataFrame入库'''
        pymysql.install_as_MySQLdb()
        mysqlconnect = create_engine(self.mysql_conf.MYSQL_CON_STR)
        res.to_sql(name='industries',
                   con=mysqlconnect,
                   if_exists='append',
                   index=True,
                   index_label='index',
                   chunksize=1000)
        flog.FinanceLoger.logger.info('所有行业信息已经保存成功')
        return

    def clean_industries(self):
        """
        清理行业信息表,表信息由于上游更新定期进行重置
        :return:
        """
        self.clean_data_by_table('industries')

    def get_swl1_daliy_price(self, date=None):
        """
        获取申万一级行业日行情
        申万行业行情每天18:00更新,这个最好是第二天下载
        :return:
        """
        '''从本地数据库里获取全部股票信息,代码,上市日期,退市日期'''
        sql = "select * from industries i2 where i2.`type` = 'sw_l1' "
        industries = self.get_df_by_sql(sql)
        if date is None:
            date = datetime.datetime.now().strftime('%Y-%m-%d')

        j = 0
        for i in range(0, len(industries)):
            industry = industries.iloc[i]['index']
            s = industry.decode("utf-8")

            res = finance.run_query(
                query(finance.SW1_DAILY_PRICE).filter(
                    finance.SW1_DAILY_PRICE.code == s
                    and finance.SW1_DAILY_PRICE.date <= date).order_by(
                        finance.SW1_DAILY_PRICE.date.desc()).limit(1))
            '''DataFrame入库'''
            pymysql.install_as_MySQLdb()
            mysqlconnect = create_engine(self.mysql_conf.MYSQL_CON_STR)
            try:
                res.to_sql(name='sw1_daily_price',
                           con=mysqlconnect,
                           if_exists='append',
                           index=False,
                           chunksize=1000)
                j += 1
            except Exception as e:
                flog.FinanceLoger.logger.error(
                    "获取申万一级行业的行情,存储数据库时出错,标的:{},出错信息:{}".format(s, e))
        flog.FinanceLoger.logger.info(
            "获取申万一级行业的行情信息,总计拉取了{}条符合条件的标的".format(j))
        return

    def screening_security(self, types):
        """
        筛选stock入库,置标识字段status为 1,标记后,下载行情时,进行判断 ,如果不足240的补足
        :return:
        """
        # 每30天执行一次基本面选标策略
        flog.FinanceLoger.logger.info('证券筛选更新开始...!')
        if types is None:
            return
        now = datetime.datetime.now()
        half_year_day = now - datetime.timedelta(days=180)
        db_session = self.db_session_factory()
        list_screening = db_session.query(Setting).filter(
            Setting.name == 'security.status.update.date').first()
        list_date = datetime.datetime.strptime(list_screening.value,
                                               '%Y-%m-%d')
        day_count = (now - list_date).days
        if day_count < self.app.conf.conf['Update']['SecuritiesInterval']:
            return

        for x in types:
            i, j = 0, 0
            securities = db_session.query(YSecurity).filter(
                YSecurity.type == x).all()
            for security in securities:
                flag_comm = security.end_date > datetime.datetime.now().date() and 'ST' not in security.display_name \
                            and security.start_date < half_year_day.date()
                flag = False
                # 不同标的的入选标识
                if x == 'stock':
                    flag = self.verfiy_finance(security.security) and flag_comm
                else:
                    flag = flag_comm
                # 总的入选标识
                if security.status == 1:
                    if flag:
                        state = ScreenState.Nothing
                    else:
                        state = ScreenState.Reject
                elif security.status == 0:
                    if not flag:
                        state = ScreenState.Nothing
                    else:
                        state = ScreenState.Add

                # 依据不同的入选标识,没有改变,增选,剔除,做不同的动作
                if state == ScreenState.Nothing:
                    continue
                elif state == ScreenState.Add:
                    security.status = 1
                    db_session.commit()
                    j += 1
                    # 下载240天数据
                    self.get_day_price(security.security)
                    flog.FinanceLoger.logger.debug(
                        "标的 {} - 代码 {}被增选为优质标的".format(security.display_name,
                                                       security.security))
                elif state == ScreenState.Reject:
                    security.status = 0
                    db_session.commit()
                    # 清理行情数据
                    self.execute_sql(
                        "delete from kline_day where `security` = '{}' ".
                        format(security.security))
                    self.execute_sql(
                        "delete from indicator_day where `security` = '{}' ".
                        format(security.security))
                    flog.FinanceLoger.logger.debug(
                        "标的 {} - 代码 {}被删除优质标的".format(security.display_name,
                                                      security.security))
                    i += 1
                    db_session.commit()
            flog.FinanceLoger.logger.info(
                "对于标的类型{}共有{}被剔除,总共有{}被选择........".format(x, i, j))
        list_screening.value = now.strftime('%Y-%m-%d')
        db_session.commit()
        db_session.close()
        flog.FinanceLoger.logger.info('证券筛选更新结束...!')
        return

    @staticmethod
    def get_industry_by_security(security):
        """
        stock所属版块行业信息
        :param security:
        :return:
        """
        d = jq.get_industry(security)
        return (d[security]['sw_l1']['industry_name'],
                d[security]['sw_l2']['industry_name'],
                d[security]['sw_l3']['industry_name'])

    def down_data(self, types):
        """
        自动执行下载数据
        :return:
        """
        self.get_all_price(types)
        # self.get_swl1_daliy_price(now)

    def clean_indicator(self):
        """
        清理指标数据表
        :return:
        """
        self.clean_data_by_table('indicator_day')

    def get_df_by_sql(self, sql):
        """
        通过指定的sql获取到dataframe
        :param sql:
        :return: dataframe
        """
        # 公用时这里可能有隐患,数据库及时打开和关闭的问题。
        # self.engine.open()
        # self.engine.close()
        df = pd.read_sql(sql,
                         self.engine,
                         index_col=None,
                         coerce_float=True,
                         params=None,
                         parse_dates=None,
                         columns=None,
                         chunksize=None)
        return df
예제 #20
0
class tcp_http_pcap():
    def __init__(self, pcap_collection_data, max_queue_size, work_queue,
                 interface, custom_tag, return_deep_info, http_filter_json,
                 cache_size, session_size, bpf_filter, timeout, debug):
        """
		构造函数
		:param max_queue_size: 资产队列最大长度
		:param work_queue: 捕获资产数据消息发送队列
		:param interface: 捕获流量的网卡名
		:param custom_tag: 数据标签,用于区分不同的采集引擎
		:param return_deep_info: 是否处理更多信息,包括原始请求、响应头和正文
		:param http_filter_json: HTTP过滤器配置,支持按状态和内容类型过滤
		:param cache_size: 缓存的已处理数据条数,120秒内重复的数据将不会重复采集
		:param session_size: 缓存的HTTP/TCP会话数量,30秒未使用的会话将被自动清除
		:param bpf_filter: 数据包底层过滤器
		:param timeout: 采集程序的运行超时时间,默认为启动后1小时自动退出
		:param debug: 调试开关
		"""
        self.pcap_collection_data = pcap_collection_data
        self.total_msg_num = 0
        self.max_queue_size = max_queue_size
        self.work_queue = work_queue
        self.debug = debug
        self.timeout = timeout
        self.bpf_filter = bpf_filter
        self.cache_size = cache_size
        self.session_size = session_size
        self.http_filter_json = http_filter_json
        self.return_deep_info = return_deep_info
        self.custom_tag = custom_tag
        self.interface = interface
        self.sniffer = pcap.pcap(self.interface,
                                 snaplen=65535,
                                 promisc=True,
                                 timeout_ms=self.timeout,
                                 immediate=False)
        self.sniffer.setfilter(self.bpf_filter)
        self.tcp_stream_cache = Cache(maxsize=self.session_size,
                                      ttl=30,
                                      timer=time.time,
                                      default=None)
        if self.cache_size:
            self.tcp_cache = LRUCache(maxsize=self.cache_size,
                                      ttl=120,
                                      timer=time.time,
                                      default=None)
            self.http_cache = LRUCache(maxsize=self.cache_size,
                                       ttl=120,
                                       timer=time.time,
                                       default=None)
        # http数据分析正则
        self.decode_request_regex = re.compile(
            r'^([A-Z]+) +([^ \r\n]+) +HTTP/\d+(?:\.\d+)?[^\r\n]*(.*?)$', re.S)
        self.decode_response_regex = re.compile(
            r'^HTTP/(\d+(?:\.\d+)?) (\d+)[^\r\n]*(.*?)$', re.S)
        self.decode_body_regex = re.compile(
            rb'<meta[^>]+?charset=[\'"]?([a-z\d\-]+)[\'"]?', re.I)

    def run(self):
        """
		入口函数
		"""
        for ts, pkt in self.sniffer:
            # self.total_msg_num += 1
            # if self.total_msg_num%1000 == 0:
            # 	print("Packet analysis rate: %s"%(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())+" - "+str(self.total_msg_num)))
            packet = self.pkt_decode(pkt)
            if not packet:
                continue

            # print('{}:{}->{}:{}: Seq:{}, Ack:{}, Flag: {}, Len: {}'.format(packet.src, packet.sport, packet.dst, packet.dport, packet.ack, packet.seq, packet.flags, len(packet.data)))
            cache_key = '{}:{}'.format(packet.src, packet.sport)
            # SYN & ACK
            if packet.flags == 0x12:
                if self.cache_size and self.tcp_cache.get(cache_key):
                    continue

                self.tcp_stream_cache.set('S_{}'.format(packet.ack),
                                          packet.seq + 1)

            # ACK || PSH-ACK
            elif packet.flags in [0x10, 0x18, 0x19]:
                # 长度为0的数据包不处理
                if len(packet.data) == 0:
                    continue

                # 第一个有数据的请求包,先缓存下来
                # Seq == SYN-ACK Ack
                pre_cs_seq = self.tcp_stream_cache.get('S_{}'.format(
                    packet.seq))
                if pre_cs_seq:
                    c_s_key = 'C_{}'.format(packet.ack)

                    self.tcp_stream_cache.set(c_s_key, packet.data)
                    self.tcp_stream_cache.delete('S_{}'.format(packet.seq))
                    continue

                # 1. 提取服务器主动响应的通讯,例如:MySQL
                # Seq == SYN-ACK Seq + 1
                if 'TCP' in self.pcap_collection_data:
                    pre_sc_seq = self.tcp_stream_cache.get('S_{}'.format(
                        packet.ack))
                    if pre_sc_seq == packet.seq:
                        self.tcp_stream_cache.delete('S_{}'.format(packet.ack))

                        # TCP瞬时重复处理
                        if self.cache_size:
                            self.tcp_cache.set(cache_key, True)

                        data = {
                            'pro': 'TCP',
                            'tag': self.custom_tag,
                            'ip': packet.src,
                            'port': packet.sport,
                            'data': packet.data.hex()
                        }
                        self.send_msg(data)
                        continue

                # 2. 提取需要请求服务器才会响应的通讯,例如:HTTP
                # Seq == PSH ACK(C->S) Ack
                send_data = self.tcp_stream_cache.get('C_{}'.format(
                    packet.seq))
                # 判断是否存在请求数据
                if send_data:
                    # 删除已使用的缓存
                    self.tcp_stream_cache.delete('C_{}'.format(packet.seq))

                    # HTTP通讯采集判断
                    if 'HTTP' in self.pcap_collection_data and packet.data[:
                                                                           5] == b'HTTP/':
                        request_dict = self.decode_request(
                            send_data, packet.src, str(packet.sport))
                        if not request_dict:
                            continue

                        http_cache_key = '{}:{}'.format(
                            request_dict['method'], request_dict['uri'])
                        if self.cache_size and self.http_cache.get(
                                http_cache_key):
                            continue

                        response_dict = self.decode_response(packet.data)
                        if response_dict:
                            # HTTP瞬时重复处理
                            if self.cache_size:
                                self.http_cache.set(http_cache_key, True)

                            response_code = response_dict['status']
                            content_type = response_dict['type']

                            # 根据响应状态码和页面类型进行过滤
                            if self.http_filter_json:
                                filter_code = self.http_filter(
                                    'response_code',
                                    response_code) if response_code else False
                                filter_type = self.http_filter(
                                    'content_type',
                                    content_type) if content_type else False
                                if filter_code or filter_type:
                                    continue

                            data = {
                                'pro': 'HTTP',
                                'tag': self.custom_tag,
                                'ip': packet.src,
                                'port': packet.sport,
                                'method': request_dict['method'],
                                'code': response_code,
                                'type': content_type,
                                'server': response_dict['server'],
                                'header': response_dict['headers'],
                                'url': request_dict['uri'],
                                'body': response_dict['body']
                            }

                            self.send_msg(data)
                            continue

                    # TCP通讯采集判断
                    elif 'TCP' in self.pcap_collection_data:
                        # TCP瞬时重复处理
                        if self.cache_size:
                            self.tcp_cache.set(cache_key, True)

                        # 2.2 非 HTTP 通讯
                        data = {
                            'pro': 'TCP',
                            'tag': self.custom_tag,
                            'ip': packet.src,
                            'port': packet.sport,
                            'data': packet.data.hex()
                        }
                        self.send_msg(data)

        self.sniffer.close()

    def http_filter(self, key, value):
        """
		检查字符串中是否包含特定的规则
		:param key: 规则键名,response_code(状态码)或 content_type(内容类型)
		:param value: 要检查的字符串
		:return: True - 包含, False - 不包含
		"""
        if key in self.http_filter_json:
            for rule in self.http_filter_json[key]:
                if rule in value:
                    return True
        return False

    def pkt_decode(self, pkt):
        try:
            ip_type = ''
            packet = dpkt.ethernet.Ethernet(pkt)
            if isinstance(packet.data, dpkt.ip.IP):
                ip_type = 'ip4'
            elif isinstance(packet.data, dpkt.ip6.IP6):
                ip_type = 'ip6'
            if ip_type and isinstance(packet.data.data, dpkt.tcp.TCP):
                if packet.data.data.flags == 0x12 or \
                 packet.data.data.flags in [0x10, 0x18, 0x19] and len(packet.data.data.data) > 0:
                    tcp_pkt = packet.data.data
                    if ip_type == 'ip4':
                        tcp_pkt.src = self.ip_addr(packet.data.src)
                        tcp_pkt.dst = self.ip_addr(packet.data.dst)
                    else:
                        tcp_pkt.src = self.ip6_addr(''.join(
                            ['%02X' % x for x in packet.data.src]))
                        tcp_pkt.dst = self.ip6_addr(''.join(
                            ['%02X' % x for x in packet.data.dst]))
                    return tcp_pkt
        except KeyboardInterrupt:
            print('\nExit.')
            os.kill(os.getpid(), signal.SIGKILL)
        except Exception as e:
            # print(str(e))
            # print(("".join(['%02X ' % b for b in pkt])))
            pass
        return None

    def ip_addr(self, ip):
        return '%d.%d.%d.%d' % tuple(ip)

    def ip6_addr(self, ip6):
        ip6_addr = ''
        ip6_list = re.findall(r'.{4}', ip6)
        for i in range(len(ip6_list)):
            ip6_addr += ':%s' % (ip6_list[i].lstrip('0')
                                 if ip6_list[i].lstrip('0') else '0')
        return ip6_addr.lstrip(':')

    def decode_request(self, data, sip, sport):
        pos = data.find(b'\r\n\r\n')
        body = data[pos + 4:] if pos > 0 else b''
        data_str = str(data[:pos] if pos > 0 else data, 'utf-8', 'ignore')
        m = self.decode_request_regex.match(data_str)
        if m:
            if m.group(2)[:1] != '/':
                return None

            headers = m.group(3).strip() if m.group(3) else ''
            header_dict = self.parse_headers(headers)
            host_domain = ''
            # host domain
            if 'host' in header_dict and re.search('[a-zA-Z]',
                                                   header_dict['host']):
                host_domain = header_dict['host']
            # host ip
            else:
                host_domain = sip + ':' + sport if sport != '80' else sip
            url = 'http://{}{}'.format(
                host_domain, m.group(2)) if host_domain else m.group(2)

            return {
                'method': m.group(1) if m.group(1) else '',
                'uri': url,
                'headers': headers,
                'body': str(body, 'utf-8', 'ignore')
            }

        return {
            'method':
            '',
            'uri':
            'http://{}:{}/'.format(sip if ':' not in sip else '[' + sip + ']',
                                   sport),
            'headers':
            '',
            'body':
            ''
        }

    def decode_response(self, data):
        pos = data.find(b'\r\n\r\n')
        body = data[pos + 4:] if pos > 0 else b''
        header_str = str(data[:pos] if pos > 0 else data, 'utf-8', 'ignore')
        m = self.decode_response_regex.match(header_str)
        if m:
            headers = m.group(3).strip() if m.group(3) else ''
            headers_dict = self.parse_headers(headers)
            if self.return_deep_info and 'transfer-encoding' in headers_dict and headers_dict[
                    'transfer-encoding'] == 'chunked':
                body = self.decode_chunked(body)

            if self.return_deep_info and 'content-encoding' in headers_dict:
                if headers_dict['content-encoding'] == 'gzip':
                    body = self.decode_gzip(body)
                elif headers_dict['content-encoding'] == 'br':
                    body = self.decode_brotli(body)

            content_type = '' if 'content-type' not in headers_dict else headers_dict[
                'content-type']
            server = '' if 'server' not in headers_dict else headers_dict[
                'server']
            return {
                'version': m.group(1) if m.group(1) else '',
                'status': m.group(2) if m.group(2) else '',
                'headers': headers,
                'type': content_type,
                'server': server,
                'body': self.decode_body(body, content_type)
            }

        return None

    def decode_gzip(self, data):
        '''
		还原 HTTP 响应中采用 gzip 压缩的数据
		标识:
		Content-Encoding: gzip
		'''
        try:
            buf = io.BytesIO(data)
            gf = gzip.GzipFile(fileobj=buf)
            content = gf.read()
            gf.close()

            return content
        except:
            return data

    def decode_brotli(self, data):
        '''
		还原 HTTP 响应中采用 brotli 压缩的数据
		标识:
		Content-Encoding: br
		'''
        try:
            return brotli.decompress(data)
        except:
            return data

    def decode_chunked(self, data):
        '''
		还原 HTTP 响应中被 Chunked 的数据
		示例:
		Transfer-Encoding: chunked

		1b
		{"ret":0, "messge":"error"}
		'''
        line_end = data.find(b'\r\n')
        if line_end > 0:
            data_len = -1
            try:
                data_len = int(data[:line_end], 16)
                if data_len == 0:
                    return b''

                if data_len > 0:
                    new_data = data[line_end + 2:line_end + 2 + data_len]
                    return new_data + self.decode_chunked(
                        data[line_end + 2 + data_len + 2:])
            except:
                return data

        return data

    def decode_body(self, data, content_type):
        charset_white_list = [
            'big5', 'big5-hkscs', 'cesu-8', 'euc-jp', 'euc-kr', 'gb18030',
            'gb2312', 'gbk', 'ibm-thai', 'ibm00858', 'ibm01140', 'ibm01141',
            'ibm01142', 'ibm01143', 'ibm01144', 'ibm01145', 'ibm01146',
            'ibm01147', 'ibm01148', 'ibm01149', 'ibm037', 'ibm1026', 'ibm1047',
            'ibm273', 'ibm277', 'ibm278', 'ibm280', 'ibm284', 'ibm285',
            'ibm290', 'ibm297', 'ibm420', 'ibm424', 'ibm437', 'ibm500',
            'ibm775', 'ibm850', 'ibm852', 'ibm855', 'ibm857', 'ibm860',
            'ibm861', 'ibm862', 'ibm863', 'ibm864', 'ibm865', 'ibm866',
            'ibm868', 'ibm869', 'ibm870', 'ibm871', 'ibm918',
            'iso-10646-ucs-2', 'iso-2022-cn', 'iso-2022-jp', 'iso-2022-jp-2',
            'iso-2022-kr', 'iso-8859-1', 'iso-8859-10', 'iso-8859-13',
            'iso-8859-15', 'iso-8859-16', 'iso-8859-2', 'iso-8859-3',
            'iso-8859-4', 'iso-8859-5', 'iso-8859-6', 'iso-8859-7',
            'iso-8859-8', 'iso-8859-9', 'jis_x0201', 'jis_x0212-1990',
            'koi8-r', 'koi8-u', 'shift_jis', 'tis-620', 'us-ascii', 'utf-16',
            'utf-16be', 'utf-16le', 'utf-32', 'utf-32be', 'utf-32le', 'utf-8',
            'windows-1250', 'windows-1251', 'windows-1252', 'windows-1253',
            'windows-1254', 'windows-1255', 'windows-1256', 'windows-1257',
            'windows-1258', 'windows-31j', 'x-big5-hkscs-2001',
            'x-big5-solaris', 'x-euc-jp-linux', 'x-euc-tw', 'x-eucjp-open',
            'x-ibm1006', 'x-ibm1025', 'x-ibm1046', 'x-ibm1097', 'x-ibm1098',
            'x-ibm1112', 'x-ibm1122', 'x-ibm1123', 'x-ibm1124', 'x-ibm1166',
            'x-ibm1364', 'x-ibm1381', 'x-ibm1383', 'x-ibm300', 'x-ibm33722',
            'x-ibm737', 'x-ibm833', 'x-ibm834', 'x-ibm856', 'x-ibm874',
            'x-ibm875', 'x-ibm921', 'x-ibm922', 'x-ibm930', 'x-ibm933',
            'x-ibm935', 'x-ibm937', 'x-ibm939', 'x-ibm942', 'x-ibm942c',
            'x-ibm943', 'x-ibm943c', 'x-ibm948', 'x-ibm949', 'x-ibm949c',
            'x-ibm950', 'x-ibm964', 'x-ibm970', 'x-iscii91',
            'x-iso-2022-cn-cns', 'x-iso-2022-cn-gb', 'x-iso-8859-11',
            'x-jis0208', 'x-jisautodetect', 'x-johab', 'x-macarabic',
            'x-maccentraleurope', 'x-maccroatian', 'x-maccyrillic',
            'x-macdingbat', 'x-macgreek', 'x-machebrew', 'x-maciceland',
            'x-macroman', 'x-macromania', 'x-macsymbol', 'x-macthai',
            'x-macturkish', 'x-macukraine', 'x-ms932_0213', 'x-ms950-hkscs',
            'x-ms950-hkscs-xp', 'x-mswin-936', 'x-pck', 'x-sjis',
            'x-sjis_0213', 'x-utf-16le-bom', 'x-utf-32be-bom',
            'x-utf-32le-bom', 'x-windows-50220', 'x-windows-50221',
            'x-windows-874', 'x-windows-949', 'x-windows-950',
            'x-windows-iso2022jp'
        ]
        content_type = content_type.lower() if content_type else ''
        if 'charset=' in content_type:
            charset = content_type[content_type.find('charset=') +
                                   8:].strip('" ;\r\n').lower()
            if charset != 'iso-8859-1' and charset in charset_white_list:
                return str(data, charset, 'ignore')

        m = self.decode_body_regex.match(data)
        if m:
            charset = m.group(1).lower() if m.group(1) else ''
            if charset != 'iso-8859-1' and charset in charset_white_list:
                return str(data, charset, 'ignore')

        return str(data, 'utf-8', 'ignore')

    def parse_headers(self, data):
        headers = {}
        lines = data.split('\r\n')
        for _ in lines:
            pos = _.find(':')
            if pos > 0:
                headers[_[:pos].lower()] = _[pos + 1:].strip()
        return headers

    def send_msg(self, data):
        result = json.dumps(data)
        if self.debug:
            print(result)
        if len(self.work_queue) >= self.max_queue_size * 0.95:
            self.work_queue.clear()
        self.work_queue.append(result)
예제 #21
0
class tcp_http_sniff():

	def __init__(self,interface,display_filter,syslog_ip,syslog_port,custom_tag,return_deep_info,filter_rules,cache_size,bpf_filter,timeout,debug):
		self.debug = debug
		self.timeout = timeout
		self.bpf_filter = bpf_filter
		self.cache_size = cache_size
		self.filter_rules = filter_rules
		self.return_deep_info = return_deep_info
		self.custom_tag = custom_tag
		self.syslog_ip = syslog_ip
		self.syslog_port = syslog_port
		self.log_obj = _logging(self.syslog_ip,self.syslog_port)
		self.interface = interface
		self.display_filter = display_filter
		self.pktcap = pyshark.LiveCapture(interface=self.interface, bpf_filter=self.bpf_filter, use_json=True, display_filter=self.display_filter, debug=self.debug)
		self.http_cache = Cache(maxsize=self.cache_size, ttl=120, timer=time.time, default=None)
		self.tcp_cache = Cache(maxsize=self.cache_size, ttl=120, timer=time.time, default=None)
		# 检测页面编码的正则表达式
		self.encode_regex = re.compile(b'<meta [^>]*?charset=["\']?([^"\'\s]+)["\']?', re.I)

	# 根据response_code和content_type过滤
	def http_filter(self,key,value):
		if key in self.filter_rules:
			for rule in self.filter_rules[key]:
				if rule in value:
					return True
		return False
	
	def run(self):
		try:
			self.pktcap.apply_on_packets(self.proc_packet,timeout=self.timeout)
		except concurrent.futures.TimeoutError:
			print("\nTimeoutError.")
	def proc_packet(self, pkt):
		try:
			pkt_json = None
			pkt_dict = dir(pkt)
			
			if 'ip' in pkt_dict:
				if 'http' in pkt_dict:
					pkt_json = self.proc_http(pkt)
				elif 'tcp' in pkt_dict:
					pkt_json = self.proc_tcp(pkt)

			if pkt_json:
				if self.debug:
					print(json.dumps(pkt_json))
				self.log_obj.info(json.dumps(pkt_json))

		except Exception:
			traceback.format_exc()
			# error_log_json = {}
			# error_log_json["custom_tag"] = self.custom_tag
			# error_log_json["error_log"] = str(traceback.format_exc())
			# if self.debug:
			# 	print(json.dumps(error_log_json))
			# self.log_obj.error(json.dumps(error_log_json))
	
	def proc_http(self, pkt):
		http_dict = dir(pkt.http)
		
		if self.return_deep_info:
			if 'request' in http_dict:
				self.http_cache.set(pkt.tcp.stream, pkt.http.request_full_uri if 'request_full_uri' in http_dict else pkt.http.request_uri)
		
		if 'response' in http_dict:
			pkt_json = {}
			src_addr = pkt.ip.src
			src_port = pkt[pkt.transport_layer].srcport
			
			cache_url = self.http_cache.get(pkt.tcp.stream)
			if cache_url:
				pkt_json['url'] = cache_url
				self.http_cache.delete(pkt.tcp.stream)
			
			if 'url' not in pkt_json:
				if 'response_for_uri' in http_dict:
					pkt_json["url"] = pkt.http.response_for_uri
				else:
					pkt_json["url"] = '/'

			# 处理 URL 只有URI的情况
			if pkt_json["url"][0] == '/':
				if src_port == '80':
					pkt_json["url"] = "http://%s%s"%(src_addr,pkt_json["url"])
				else:
					pkt_json["url"] = "http://%s:%s%s"%(src_addr,src_port,pkt_json["url"])

			# 缓存机制,防止短时间大量处理重复响应
			exists = self.http_cache.get(pkt_json['url'])
			if exists:
				return None

			self.http_cache.set(pkt_json["url"], True)

			pkt_json["pro"] = 'HTTP'
			pkt_json["tag"] = self.custom_tag
			pkt_json["ip"] = src_addr
			pkt_json["port"] = src_port

			if 'response_code' in http_dict:
				if self.filter_rules:
					return_status = self.http_filter('response_code', pkt.http.response_code)
					if return_status:
						return None
				pkt_json["code"] = pkt.http.response_code
			
			if 'content_type' in http_dict:
				if self.filter_rules:
					return_status = self.http_filter('content_type', pkt.http.content_type)
					if return_status:
						return None
				pkt_json["type"] = pkt.http.content_type.lower()
			else:
				pkt_json["type"] = 'unkown'

			if 'server' in http_dict:
				pkt_json["server"] = pkt.http.server

			# -r on开启深度数据分析,返回header和body等数据
			if self.return_deep_info:
				charset = 'utf-8'
				# 检测 Content-Type 中的编码信息
				if 'type' in pkt_json and 'charset=' in pkt_json["type"]:
					charset = pkt_json["type"][pkt_json["type"].find('charset=')+8:].strip().lower()
					if not charset :
						charset = 'utf-8'
				if 'payload' in dir(pkt.tcp):
					payload = bytes.fromhex(str(pkt.tcp.payload).replace(':', ''))
					if payload.find(b'HTTP/') == 0:
						split_pos = payload.find(b'\r\n\r\n')
						if split_pos <= 0 or split_pos > 4096:
							split_pos = 4096
						pkt_json["header"] = str(payload[:split_pos], 'utf-8', 'ignore')

				if 'file_data' in http_dict and pkt.http.file_data.raw_value and pkt_json['type'] != 'application/octet-stream':
					data = bytes.fromhex(pkt.http.file_data.raw_value)
					# 检测页面 Meta 中的编码信息
					data_head = data[:500] if data.find(b'</head>', 0, 1024) == -1 else data[:data.find(b'</head>')]
					match = self.encode_regex.search(data_head)
					if match:
						charset = str(match.group(1).lower(), 'utf-8', 'ignore')
					response_body = self.proc_body(str(data, charset, 'ignore'), 16*1024)
					pkt_json["body"] = response_body
				else:
					pkt_json["body"] = ''
			
			return pkt_json
		
		return None

	def proc_tcp(self, pkt):
		tcp_stream = pkt.tcp.stream
		
		pkt_json = {}
		pkt_json["pro"] = 'TCP'
		pkt_json["tag"] = self.custom_tag

		# SYN+ACK
		if pkt.tcp.flags == '0x00000012' : 
			server_ip = pkt.ip.src
			server_port = pkt[pkt.transport_layer].srcport
			tcp_info = '%s:%s' % (server_ip, server_port)

			exists = self.tcp_cache.get(tcp_info)
			if exists:
				return None
				
			if self.return_deep_info and tcp_info:
				self.tcp_cache.set(tcp_stream, tcp_info)
				self.tcp_cache.set(tcp_info,True)
			else:
				pkt_json["ip"] = server_ip
				pkt_json["port"] = server_port
				self.tcp_cache.set(tcp_info,True)
				return pkt_json
		
		# -r on开启深度数据分析,采集server第一个响应数据包
		if self.return_deep_info and pkt.tcp.seq == "1" and "payload" in dir(pkt.tcp) :
			tcp_info = self.tcp_cache.get(tcp_stream)
			if tcp_info:
				tcp_info_list = tcp_info.split(":")
				tcp_ip = tcp_info_list[0]
				tcp_port = tcp_info_list[1]
				pkt_json["ip"] = tcp_ip
				pkt_json["port"] = tcp_port
				payload_data = pkt.tcp.payload.replace(":","")
				if payload_data.startswith("48545450"):
					return None
				# HTTPS Protocol
				# TODO: other https port support 
				if tcp_port == "443" and payload_data.startswith("1603"):
					pkt_json["pro"] = 'HTTPS'
					pkt_json["url"] = "https://%s/"%(tcp_ip)
				else:
					pkt_json["data"] = payload_data
				self.tcp_cache.delete(tcp_stream)
				return pkt_json
		return None

	def proc_body(self, data, length):
		json_data = json.dumps(data)[:length]
		total_len = len(json_data)
		if total_len < length:
			return data
		
		pos = json_data.rfind("\\u")
		if pos + 6 > len(json_data):
			json_data = json_data[:pos]
		
		return json.loads(json_data + '"')