Пример #1
0
    def __init__(
            self, account_id, init_balance,
            trade_class: Union[Type[SimTrade], Type[SimTradeStock]]) -> None:
        self._account_id = account_id
        super(BaseSim, self).__init__()

        self.trade_log = {}  # 日期->交易记录及收盘时的权益及持仓
        self.tqsdk_stat = {}  # 回测结束后储存回测报告信息
        self._init_balance = init_balance
        self._current_datetime = "1990-01-01 00:00:00.000000"  # 当前行情时间(最新的 quote 时间)
        self._trading_day_end = "1990-01-01 18:00:00.000000"
        self._local_time_record = float("nan")  # 记录获取最新行情时的本地时间
        self._sim_trade = trade_class(
            account_key=self._account_key,
            account_id=self._account_id,
            init_balance=self._init_balance,
            get_trade_timestamp=self._get_trade_timestamp,
            is_in_trading_time=self._is_in_trading_time)
        self._data = Entity()
        self._data._instance_entity([])
        self._prototype = {
            "quotes": {
                "#": Quote(self),  # 行情的数据原型
            }
        }
        self._quote_tasks = {}
Пример #2
0
 def __init__(self, logger):
     self._logger = logger
     self._resend_request = {}  # 重连时需要重发的请求
     self._un_processed = False  # 重连后尚未处理完标志
     self._pending_diffs = []
     self._data = Entity()
     self._data._instance_entity([])
Пример #3
0
 async def _run(self, api, sim_send_chan, sim_recv_chan, md_send_chan,
                md_recv_chan):
     """回测task"""
     self._api = api
     self._logger = api._logger.getChild("TqBacktest")  # 调试信息输出
     self._sim_send_chan = sim_send_chan
     self._sim_recv_chan = sim_recv_chan
     self._md_send_chan = md_send_chan
     self._md_recv_chan = md_recv_chan
     self._pending_peek = False
     self._data = Entity()  # 数据存储
     self._data._instance_entity([])
     self._serials = {}  # 所有原始数据序列
     self._quotes = {}
     self._diffs = []
     self._is_first_send = True
     md_task = self._api.create_task(self._md_handler())
     try:
         await self._send_snapshot()
         async for pack in self._sim_send_chan:
             self._logger.debug("TqBacktest message received: %s", pack)
             if pack["aid"] == "subscribe_quote":
                 self._diffs.append({"ins_list": pack["ins_list"]})
                 for ins in pack["ins_list"].split(","):
                     await self._ensure_quote(ins)
                 await self._send_diff()  # 处理上一次未处理的 peek_message
             elif pack["aid"] == "set_chart":
                 if pack["ins_list"]:
                     # 回测模块中已保证每次将一个行情时间的数据全部发送给api,因此更新行情时 保持与初始化时一样的charts信息(即不作修改)
                     self._diffs.append({
                         "charts": {
                             pack["chart_id"]: {
                                 # 两个id设置为0:保证api在回测中判断此值时不是-1,即直接通过对数据接收完全的验证
                                 "left_id": 0,
                                 "right_id": 0,
                                 "more_data":
                                 False,  # 直接发送False给api,表明数据发送完全,使api中通过数据接收完全的验证
                                 "state": pack
                             }
                         }
                     })
                     await self._ensure_serial(pack["ins_list"],
                                               pack["duration"])
                 else:
                     self._diffs.append(
                         {"charts": {
                             pack["chart_id"]: None
                         }})
                 await self._send_diff()  # 处理上一次未处理的 peek_message
             elif pack["aid"] == "peek_message":
                 self._pending_peek = True
                 await self._send_diff()
     finally:
         # 关闭所有serials
         for s in self._serials.values():
             await s["generator"].aclose()
         md_task.cancel()
         await asyncio.gather(md_task, return_exceptions=True)
Пример #4
0
 async def _run(self, api, api_send_chan, api_recv_chan, ws_send_chan, ws_recv_chan):
     self._api = api
     send_task = self._api.create_task(self._send_handler(api_send_chan, ws_send_chan))
     try:
         async for pack in ws_recv_chan:
             self._record_upper_data(pack)
             if self._un_processed:  # 处理重连后数据
                 pack_data = pack.get("data", [])
                 self._pending_diffs.extend(pack_data)
                 for d in pack_data:
                     # _merge_diff 之后, self._data 会用于判断是否接收到了完整截面数据
                     _merge_diff(self._data, d, self._api._prototype, persist=False, reduce_diff=False)
                 if self._is_all_received():
                     # 重连后收到完整数据截面
                     self._un_processed = False
                     pack = {
                         "aid": "rtn_data",
                         "data": self._pending_diffs
                     }
                     await api_recv_chan.send(pack)
                     self._logger = self._logger.bind(status=self._status)
                     self._logger.debug("data completed", pack=pack)
                 else:
                     await ws_send_chan.send({"aid": "peek_message"})
                     self._logger.debug("wait for data completed", pack={"aid": "peek_message"})
             else:
                 is_reconnected = False
                 for i in range(len(pack.get("data", []))):
                     for _, notify in pack["data"][i].get("notify", {}).items():
                         if notify["code"] == 2019112902:  # 重连建立
                             is_reconnected = True
                             self._un_processed = True
                             self._logger = self._logger.bind(status=self._status)
                             if i > 0:
                                 ws_send_chan.send_nowait({
                                     "aid": "rtn_data",
                                     "data": pack.get("data", [])[0:i]
                                 })
                             self._pending_diffs = pack.get("data", [])[i:]
                             break
                 if is_reconnected:
                     self._data = Entity()
                     self._data._instance_entity([])
                     for d in self._pending_diffs:
                         _merge_diff(self._data, d, self._api._prototype, persist=False, reduce_diff=False)
                     # 发送所有 resend_request
                     for msg in self._resend_request.values():
                         # 这里必须用 send_nowait 而不是 send,因为如果使用异步写法,在循环中,代码可能执行到 send_task, 可能会修改 _resend_request
                         ws_send_chan.send_nowait(msg)
                         self._logger.debug("resend request", pack=msg)
                     await ws_send_chan.send({"aid": "peek_message"})
                 else:
                     await api_recv_chan.send(pack)
     finally:
         send_task.cancel()
         await asyncio.gather(send_task, return_exceptions=True)
Пример #5
0
def _get_obj(root, path, default=None):
    """获取业务数据"""
    d = root
    for i in range(len(path)):
        if path[i] not in d:
            if i != len(path) - 1 or default is None:
                dv = Entity()
            else:
                dv = copy.copy(default)
            dv._instance_entity(d["_path"] + [path[i]])
            d[path[i]] = dv
        d = d[path[i]]
    return d
Пример #6
0
def lots_at_price(order: Entity, p: float) -> int:
    order_id: str
    order: Order
    lots: int = 0
    for order_id, order in order.items():
        if order.limit_price == p:
            lots += order.volume_left
    return lots
Пример #7
0
 def __init__(self, api):
     self._api = api
     self._data = Entity()  # 交易业务信息截面,需要定于数据原型,使用 Entity 类型 和 _merge_diff
     self._data._instance_entity([])
     self._new_objs_list = []
     self._prototype = {
         "trade": {
             "*": {
                 "@": CustomDict(self._api, self._new_objs_list)
             }
         }
     }
     self._data_quotes = {}  # 行情信息截面,只需要 quotes 数据。这里不需要定义数据原型,使用普通 dict 和 _simple_merge_diff
     self._diffs = []
     self._all_trade_symbols = set()  # 所有持仓、委托、成交中的合约
     self._query_symbols = set()  # 已经发送合约信息请求 + 已经知道合约信息的合约
     self._need_wait_symbol_info = set()  # 需要发送合约信息请求 + 不知道合约信息的合约
Пример #8
0
    def __init__(self, init_balance: float = 10000000.0, account_id: str = None) -> None:
        """
        Args:
            init_balance (float): [可选]初始资金, 默认为一千万

            account_id (str): [可选]帐号, 默认为 TQSIM

        Example::

            # 修改TqSim模拟帐号的初始资金为100000
            from tqsdk import TqApi, TqSim, TqAuth
            api = TqApi(TqSim(init_balance=100000), auth=TqAuth("信易账户", "账户密码"))

        """
        self.trade_log = {}  # 日期->交易记录及收盘时的权益及持仓
        self.tqsdk_stat = {}  # 回测结束后储存回测报告信息
        self._account_id = "TQSIM" if account_id is None else account_id
        self._account_type = "FUTURE"
        self._broker_id = "TQSIM" if self._account_type == "FUTURE" else "TQSIM_STOCK"
        self._account_key = str(id(self))
        self._init_balance = float(init_balance)
        if self._init_balance <= 0:
            raise Exception("初始资金(init_balance) %s 错误, 请检查 init_balance 是否填写正确" % (init_balance))
        self._current_datetime = "1990-01-01 00:00:00.000000"  # 当前行情时间(最新的 quote 时间)
        self._trading_day_end = "1990-01-01 18:00:00.000000"
        self._local_time_record = float("nan")  # 记录获取最新行情时的本地时间
        self._sim_trade = SimTrade(account_key=self._account_key, init_balance=self._init_balance,
                                     get_trade_timestamp=self._get_trade_timestamp,
                                     is_in_trading_time=self._is_in_trading_time)
        self._data = Entity()
        self._data._instance_entity([])
        self._prototype = {
            "quotes": {
                "#": Quote(self),  # 行情的数据原型
            }
        }
        self._quote_tasks = {}
Пример #9
0
 def __init__(self, api):
     self._api = api
     self._data = Entity()  # 业务信息截面
     self._data._instance_entity([])
     self._diffs = []
     self._all_subscribe = set()
Пример #10
0
class TqBacktest(object):
    """
    天勤回测类

    将该类传入 TqApi 的构造函数, 则策略就会进入回测模式。

    回测模式下 k线会在刚创建出来时和结束时分别更新一次, 在这之间 k线是不会更新的。

    回测模式下 quote 的更新频率由所订阅的 tick 和 k线周期确定:
        * 只要订阅了 tick, 则对应合约的 quote 就会使用 tick 生成, 更新频率也和 tick 一致, 但 **只有下字段** :
          datetime/ask&bid_price1/ask&bid_volume1/last_price/highest/lowest/average/volume/amount/open_interest/
          price_tick/price_decs/volume_multiple/max&min_limit&market_order_volume/underlying_symbol/strike_price

        * 如果没有订阅 tick, 但是订阅了 k线, 则对应合约的 quote 会使用 k线生成, 更新频率和 k线的周期一致, 如果订阅了某个合约的多个周期的 k线,
          则任一个周期的 k线有更新时, quote 都会更新. 使用 k线生成的 quote 的盘口由收盘价分别加/减一个最小变动单位, 并且 highest/lowest/average/amount
          始终为 nan, volume 始终为0

        * 如果即没有订阅 tick, 也没有订阅k线或 订阅的k线周期大于分钟线, 则 TqBacktest 会 **自动订阅分钟线** 来生成 quote

    **注意** :如果未订阅 quote,模拟交易在下单时会自动为此合约订阅 quote ,根据回测时 quote 的更新规则,如果此合约没有订阅K线或K线周期大于分钟线 **则会自动订阅一个分钟线** 。

    模拟交易要求报单价格大于等于对手盘价格才会成交, 例如下买单, 要求价格大于等于卖一价才会成交, 如果不能立即成交则会等到下次行情更新再重新判断。

    回测模式下 wait_update 每次最多推进一个行情时间。

    回测结束后会抛出 BacktestFinished 例外。

    对 **组合合约** 进行回测时需注意:只能通过订阅 tick 数据来回测,不能订阅K线,因为K线是由最新价合成的,而交易所发回的组合合约数据中无最新价。
    """
    def __init__(self, start_dt: Union[date, datetime],
                 end_dt: Union[date, datetime]) -> None:
        """
        创建天勤回测类

        Args:
            start_dt (date/datetime): 回测起始时间, 如果类型为 date 则指的是交易日, 如果为 datetime 则指的是具体时间点

            end_dt (date/datetime): 回测结束时间, 如果类型为 date 则指的是交易日, 如果为 datetime 则指的是具体时间点
        """
        if isinstance(start_dt, datetime):
            self._start_dt = int(start_dt.timestamp() * 1e9)
        elif isinstance(start_dt, date):
            self._start_dt = _get_trading_day_start_time(
                int(
                    datetime(start_dt.year, start_dt.month,
                             start_dt.day).timestamp()) * 1000000000)
        else:
            raise Exception(
                "回测起始时间(start_dt)类型 %s 错误, 请检查 start_dt 数据类型是否填写正确" %
                (type(start_dt)))
        if isinstance(end_dt, datetime):
            self._end_dt = int(end_dt.timestamp() * 1e9)
        elif isinstance(end_dt, date):
            self._end_dt = _get_trading_day_end_time(
                int(
                    datetime(end_dt.year, end_dt.month,
                             end_dt.day).timestamp()) * 1000000000)
        else:
            raise Exception("回测结束时间(end_dt)类型 %s 错误, 请检查 end_dt 数据类型是否填写正确" %
                            (type(end_dt)))
        self._current_dt = self._start_dt

    async def _run(self, api, sim_send_chan, sim_recv_chan, md_send_chan,
                   md_recv_chan):
        """回测task"""
        self._api = api
        self._logger = api._logger.getChild("TqBacktest")  # 调试信息输出
        self._sim_send_chan = sim_send_chan
        self._sim_recv_chan = sim_recv_chan
        self._md_send_chan = md_send_chan
        self._md_recv_chan = md_recv_chan
        self._pending_peek = False
        self._data = Entity()  # 数据存储
        self._data._instance_entity([])
        self._serials = {}  # 所有原始数据序列
        self._quotes = {}
        self._diffs = []
        self._is_first_send = True
        md_task = self._api.create_task(self._md_handler())
        try:
            await self._send_snapshot()
            async for pack in self._sim_send_chan:
                self._logger.debug("TqBacktest message received: %s", pack)
                if pack["aid"] == "subscribe_quote":
                    self._diffs.append({"ins_list": pack["ins_list"]})
                    for ins in pack["ins_list"].split(","):
                        await self._ensure_quote(ins)
                    await self._send_diff()  # 处理上一次未处理的 peek_message
                elif pack["aid"] == "set_chart":
                    if pack["ins_list"]:
                        # 回测模块中已保证每次将一个行情时间的数据全部发送给api,因此更新行情时 保持与初始化时一样的charts信息(即不作修改)
                        self._diffs.append({
                            "charts": {
                                pack["chart_id"]: {
                                    # 两个id设置为0:保证api在回测中判断此值时不是-1,即直接通过对数据接收完全的验证
                                    "left_id": 0,
                                    "right_id": 0,
                                    "more_data":
                                    False,  # 直接发送False给api,表明数据发送完全,使api中通过数据接收完全的验证
                                    "state": pack
                                }
                            }
                        })
                        await self._ensure_serial(pack["ins_list"],
                                                  pack["duration"])
                    else:
                        self._diffs.append(
                            {"charts": {
                                pack["chart_id"]: None
                            }})
                    await self._send_diff()  # 处理上一次未处理的 peek_message
                elif pack["aid"] == "peek_message":
                    self._pending_peek = True
                    await self._send_diff()
        finally:
            # 关闭所有serials
            for s in self._serials.values():
                await s["generator"].aclose()
            md_task.cancel()
            await asyncio.gather(md_task, return_exceptions=True)

    async def _md_handler(self):
        async for pack in self._md_recv_chan:
            await self._md_send_chan.send({"aid": "peek_message"})
            for d in pack.get("data", []):
                _merge_diff(self._data, d, self._api._prototype, False)

    async def _send_snapshot(self):
        """发送初始合约信息"""
        async with TqChan(self._api,
                          last_only=True) as update_chan:  # 等待与行情服务器连接成功
            self._data["_listener"].add(update_chan)
            while self._data.get("mdhis_more_data", True):
                await update_chan.recv()
        # 发送合约信息截面
        quotes = {}
        for ins, quote in self._data["quotes"].items():
            if not ins.startswith("_"):
                quotes[ins] = {
                    "open": None,  # 填写None: 删除api中的这个字段
                    "close": None,
                    "settlement": None,
                    "lower_limit": None,
                    "upper_limit": None,
                    "pre_open_interest": None,
                    "pre_settlement": None,
                    "pre_close": None,
                    "ins_class": quote.get("ins_class", ""),
                    'instrument_id': quote.get("instrument_id", ""),
                    "margin": quote.get(
                        "margin"),  # 用于内部实现模拟交易, 不作为api对外可用数据(即 Quote 类中无此字段)
                    "commission":
                    quote.get("commission"
                              ),  # 用于内部实现模拟交易, 不作为api对外可用数据(即 Quote 类中无此字段)
                    "price_tick": quote["price_tick"],
                    "price_decs": quote["price_decs"],
                    "volume_multiple": quote["volume_multiple"],
                    "max_limit_order_volume": quote["max_limit_order_volume"],
                    "max_market_order_volume":
                    quote["max_market_order_volume"],
                    "min_limit_order_volume": quote["min_limit_order_volume"],
                    "min_market_order_volume":
                    quote["min_market_order_volume"],
                    "underlying_symbol": quote["underlying_symbol"],
                    "strike_price": quote["strike_price"],
                    "expired": None,
                    "trading_time": quote.get("trading_time"),
                    "expire_datetime": quote.get("expire_datetime"),
                    "delivery_month": quote.get("delivery_month"),
                    "delivery_year": quote.get("delivery_year"),
                    "option_class": quote.get("option_class", ""),
                    "product_id": quote.get("product_id", ""),
                }
        self._diffs.append({
            "quotes": quotes,
            "ins_list": "",
            "mdhis_more_data": False,
        })

    async def _send_diff(self):
        """发送数据到 api, 如果 self._diffs 不为空则发送 self._diffs, 不推进行情时间, 否则将时间推进一格, 并发送对应的行情"""
        if self._pending_peek:
            quotes = {}
            if not self._diffs:
                while self._serials:
                    min_serial = min(
                        self._serials.keys(),
                        key=lambda serial: self._serials[serial]["timestamp"])
                    timestamp = self._serials[min_serial][
                        "timestamp"]  # 所有已订阅数据中的最小行情时间
                    quotes_diff = self._serials[min_serial]["quotes"]
                    # 推进时间,一次只会推进最多一个(补数据时有可能是0个)行情时间,并确保<=该行情时间的行情都被发出
                    # 如果行情时间大于当前回测时间 则 判断是否diff中已有数据;否则表明此行情时间的数据未全部保存在diff中,则继续append
                    if timestamp > self._current_dt:
                        if self._diffs:  # 如果diffs中已有数据:退出循环并发送数据给下游api
                            break
                        else:
                            self._current_dt = timestamp  # 否则将回测时间更新至最新行情时间
                    self._diffs.append(self._serials[min_serial]["diff"])
                    quote_info = self._quotes[min_serial[0]]
                    if quotes_diff and (quote_info["min_duration"] != 0
                                        or min_serial[1] == 0):
                        quotes[min_serial[0]] = quotes_diff
                    await self._fetch_serial(min_serial)
                if not self._serials and not self._diffs:  # 当无可发送数据时则抛出BacktestFinished例外,包括未订阅任何行情 或 所有已订阅行情的最后一笔行情获取完成
                    self._logger.warning("回测结束")
                    if self._current_dt < self._end_dt:
                        self._current_dt = 2145888000000000000  # 一个远大于 end_dt 的日期 20380101
                    await self._sim_recv_chan.send({
                        "aid":
                        "rtn_data",
                        "data": [{
                            "_tqsdk_backtest": {
                                "start_dt": self._start_dt,
                                "current_dt": self._current_dt,
                                "end_dt": self._end_dt
                            }
                        }]
                    })
                    raise BacktestFinished(self._api) from None
            for ins, diff in quotes.items():
                for d in diff:
                    self._diffs.append({"quotes": {ins: d}})
            if self._diffs:
                # 发送数据集中添加 backtest 字段,开始时间、结束时间、当前时间,表示当前行情推进是由 backtest 推进
                if self._is_first_send:
                    self._diffs.append({
                        "_tqsdk_backtest": {
                            "start_dt": self._start_dt,
                            "current_dt": self._current_dt,
                            "end_dt": self._end_dt
                        }
                    })
                    self._is_first_send = False
                else:
                    self._diffs.append(
                        {"_tqsdk_backtest": {
                            "current_dt": self._current_dt
                        }})
                rtn_data = {
                    "aid": "rtn_data",
                    "data": self._diffs,
                }
                self._diffs = []
                self._pending_peek = False
                self._logger.debug("backtest message send: %s", rtn_data)
                await self._sim_recv_chan.send(rtn_data)

    async def _ensure_serial(self, ins, dur):
        if (ins, dur) not in self._serials:
            quote = self._quotes.setdefault(
                ins,
                {  # 在此处设置 min_duration: 每次生成K线的时候会自动生成quote, 记录某一合约的最小duration
                    "min_duration": dur
                })
            quote["min_duration"] = min(quote["min_duration"], dur)
            self._serials[(ins, dur)] = {
                "generator": self._gen_serial(ins, dur),
            }
            await self._fetch_serial((ins, dur))

    async def _ensure_quote(self, ins):
        if ins not in self._quotes or self._quotes[ins][
                "min_duration"] > 60000000000:
            await self._ensure_serial(ins, 60000000000)

    async def _fetch_serial(self, serial):
        s = self._serials[serial]
        try:
            s["timestamp"], s["diff"], s["quotes"] = await s["generator"
                                                             ].__anext__()
        except StopAsyncIteration:
            del self._serials[serial]  # 删除一个行情时间超过结束时间的serial

    async def _gen_serial(self, ins, dur):
        """k线/tick 序列的 async generator, yield 出来的行情数据带有时间戳, 因此 _send_diff 可以据此归并"""
        # 先定位左端点, focus_datetime 是 lower_bound ,这里需要的是 upper_bound
        # 因此将 view_width 和 focus_position 设置成一样,这样 focus_datetime 所对应的 k线刚好位于屏幕外
        chart_info = {
            "aid": "set_chart",
            "chart_id": _generate_uuid("PYSDK_backtest"),
            "ins_list": ins,
            "duration": dur,
            "view_width":
            8964,  # 设为8964原因:可满足用户所有的订阅长度,并在backtest中将所有的 相同合约及周期 的K线用同一个serial存储
            "focus_datetime": int(self._current_dt),
            "focus_position": 8964,
        }
        chart = _get_obj(self._data, ["charts", chart_info["chart_id"]])
        current_id = None  # 当前数据指针
        serial = _get_obj(
            self._data,
            ["klines", ins, str(dur)] if dur != 0 else ["ticks", ins])
        async with TqChan(self._api, last_only=True) as update_chan:
            serial["_listener"].add(update_chan)
            chart["_listener"].add(update_chan)
            await self._md_send_chan.send(chart_info.copy())
            try:
                async for _ in update_chan:
                    if not (chart_info.items() <= _get_obj(chart,
                                                           ["state"]).items()):
                        # 当前请求还没收齐回应, 不应继续处理
                        continue
                    left_id = chart.get("left_id", -1)
                    right_id = chart.get("right_id", -1)
                    last_id = serial.get("last_id", -1)
                    if (left_id == -1 and right_id == -1) or last_id == -1:
                        # 定位信息还没收到, 或数据序列还没收到
                        continue
                    if self._data.get("mdhis_more_data", True):
                        self._data["_listener"].add(update_chan)
                        continue
                    else:
                        self._data["_listener"].discard(update_chan)
                    if current_id is None:
                        current_id = max(left_id, 0)
                    while True:
                        if current_id > last_id:
                            # 当前 id 已超过 last_id
                            return
                        if current_id - chart_info.get("left_kline_id",
                                                       left_id) > 5000:
                            # 当前 id 已超出订阅范围, 需重新订阅后续数据
                            chart_info["left_kline_id"] = current_id
                            chart_info.pop("focus_datetime", None)
                            chart_info.pop("focus_position", None)
                            await self._md_send_chan.send(chart_info.copy())
                        # 将订阅的8964长度的窗口中的数据都遍历完后,退出循环,然后再次进入并处理下一窗口数据
                        # (因为在处理过5000条数据的同时向服务器订阅从当前id开始的新一窗口的数据,在当前窗口剩下的3000条数据处理完后,下一窗口数据也已经收到)
                        if current_id > right_id:
                            break
                        item = {
                            k: v
                            for k, v in serial["data"].get(
                                str(current_id), {}).items()
                        }
                        if dur == 0:
                            diff = {
                                "ticks": {
                                    ins: {
                                        "last_id": current_id,
                                        "data": {
                                            str(current_id): item,
                                            str(current_id - 8964): None,
                                        }
                                    }
                                }
                            }
                            if item["datetime"] > self._end_dt:  # 超过结束时间
                                return
                            yield item[
                                "datetime"], diff, self._get_quotes_from_tick(
                                    item)
                        else:
                            diff = {
                                "klines": {
                                    ins: {
                                        str(dur): {
                                            "last_id": current_id,
                                            "data": {
                                                str(current_id): {
                                                    "datetime":
                                                    item["datetime"],
                                                    "open": item["open"],
                                                    "high": item["open"],
                                                    "low": item["open"],
                                                    "close": item["open"],
                                                    "volume": 0,
                                                    "open_oi": item["open_oi"],
                                                    "close_oi":
                                                    item["open_oi"],
                                                },
                                                str(current_id - 8964): None,
                                            }
                                        }
                                    }
                                }
                            }
                            timestamp = item[
                                "datetime"] if dur < 86400000000000 else _get_trading_day_start_time(
                                    item["datetime"])
                            if timestamp > self._end_dt:  # 超过结束时间
                                return
                            yield timestamp, diff, self._get_quotes_from_kline_open(
                                self._data["quotes"][ins], timestamp,
                                item)  # K线刚生成时的数据都为开盘价
                            diff = {
                                "klines": {
                                    ins: {
                                        str(dur): {
                                            "data": {
                                                str(current_id): item,
                                            }
                                        }
                                    }
                                }
                            }
                            timestamp = item[
                                "datetime"] + dur - 1000 if dur < 86400000000000 else _get_trading_day_end_time(
                                    item["datetime"]) - 999
                            if timestamp > self._end_dt:  # 超过结束时间
                                return
                            yield timestamp, diff, self._get_quotes_from_kline(
                                self._data["quotes"][ins], timestamp,
                                item)  # K线结束时生成quote数据
                        current_id += 1
            finally:
                # 释放chart资源
                chart_info["ins_list"] = ""
                await self._md_send_chan.send(chart_info.copy())

    @staticmethod
    def _get_quotes_from_tick(tick):
        quote = {k: v for k, v in tick.items()}
        quote["datetime"] = datetime.fromtimestamp(
            tick["datetime"] / 1e9).strftime("%Y-%m-%d %H:%M:%S.%f")
        return [quote]

    @staticmethod
    def _get_quotes_from_kline_open(info, timestamp, kline):
        return [
            {  # K线刚生成时的数据都为开盘价
                "datetime": datetime.fromtimestamp(timestamp / 1e9).strftime("%Y-%m-%d %H:%M:%S.%f"),
                "ask_price1": kline["open"] + info["price_tick"],
                "ask_volume1": 1,
                "bid_price1": kline["open"] - info["price_tick"],
                "bid_volume1": 1,
                "last_price": kline["open"],
                "highest": float("nan"),
                "lowest": float("nan"),
                "average": float("nan"),
                "volume": 0,
                "amount": float("nan"),
                "open_interest": kline["open_oi"],
            },
        ]

    @staticmethod
    def _get_quotes_from_kline(info, timestamp, kline):
        return [{
            "datetime":
            datetime.fromtimestamp(timestamp /
                                   1e9).strftime("%Y-%m-%d %H:%M:%S.%f"),
            "ask_price1":
            kline["high"] + info["price_tick"],
            "ask_volume1":
            1,
            "bid_price1":
            kline["high"] - info["price_tick"],
            "bid_volume1":
            1,
            "last_price":
            kline["close"],
            "highest":
            float("nan"),
            "lowest":
            float("nan"),
            "average":
            float("nan"),
            "volume":
            0,
            "amount":
            float("nan"),
            "open_interest":
            kline["close_oi"],
        }, {
            "ask_price1": kline["low"] + info["price_tick"],
            "bid_price1": kline["low"] - info["price_tick"],
        }, {
            "ask_price1": kline["close"] + info["price_tick"],
            "bid_price1": kline["close"] - info["price_tick"],
        }]
Пример #11
0
class TqSim(object):
    """
    天勤模拟交易类

    该类实现了一个本地的模拟账户,并且在内部完成撮合交易,在回测和复盘模式下,只能使用 TqSim 账户来交易。

    限价单要求报单价格达到或超过对手盘价格才能成交, 成交价为报单价格, 如果没有对手盘(涨跌停)则无法成交

    市价单使用对手盘价格成交, 如果没有对手盘(涨跌停)则自动撤单

    模拟交易不会有部分成交的情况, 要成交就是全部成交
    """

    def __init__(self, init_balance: float = 10000000.0, account_id: str = None) -> None:
        """
        Args:
            init_balance (float): [可选]初始资金, 默认为一千万

            account_id (str): [可选]帐号, 默认为 TQSIM

        Example::

            # 修改TqSim模拟帐号的初始资金为100000
            from tqsdk import TqApi, TqSim, TqAuth
            api = TqApi(TqSim(init_balance=100000), auth=TqAuth("信易账户", "账户密码"))

        """
        self.trade_log = {}  # 日期->交易记录及收盘时的权益及持仓
        self.tqsdk_stat = {}  # 回测结束后储存回测报告信息
        self._account_id = "TQSIM" if account_id is None else account_id
        self._account_type = "FUTURE"
        self._broker_id = "TQSIM" if self._account_type == "FUTURE" else "TQSIM_STOCK"
        self._account_key = str(id(self))
        self._init_balance = float(init_balance)
        if self._init_balance <= 0:
            raise Exception("初始资金(init_balance) %s 错误, 请检查 init_balance 是否填写正确" % (init_balance))
        self._current_datetime = "1990-01-01 00:00:00.000000"  # 当前行情时间(最新的 quote 时间)
        self._trading_day_end = "1990-01-01 18:00:00.000000"
        self._local_time_record = float("nan")  # 记录获取最新行情时的本地时间
        self._sim_trade = SimTrade(account_key=self._account_key, init_balance=self._init_balance,
                                     get_trade_timestamp=self._get_trade_timestamp,
                                     is_in_trading_time=self._is_in_trading_time)
        self._data = Entity()
        self._data._instance_entity([])
        self._prototype = {
            "quotes": {
                "#": Quote(self),  # 行情的数据原型
            }
        }
        self._quote_tasks = {}

    def set_commission(self, symbol: str, commission: float=float('nan')):
        """
        设置指定合约模拟交易的每手手续费。

        Args:
            symbol (str): 合约代码

            commission (float): 每手手续费

        Returns:
            float: 设置的每手手续费

        Example::

            from tqsdk import TqSim, TqApi, TqAuth

            sim = TqSim()
            api = TqApi(sim, auth=TqAuth("信易账户", "账户密码"))

            sim.set_commission("SHFE.cu2112", 50)

            print(sim.get_commission("SHFE.cu2112"))
        """
        if commission != commission:
            raise Exception("合约手续费不可以设置为 float('nan')")
        quote = _get_obj(self._data, ["quotes", symbol], Quote(self._api if hasattr(self, "_api") else None))
        quote["user_commission"] = commission
        if self._quote_tasks.get(symbol):
            self._quote_tasks[symbol]["quote_chan"].send_nowait({
                "quotes": {symbol: {"user_commission": commission}}
            })
        return commission

    def set_margin(self, symbol: str, margin: float=float('nan')):
        """
        设置指定合约模拟交易的每手保证金。

        Args:
            symbol (str): 合约代码 (只支持期货合约)

            margin (float): 每手保证金

        Returns:
            float: 设置的每手保证金

        Example::

            from tqsdk import TqSim, TqApi, TqAuth

            sim = TqSim()
            api = TqApi(sim, auth=TqAuth("信易账户", "账户密码"))

            sim.set_margin("SHFE.cu2112", 26000)

            print(sim.get_margin("SHFE.cu2112"))
        """
        if margin != margin:
            raise Exception("合约手续费不可以设置为 float('nan')")
        quote = _get_obj(self._data, ["quotes", symbol], Quote(self._api if hasattr(self, "_api") else None))
        quote["user_margin"] = margin
        if self._quote_tasks.get(symbol):
            self._quote_tasks[symbol]["quote_chan"].send_nowait({
                "quotes": {symbol: {"user_margin": margin}}
            })
            # 当用户设置保证金时,用户应该得到的效果是:
            # 在调用 sim.set_margin() 之后,立即调用 api.get_position(symbol),得到的 margin 字段应该按照新设置的保证金调整过,而且中间没有收到过行情更新包
            # 以下代码可以保证这个效果,说明:
            # 1. 持仓已经调整过:
            #   sim_trade 中持仓的 future_margin 字段更新,margin 会同时调整,那么 api 中持仓的 future_margin 更新时,margin 一定也已经更新
            # 2. 中间没有收到过行情更新包:
            #   前提1:根据 diff 协议,sim 收到 peek_message 时,会将缓存的 diffs 发给用户,当缓存的 diffs 为空,会转发 peek_message;
            #   前提2:api.wait_update() 会等到所有 task 都执行到 pending 状态,然后发送 peek_message 给 sim
            #   当用户代码执行到 sim.set_margin(),立即向 quote_chan 中发送一个数据包,quote_task 就会到 ready 状态,此时调用 wait_update(),
            #   到所有 task 执行到 pending 状态时,sim 的 diffs 中有数据了,此时收到 api 发来 peek_message 不会转发给上游,用户会先收到 sim 本身的账户数据,
            #   在下一次 wait_update,sim 的 diffs 为空,才会收到行情数据
            # 在回测时,以下代码应该只经历一次 wait_update
            while margin != self._api.get_position(symbol).get("future_margin"):
                self._api.wait_update()
        return margin

    def get_margin(self, symbol: str):
        """
        获取指定合约模拟交易的每手保证金。

        Args:
            symbol (str): 合约代码

        Returns:
            float: 返回合约模拟交易的每手保证金

        Example::

            from tqsdk import TqSim, TqApi, TqAuth

            sim = TqSim()
            api = TqApi(sim, auth=TqAuth("信易账户", "账户密码"))

            quote = api.get_quote("SHFE.cu2112")
            print(sim.get_margin("SHFE.cu2112"))
        """
        return _get_future_margin(self._data.get("quotes", {}).get(symbol, {}))

    def get_commission(self, symbol: str):
        """
        获取指定合约模拟交易的每手手续费

        Args:
            symbol (str): 合约代码

        Returns:
            float: 返回合约模拟交易的每手手续费

        Example::

            from tqsdk import TqSim, TqApi, TqAuth

            sim = TqSim()
            api = TqApi(sim, auth=TqAuth("信易账户", "账户密码"))

            quote = api.get_quote("SHFE.cu2112")
            print(sim.get_commission("SHFE.cu2112"))
        """
        return _get_commission(self._data.get("quotes", {}).get(symbol, {}))

    async def _run(self, api, api_send_chan, api_recv_chan, md_send_chan, md_recv_chan):
        """模拟交易task"""
        self._api = api
        self._tqsdk_backtest = {}  # 储存可能的回测信息
        self._logger = api._logger.getChild("TqSim")  # 调试信息输出
        self._api_send_chan = api_send_chan
        self._api_recv_chan = api_recv_chan
        self._md_send_chan = md_send_chan
        self._md_recv_chan = md_recv_chan
        self._pending_peek = False
        # True 下游发过 subscribe,但是没有转发给上游;False 表示下游发的 subscribe 都转发给上游
        self._pending_subscribe_downstream = False
        # True 发给上游 subscribe,但是没有收到过回复;False 如果行情不变,上游不会回任何包
        self._pending_subscribe_upstream = False
        self._diffs = []
        self._all_subscribe = set()  # 客户端+模拟交易模块订阅的合约集合
        # 是否已经发送初始账户信息
        self._has_send_init_account = False
        md_task = self._api.create_task(self._md_handler())  # 将所有 md_recv_chan 上收到的包投递到 api_send_chan 上
        try:
            async for pack in self._api_send_chan:
                if "_md_recv" in pack:
                    if pack["aid"] == "rtn_data":
                        self._md_recv(pack)  # md_recv 中会发送 wait_count 个 quotes 包给各个 quote_chan
                        await asyncio.gather(*[quote_task["quote_chan"].join() for quote_task in self._quote_tasks.values()])
                        await self._send_diff()
                elif pack["aid"] == "subscribe_quote":
                    await self._subscribe_quote(set(pack["ins_list"].split(",")))
                elif pack["aid"] == "peek_message":
                    self._pending_peek = True
                    await self._send_diff()
                    if self._pending_peek:  # 控制"peek_message"发送: 当没有新的事件需要用户处理时才推进到下一个行情
                        await self._md_send_chan.send(pack)
                elif pack["aid"] == "insert_order":
                    # 非该账户的消息包发送至下一个账户
                    if pack["account_key"] != self._account_key:
                        await self._md_send_chan.send(pack)
                    else:
                        symbol = pack["exchange_id"] + "." + pack["instrument_id"]
                        if symbol not in self._quote_tasks:
                            quote_chan = TqChan(self._api)
                            order_chan = TqChan(self._api)
                            self._quote_tasks[symbol] = {
                                "quote_chan": quote_chan,
                                "order_chan": order_chan,
                                "task": self._api.create_task(self._quote_handler(symbol, quote_chan, order_chan))
                            }
                        if "account_key" in pack:
                            pack.pop("account_key", None)
                        await self._quote_tasks[symbol]["order_chan"].send(pack)
                elif pack["aid"] == "cancel_order":
                    # 非该账户的消息包发送至下一个账户
                    if pack["account_key"] != self._account_key:
                        await self._md_send_chan.send(pack)
                    else:
                        # 发送至服务器的包需要去除 account_key 信息
                        if "account_key" in pack:
                            pack.pop("account_key", None)
                        # pack 里只有 order_id 信息,发送到每一个合约的 order_chan, 交由 quote_task 判断是不是当前合约下的委托单
                        for symbol in self._quote_tasks:
                            await self._quote_tasks[symbol]["order_chan"].send(pack)
                else:
                    await self._md_send_chan.send(pack)
                if self._tqsdk_backtest != {} and self._tqsdk_backtest["current_dt"] >= self._tqsdk_backtest["end_dt"] \
                        and not self.tqsdk_stat:
                    # 回测情况下,把 _send_stat_report 在循环中回测结束时执行
                    await self._send_stat_report()
        finally:
            if not self.tqsdk_stat:
                await self._send_stat_report()
            md_task.cancel()
            tasks = [md_task]
            for symbol in self._quote_tasks:
                self._quote_tasks[symbol]["task"].cancel()
                tasks.append(self._quote_tasks[symbol]["task"])
            await asyncio.gather(*tasks, return_exceptions=True)

    async def _md_handler(self):
        async for pack in self._md_recv_chan:
            pack["_md_recv"] = True
            self._pending_subscribe_upstream = False
            await self._api_send_chan.send(pack)

    async def _send_diff(self):
        if self._pending_peek:
            if self._diffs:
                rtn_data = {
                    "aid": "rtn_data",
                    "data": self._diffs,
                }
                self._diffs = []
                self._pending_peek = False
                await self._api_recv_chan.send(rtn_data)
            if self._pending_subscribe_downstream:
                self._pending_subscribe_upstream = True
                self._pending_subscribe_downstream = False
                await self._md_send_chan.send({
                    "aid": "subscribe_quote",
                    "ins_list": ",".join(self._all_subscribe)
                })

    async def _subscribe_quote(self, symbols: [set, str]):
        """这里只会增加订阅合约,不会退订合约"""
        symbols = symbols if isinstance(symbols, set) else {symbols}
        if symbols - self._all_subscribe:
            self._all_subscribe |= symbols
            if self._pending_peek and not self._pending_subscribe_upstream:
                self._pending_subscribe_upstream = True
                self._pending_subscribe_downstream = False
                await self._md_send_chan.send({
                    "aid": "subscribe_quote",
                    "ins_list": ",".join(self._all_subscribe)
                })
            else:
                self._pending_subscribe_downstream = True

    async def _send_stat_report(self):
        self._settle()
        self._report()
        await self._api_recv_chan.send({
            "aid": "rtn_data",
            "data": [{
                "trade": {
                    self._account_key: {
                        "accounts": {
                            "CNY": {
                                "_tqsdk_stat": self.tqsdk_stat
                            }
                        }
                    }
                }
            }]
        })

    async def _ensure_quote_info(self, symbol, quote_chan):
        """quote收到合约信息后返回"""
        quote = _get_obj(self._data, ["quotes", symbol], Quote(self._api))
        if quote.get("price_tick") == quote.get("price_tick"):
            return quote.copy()
        if quote.get("price_tick") != quote.get("price_tick"):
            await self._md_send_chan.send(_query_for_quote(symbol))
        async for _ in quote_chan:
            quote_chan.task_done()
            if quote.get("price_tick") == quote.get("price_tick"):
                return quote.copy()

    async def _ensure_quote(self, symbol, quote_chan):
        """quote收到行情以及合约信息后返回"""
        quote = _get_obj(self._data, ["quotes", symbol], Quote(self._api))
        _register_update_chan(quote, quote_chan)
        if quote.get("datetime", "") and quote.get("price_tick") == quote.get("price_tick"):
            return quote.copy()
        if quote.get("price_tick") != quote.get("price_tick"):
            # 对于没有合约信息的 quote,发送查询合约信息的请求
            await self._md_send_chan.send(_query_for_quote(symbol))
        async for _ in quote_chan:
            quote_chan.task_done()
            if quote.get("datetime", "") and quote.get("price_tick") == quote.get("price_tick"):
                return quote.copy()

    async def _quote_handler(self, symbol, quote_chan, order_chan):
        try:
            await self._subscribe_quote(symbol)
            quote = await self._ensure_quote(symbol, quote_chan)
            if quote["ins_class"].endswith("INDEX") and quote["exchange_id"] == "KQ":
                # 指数可以交易,需要补充 margin commission
                if "margin" not in quote:
                    quote_m = await self._ensure_quote_info(symbol.replace("KQ.i", "KQ.m"), quote_chan)
                    quote_underlying = await self._ensure_quote_info(quote_m["underlying_symbol"], quote_chan)
                    self._data["quotes"][symbol]["margin"] = quote_underlying["margin"]
                    self._data["quotes"][symbol]["commission"] = quote_underlying["commission"]
                    quote.update(self._data["quotes"][symbol])
            underlying_quote = None
            if quote["ins_class"].endswith("OPTION"):
                # 如果是期权,订阅标的合约行情,确定收到期权标的合约行情
                underlying_symbol = quote["underlying_symbol"]
                await self._subscribe_quote(underlying_symbol)
                underlying_quote = await self._ensure_quote(underlying_symbol, quote_chan)  # 订阅合约
            # 在等待标的行情的过程中,quote_chan 可能有期权行情,把 quote_chan 清空,并用最新行情更新 quote
            while not quote_chan.empty():
                quote_chan.recv_nowait()
                quote_chan.task_done()

            # 用最新行情更新 quote
            quote.update(self._data["quotes"][symbol])
            if underlying_quote:
                underlying_quote.update(self._data["quotes"][underlying_symbol])
            task = self._api.create_task(self._forward_chan_handler(order_chan, quote_chan))
            quotes = {symbol: quote}
            if underlying_quote:
                quotes[underlying_symbol] = underlying_quote
            self._sim_trade.update_quotes(symbol, {"quotes": quotes})
            async for pack in quote_chan:
                if "aid" not in pack:
                    diffs, orders_events = self._sim_trade.update_quotes(symbol, pack)
                    self._handle_diffs(diffs, orders_events, "match order")
                elif pack["aid"] == "insert_order":
                    diffs, orders_events = self._sim_trade.insert_order(symbol, pack)
                    self._handle_diffs(diffs, orders_events, "insert order")
                    await self._send_diff()
                elif pack["aid"] == "cancel_order":
                    diffs, orders_events = self._sim_trade.cancel_order(symbol, pack)
                    self._handle_diffs(diffs, orders_events, "cancel order")
                    await self._send_diff()
                quote_chan.task_done()
        finally:
            await quote_chan.close()
            await order_chan.close()
            task.cancel()
            await asyncio.gather(task, return_exceptions=True)

    async def _forward_chan_handler(self, chan_from, chan_to):
        async for pack in chan_from:
            await chan_to.send(pack)

    def _md_recv(self, pack):
        for d in pack["data"]:
            self._diffs.append(d)
            # 在第一次收到 mdhis_more_data 为 False 的时候,发送账户初始截面信息,这样回测模式下,往后的模块才有正确的时间顺序
            if not self._has_send_init_account and not d.get("mdhis_more_data", True):
                self._diffs.append(self._sim_trade.init_snapshot())
                self._diffs.append({
                    "trade": {
                        self._account_key: {
                            "trade_more_data": False
                        }
                    }
                })
                self._has_send_init_account = True
            _tqsdk_backtest = d.get("_tqsdk_backtest", {})
            if _tqsdk_backtest:
                # 回测时,用 _tqsdk_backtest 对象中 current_dt 作为 TqSim 的 _current_datetime
                self._tqsdk_backtest.update(_tqsdk_backtest)
                self._current_datetime = datetime.fromtimestamp(
                    self._tqsdk_backtest["current_dt"] / 1e9).strftime("%Y-%m-%d %H:%M:%S.%f")
                self._local_time_record = float("nan")
                # 1. 回测时不使用时间差来模拟交易所时间的原因(_local_time_record始终为初始值nan):
                #   在sim收到行情后记录_local_time_record,然后下发行情到api进行merge_diff(),api需要处理完k线和quote才能结束wait_update(),
                #   若处理时间过长,此时下单则在判断下单时间时与测试用例中的预期时间相差较大,导致测试用例无法通过。
                # 2. 回测不使用时间差的方法来判断下单时间仍是可行的: 与使用了时间差的方法相比, 只对在每个交易时间段最后一笔行情时的下单时间判断有差异,
                #   若不使用时间差, 则在最后一笔行情时下单仍判断为在可交易时间段内, 且可成交.
            quotes_diff = d.get("quotes", {})
            # 先根据 quotes_diff 里的 datetime, 确定出 _current_datetime,再 _merge_diff(同时会发送行情到 quote_chan)
            for symbol, quote_diff in quotes_diff.items():
                if quote_diff is None:
                    continue
                # 若直接使用本地时间来判断下单时间是否在可交易时间段内 可能有较大误差,因此判断的方案为:(在接收到下单指令时判断 估计的交易所时间 是否在交易时间段内)
                # 在更新最新行情时间(即self._current_datetime)时,记录当前本地时间(self._local_time_record),
                # 在这之后若收到下单指令,则获取当前本地时间,判 "最新行情时间 + (当前本地时间 - 记录的本地时间)" 是否在交易时间段内。
                # 另外, 若在盘后下单且下单前未订阅此合约:
                # 因为从_md_recv()中获取数据后立即判断下单时间则速度过快(两次time.time()的时间差小于最后一笔行情(14:59:9995)到15点的时间差),
                # 则会立即成交,为处理此情况则将当前时间减去5毫秒(模拟发生5毫秒网络延迟,则两次time.time()的时间差增加了5毫秒)。
                # todo: 按交易所来存储 _current_datetime(issue: #277)
                if quote_diff.get("datetime", "") > self._current_datetime:
                    # 回测时,当前时间更新即可以由 quote 行情更新,也可以由 _tqsdk_backtest.current_dt 更新,
                    # 在最外层的循环里,_tqsdk_backtest.current_dt 是在 rtn_data.data 中数组位置中的最后一个,会在循环最后一个才更新 self.current_datetime
                    # 导致前面处理 order 时的 _current_datetime 还是旧的行情时间
                    self._current_datetime = quote_diff["datetime"]  # 最新行情时间
                    # 更新最新行情时间时的本地时间,回测时不使用时间差
                    self._local_time_record = (time.time() - 0.005) if not self._tqsdk_backtest else float("nan")
                if self._current_datetime > self._trading_day_end:  # 结算
                    self._settle()
                    # 若当前行情时间大于交易日的结束时间(切换交易日),则根据此行情时间更新交易日及交易日结束时间
                    trading_day = _get_trading_day_from_timestamp(self._get_current_timestamp())
                    self._trading_day_end = datetime.fromtimestamp(
                        (_get_trading_day_end_time(trading_day) - 999) / 1e9).strftime("%Y-%m-%d %H:%M:%S.%f")
            if quotes_diff:
                _merge_diff(self._data, {"quotes": quotes_diff}, self._prototype, False, True)

    def _handle_diffs(self, diffs, orders_events, msg):
        """
        处理 sim_trade 返回的 diffs
        orders_events 为持仓变更事件,依次屏幕输出信息,打印日志
        """
        self._diffs += diffs
        for order in orders_events:
            if order["status"] == "FINISHED":
                self._handle_on_finished(msg, order)
            else:
                assert order["status"] == "ALIVE"
                self._handle_on_alive(msg, order)

    def _handle_on_alive(self, msg, order):
        """
        在 order 状态变为 ALIVE 调用,屏幕输出信息,打印日志
        """
        symbol = f"{order['exchange_id']}.{order['instrument_id']}"
        self._api._print(
            f"模拟交易下单 {order['order_id']}: 时间: {_format_from_timestamp_nano(order['insert_date_time'])}, "
            f"合约: {symbol}, 开平: {order['offset']}, 方向: {order['direction']}, 手数: {order['volume_left']}, "
            f"价格: {order.get('limit_price', '市价')}")
        self._logger.debug(msg, order_id=order["order_id"], datetime=order["insert_date_time"],
                           symbol=symbol, offset=order["offset"], direction=order["direction"],
                           volume_left=order["volume_left"], limit_price=order.get("limit_price", "市价"))

    def _handle_on_finished(self, msg, order):
        """
        在 order 状态变为 FINISHED 调用,屏幕输出信息,打印日志
        """
        self._api._print(f"模拟交易委托单 {order['order_id']}: {order['last_msg']}")
        self._logger.debug(msg, order_id=order["order_id"], last_msg=order["last_msg"], status=order["status"],
                           volume_orign=order["volume_orign"], volume_left=order["volume_left"])

    def _settle(self):
        if self._trading_day_end[:10] == "1990-01-01":
            return
        # 结算并记录账户截面
        diffs, orders_events, trade_log = self._sim_trade.settle()
        self._handle_diffs(diffs, orders_events, "settle")
        self.trade_log[self._trading_day_end[:10]] = trade_log

    def _report(self):
        if not self.trade_log:
            return
        date_keys = sorted(self.trade_log.keys())
        self._api._print("模拟交易成交记录")
        for d in date_keys:
            for t in self.trade_log[d]["trades"]:
                symbol = t["exchange_id"] + "." + t["instrument_id"]
                self._api._print(f"时间: {_format_from_timestamp_nano(t['trade_date_time'])}, 合约: {symbol}, "
                                 f"开平: {t['offset']}, 方向: {t['direction']}, 手数: {t['volume']}, 价格: {t['price']:.3f},"
                                 f"手续费: {t['commission']:.2f}")

        self._api._print("模拟交易账户资金")
        for d in date_keys:
            account = self.trade_log[d]["account"]
            self._api._print(
                f"日期: {d}, 账户权益: {account['balance']:.2f}, 可用资金: {account['available']:.2f}, "
                f"浮动盈亏: {account['float_profit']:.2f}, 持仓盈亏: {account['position_profit']:.2f}, "
                f"平仓盈亏: {account['close_profit']:.2f}, 市值: {account['market_value']:.2f}, "
                f"保证金: {account['margin']:.2f}, 手续费: {account['commission']:.2f}, "
                f"风险度: {account['risk_ratio'] * 100:.2f}%")

        # TqReport 模块计算交易统计信息
        report = TqReport(report_id=self._account_id, trade_log=self.trade_log, quotes=self._data['quotes'])
        self.tqsdk_stat = report.default_metrics
        self._api._print(
            f"胜率: {self.tqsdk_stat['winning_rate'] * 100:.2f}%, 盈亏额比例: {self.tqsdk_stat['profit_loss_ratio']:.2f}, "
            f"收益率: {self.tqsdk_stat['ror'] * 100:.2f}%, 年化收益率: {self.tqsdk_stat['annual_yield'] * 100:.2f}%, "
            f"最大回撤: {self.tqsdk_stat['max_drawdown'] * 100:.2f}%, 年化夏普率: {self.tqsdk_stat['sharpe_ratio']:.4f},"
            f"年化索提诺比率: {self.tqsdk_stat['sortino_ratio']:.4f}")

        # 回测情况下,在计算报告之后,还会发送绘制图表请求,
        # 这样处理,用户不要修改代码,就能够看到报告图表
        if self._tqsdk_backtest:
            self._api.draw_report(report.full())

    def _get_current_timestamp(self):
        return int(datetime.strptime(self._current_datetime, "%Y-%m-%d %H:%M:%S.%f").timestamp() * 1e6) * 1000

    def _get_trade_timestamp(self):
        return _get_trade_timestamp(self._current_datetime, self._local_time_record)

    def _is_in_trading_time(self, quote):
        return _is_in_trading_time(quote, self._current_datetime, self._local_time_record)
Пример #12
0
 async def _run(self, api, sim_send_chan, sim_recv_chan, md_send_chan, md_recv_chan):
     """回测task"""
     self._api = api
     # 下载历史主连合约信息
     start_trading_day = _get_trading_day_from_timestamp(self._start_dt)  # 回测开始交易日
     end_trading_day = _get_trading_day_from_timestamp(self._end_dt)  # 回测结束交易日
     self._continuous_table = TqBacktestContinuous(start_dt=start_trading_day,
                                                   end_dt=end_trading_day,
                                                   headers=self._api._base_headers)
     self._stock_dividend = TqBacktestDividend(start_dt=start_trading_day,
                                               end_dt=end_trading_day,
                                               headers=self._api._base_headers)
     self._logger = api._logger.getChild("TqBacktest")  # 调试信息输出
     self._sim_send_chan = sim_send_chan
     self._sim_recv_chan = sim_recv_chan
     self._md_send_chan = md_send_chan
     self._md_recv_chan = md_recv_chan
     self._pending_peek = False
     self._data = Entity()  # 数据存储
     self._data._instance_entity([])
     self._prototype = {
         "quotes": {
             "#": BtQuote(self._api),  # 行情的数据原型
         },
         "klines": {
             "*": {
                 "*": {
                     "data": {
                         "@": Kline(self._api),  # K线的数据原型
                     }
                 }
             }
         },
         "ticks": {
             "*": {
                 "data": {
                     "@": Tick(self._api),  # Tick的数据原型
                 }
             }
         }
     }
     self._sended_to_api = {}  # 已经发给 api 的 rangeset  (symbol, dur),只记录了 kline
     self._serials = {}  # 所有用户请求的 chart 序列,如果用户订阅行情,默认请求 1 分钟 Kline
     # gc 是会循环 self._serials,来计算用户需要的数据,self._serials 不应该被删除,
     self._generators = {}  # 所有用户请求的 chart 序列相应的 generator 对象,创建时与 self._serials 一一对应,会在一个序列计算到最后一根 kline 时被删除
     self._had_any_generator = False  # 回测过程中是否有过 generator 对象
     self._sim_recv_chan_send_count = 0  # 统计向下游发送的 diff 的次数,每 1w 次执行一次 gc
     self._quotes = {}  # 记录 min_duration 记录某一合约的最小duration; sended_init_quote 是否已经过这个合约的初始行情
     self._diffs: List[Dict[str, Any]] = []
     self._is_first_send = True
     md_task = self._api.create_task(self._md_handler())
     try:
         await self._send_snapshot()
         async for pack in self._sim_send_chan:
             if pack["aid"] == "ins_query":
                 await self._md_send_chan.send(pack)
                 # 回测 query 不为空时需要ensure_query
                 # 1. 在api初始化时会发送初始化请求(2.5.0版本开始已经不再发送初始化请求),接着会发送peek_message,如果这里没有等到结果,那么在收到 peek_message 的时候,会发现没有数据需要发送,回测结束
                 # 2. api在发送请求后,会调用 wait_update 更新数据,如果这里没有等到结果,行情可能会被推进
                 # query 为空时,表示清空数据的请求,这个可以直接发出去,不需要等到收到回复
                 if pack["query"] != "":
                     await self._ensure_query(pack)
                 await self._send_diff()
             elif pack["aid"] == "subscribe_quote":
                 # todo: 回测时,用户如果先订阅日线,再订阅行情,会直接返回以日线 datetime 标识的行情信息,而不是当前真正的行情时间
                 self._diffs.append({
                     "ins_list": pack["ins_list"]
                 })
                 for ins in pack["ins_list"].split(","):
                     await self._ensure_quote(ins)
                 await self._send_diff()  # 处理上一次未处理的 peek_message
             elif pack["aid"] == "set_chart":
                 if pack["ins_list"]:
                     # 回测模块中已保证每次将一个行情时间的数据全部发送给api,因此更新行情时 保持与初始化时一样的charts信息(即不作修改)
                     self._diffs.append({
                         "charts": {
                             pack["chart_id"]: {
                                 # 两个id设置为0:保证api在回测中判断此值时不是-1,即直接通过对数据接收完全的验证
                                 "left_id": 0,
                                 "right_id": 0,
                                 "more_data": False,  # 直接发送False给api,表明数据发送完全,使api中通过数据接收完全的验证
                                 "state": pack
                             }
                         }
                     })
                     await self._ensure_serial(pack["ins_list"], pack["duration"], pack["chart_id"])
                 else:
                     self._diffs.append({
                         "charts": {
                             pack["chart_id"]: None
                         }
                     })
                 await self._send_diff()  # 处理上一次未处理的 peek_message
             elif pack["aid"] == "peek_message":
                 self._pending_peek = True
                 await self._send_diff()
     finally:
         # 关闭所有 generator
         for s in self._generators.values():
             await s.aclose()
         md_task.cancel()
         await asyncio.gather(md_task, return_exceptions=True)
Пример #13
0
class TqBacktest(object):
    """
    天勤回测类

    将该类传入 TqApi 的构造函数, 则策略就会进入回测模式。

    回测模式下 k线会在刚创建出来时和结束时分别更新一次, 在这之间 k线是不会更新的。

    回测模式下 quote 的更新频率由所订阅的 tick 和 k线周期确定:
        * 只要订阅了 tick, 则对应合约的 quote 就会使用 tick 生成, 更新频率也和 tick 一致, 但 **只有下字段** :
          datetime/ask&bid_price1/ask&bid_volume1/last_price/highest/lowest/average/volume/amount/open_interest/
          price_tick/price_decs/volume_multiple/max&min_limit&market_order_volume/underlying_symbol/strike_price

        * 如果没有订阅 tick, 但是订阅了 k线, 则对应合约的 quote 会使用 k线生成, 更新频率和 k线的周期一致, 如果订阅了某个合约的多个周期的 k线,
          则任一个周期的 k线有更新时, quote 都会更新. 使用 k线生成的 quote 的盘口由收盘价分别加/减一个最小变动单位, 并且 highest/lowest/average/amount
          始终为 nan, volume 始终为0

        * 如果即没有订阅 tick, 也没有订阅k线或 订阅的k线周期大于分钟线, 则 TqBacktest 会 **自动订阅分钟线** 来生成 quote

        * 如果没有订阅 tick, 但是订阅了 k线, 则对应合约的 quote **只有下字段** :
          datetime/ask&bid_price1/ask&bid_volume1/last_price/open_interest/
          price_tick/price_decs/volume_multiple/max&min_limit&market_order_volume/underlying_symbol/strike_price

    **注意** :如果未订阅 quote,模拟交易在下单时会自动为此合约订阅 quote ,根据回测时 quote 的更新规则,如果此合约没有订阅K线或K线周期大于分钟线 **则会自动订阅一个分钟线** 。

    模拟交易要求报单价格大于等于对手盘价格才会成交, 例如下买单, 要求价格大于等于卖一价才会成交, 如果不能立即成交则会等到下次行情更新再重新判断。

    回测模式下 wait_update 每次最多推进一个行情时间。

    回测结束后会抛出 BacktestFinished 例外。

    对 **组合合约** 进行回测时需注意:只能通过订阅 tick 数据来回测,不能订阅K线,因为K线是由最新价合成的,而交易所发回的组合合约数据中无最新价。
    """

    def __init__(self, start_dt: Union[date, datetime], end_dt: Union[date, datetime]) -> None:
        """
        创建天勤回测类

        Args:
            start_dt (date/datetime): 回测起始时间, 如果类型为 date 则指的是交易日, 如果为 datetime 则指的是具体时间点

            end_dt (date/datetime): 回测结束时间, 如果类型为 date 则指的是交易日, 如果为 datetime 则指的是具体时间点
        """
        if isinstance(start_dt, datetime):
            self._start_dt = int(start_dt.timestamp() * 1e9)
        elif isinstance(start_dt, date):
            self._start_dt = _get_trading_day_start_time(
                int(datetime(start_dt.year, start_dt.month, start_dt.day).timestamp()) * 1000000000)
        else:
            raise Exception("回测起始时间(start_dt)类型 %s 错误, 请检查 start_dt 数据类型是否填写正确" % (type(start_dt)))
        if isinstance(end_dt, datetime):
            self._end_dt = int(end_dt.timestamp() * 1e9)
        elif isinstance(end_dt, date):
            self._end_dt = _get_trading_day_end_time(
                int(datetime(end_dt.year, end_dt.month, end_dt.day).timestamp()) * 1000000000)
        else:
            raise Exception("回测结束时间(end_dt)类型 %s 错误, 请检查 end_dt 数据类型是否填写正确" % (type(end_dt)))
        self._current_dt = self._start_dt
        # 记录当前的交易日 开始时间/结束时间
        self._trading_day = _get_trading_day_from_timestamp(self._current_dt)
        self._trading_day_start = _get_trading_day_start_time(self._trading_day)
        self._trading_day_end = _get_trading_day_end_time(self._trading_day)

    async def _run(self, api, sim_send_chan, sim_recv_chan, md_send_chan, md_recv_chan):
        """回测task"""
        self._api = api
        # 下载历史主连合约信息
        start_trading_day = _get_trading_day_from_timestamp(self._start_dt)  # 回测开始交易日
        end_trading_day = _get_trading_day_from_timestamp(self._end_dt)  # 回测结束交易日
        self._continuous_table = TqBacktestContinuous(start_dt=start_trading_day,
                                                      end_dt=end_trading_day,
                                                      headers=self._api._base_headers)
        self._stock_dividend = TqBacktestDividend(start_dt=start_trading_day,
                                                  end_dt=end_trading_day,
                                                  headers=self._api._base_headers)
        self._logger = api._logger.getChild("TqBacktest")  # 调试信息输出
        self._sim_send_chan = sim_send_chan
        self._sim_recv_chan = sim_recv_chan
        self._md_send_chan = md_send_chan
        self._md_recv_chan = md_recv_chan
        self._pending_peek = False
        self._data = Entity()  # 数据存储
        self._data._instance_entity([])
        self._prototype = {
            "quotes": {
                "#": BtQuote(self._api),  # 行情的数据原型
            },
            "klines": {
                "*": {
                    "*": {
                        "data": {
                            "@": Kline(self._api),  # K线的数据原型
                        }
                    }
                }
            },
            "ticks": {
                "*": {
                    "data": {
                        "@": Tick(self._api),  # Tick的数据原型
                    }
                }
            }
        }
        self._sended_to_api = {}  # 已经发给 api 的 rangeset  (symbol, dur),只记录了 kline
        self._serials = {}  # 所有用户请求的 chart 序列,如果用户订阅行情,默认请求 1 分钟 Kline
        # gc 是会循环 self._serials,来计算用户需要的数据,self._serials 不应该被删除,
        self._generators = {}  # 所有用户请求的 chart 序列相应的 generator 对象,创建时与 self._serials 一一对应,会在一个序列计算到最后一根 kline 时被删除
        self._had_any_generator = False  # 回测过程中是否有过 generator 对象
        self._sim_recv_chan_send_count = 0  # 统计向下游发送的 diff 的次数,每 1w 次执行一次 gc
        self._quotes = {}  # 记录 min_duration 记录某一合约的最小duration; sended_init_quote 是否已经过这个合约的初始行情
        self._diffs: List[Dict[str, Any]] = []
        self._is_first_send = True
        md_task = self._api.create_task(self._md_handler())
        try:
            await self._send_snapshot()
            async for pack in self._sim_send_chan:
                if pack["aid"] == "ins_query":
                    await self._md_send_chan.send(pack)
                    # 回测 query 不为空时需要ensure_query
                    # 1. 在api初始化时会发送初始化请求(2.5.0版本开始已经不再发送初始化请求),接着会发送peek_message,如果这里没有等到结果,那么在收到 peek_message 的时候,会发现没有数据需要发送,回测结束
                    # 2. api在发送请求后,会调用 wait_update 更新数据,如果这里没有等到结果,行情可能会被推进
                    # query 为空时,表示清空数据的请求,这个可以直接发出去,不需要等到收到回复
                    if pack["query"] != "":
                        await self._ensure_query(pack)
                    await self._send_diff()
                elif pack["aid"] == "subscribe_quote":
                    # todo: 回测时,用户如果先订阅日线,再订阅行情,会直接返回以日线 datetime 标识的行情信息,而不是当前真正的行情时间
                    self._diffs.append({
                        "ins_list": pack["ins_list"]
                    })
                    for ins in pack["ins_list"].split(","):
                        await self._ensure_quote(ins)
                    await self._send_diff()  # 处理上一次未处理的 peek_message
                elif pack["aid"] == "set_chart":
                    if pack["ins_list"]:
                        # 回测模块中已保证每次将一个行情时间的数据全部发送给api,因此更新行情时 保持与初始化时一样的charts信息(即不作修改)
                        self._diffs.append({
                            "charts": {
                                pack["chart_id"]: {
                                    # 两个id设置为0:保证api在回测中判断此值时不是-1,即直接通过对数据接收完全的验证
                                    "left_id": 0,
                                    "right_id": 0,
                                    "more_data": False,  # 直接发送False给api,表明数据发送完全,使api中通过数据接收完全的验证
                                    "state": pack
                                }
                            }
                        })
                        await self._ensure_serial(pack["ins_list"], pack["duration"], pack["chart_id"])
                    else:
                        self._diffs.append({
                            "charts": {
                                pack["chart_id"]: None
                            }
                        })
                    await self._send_diff()  # 处理上一次未处理的 peek_message
                elif pack["aid"] == "peek_message":
                    self._pending_peek = True
                    await self._send_diff()
        finally:
            # 关闭所有 generator
            for s in self._generators.values():
                await s.aclose()
            md_task.cancel()
            await asyncio.gather(md_task, return_exceptions=True)

    async def _md_handler(self):
        async for pack in self._md_recv_chan:
            await self._md_send_chan.send({
                "aid": "peek_message"
            })
            recv_quotes = False
            for d in pack.get("data", []):
                _merge_diff(self._data, d, self._prototype, persist=False, reduce_diff=False)
                # 收到的 quotes 转发给下游
                quotes = d.get("quotes", {})
                if quotes:
                    recv_quotes = True
                    quotes = self._update_valid_quotes(quotes)  # 删去回测 quotes 不应该下发的字段
                    self._diffs.append({"quotes": quotes})
                # 收到的 symbols 应该转发给下游
                if d.get("symbols"):
                    self._diffs.append({"symbols": d["symbols"]})
            # 如果没有收到 quotes(合约信息),或者当前的 self._data.get('quotes', {}) 里没有股票,那么不应该向 _diffs 里添加元素
            if recv_quotes:
                quotes_stock = self._stock_dividend._get_dividend(self._data.get('quotes', {}), self._trading_day)
                if quotes_stock:
                    self._diffs.append({"quotes": quotes_stock})

    def _update_valid_quotes(self, quotes):
        # 从 quotes 返回只剩余合约信息的字段的 quotes,防止发生未来数据发送给下游
        # backtest 模块会生成的数据
        invalid_keys = {f"{d}{i+1}" for d in ['ask_price', 'ask_volume', 'bid_price', 'bid_volume'] for i in range(5)}
        invalid_keys.union({'datetime', 'last_price', 'highest', 'lowest', 'average', 'volume', 'amount', 'open_interest'})
        invalid_keys.union({'cash_dividend_ratio', 'stock_dividend_ratio'})  # 这两个字段完全由 self._stock_dividend 负责处理
        # backtest 模块不会生成的数据,下游服务也不应该收到的数据
        invalid_keys.union({'open', 'close', 'settlement', 'lowest', 'lower_limit', 'upper_limit', 'pre_open_interest', 'pre_settlement', 'pre_close', 'expired'})
        for symbol, quote in quotes.items():
            [quote.pop(k, None) for k in invalid_keys]
            if symbol.startswith("KQ.m"):
                quote.pop("underlying_symbol", None)
            if quote.get('expire_datetime'):
                # 先删除所有的 quote 的 expired 字段,只在有 expire_datetime 字段时才会添加 expired 字段
                quote['expired'] = quote.get('expire_datetime') * 1e9 <= self._trading_day_start
        return quotes

    async def _send_snapshot(self):
        """发送初始合约信息"""
        async with TqChan(self._api, last_only=True) as update_chan:  # 等待与行情服务器连接成功
            self._data["_listener"].add(update_chan)
            while self._data.get("mdhis_more_data", True):
                await update_chan.recv()
        # 发送初始行情(合约信息截面)时
        quotes = {}
        for ins, quote in self._data["quotes"].items():
            if not ins.startswith("_"):
                trading_time = quote.get("trading_time", {})
                quotes[ins] = {
                    "open": None,  # 填写None: 删除api中的这个字段
                    "close": None,
                    "settlement": None,
                    "lower_limit": None,
                    "upper_limit": None,
                    "pre_open_interest": None,
                    "pre_settlement": None,
                    "pre_close": None,
                    "ins_class": quote.get("ins_class", ""),
                    "instrument_id": quote.get("instrument_id", ""),
                    "exchange_id": quote.get("exchange_id", ""),
                    "margin": quote.get("margin"),  # 用于内部实现模拟交易, 不作为api对外可用数据(即 Quote 类中无此字段)
                    "commission": quote.get("commission"),  # 用于内部实现模拟交易, 不作为api对外可用数据(即 Quote 类中无此字段)
                    "price_tick": quote["price_tick"],
                    "price_decs": quote["price_decs"],
                    "volume_multiple": quote["volume_multiple"],
                    "max_limit_order_volume": quote["max_limit_order_volume"],
                    "max_market_order_volume": quote["max_market_order_volume"],
                    "min_limit_order_volume": quote["min_limit_order_volume"],
                    "min_market_order_volume": quote["min_market_order_volume"],
                    "underlying_symbol": quote["underlying_symbol"],
                    "strike_price": quote["strike_price"],
                    "expired": quote.get('expire_datetime', float('nan')) <= self._trading_day_start,  # expired 默认值就是 False
                    "trading_time": {"day": trading_time.get("day", []), "night": trading_time.get("night", [])},
                    "expire_datetime": quote.get("expire_datetime"),
                    "delivery_month": quote.get("delivery_month"),
                    "delivery_year": quote.get("delivery_year"),
                    "option_class": quote.get("option_class", ""),
                    "product_id": quote.get("product_id", ""),
                }
        # 修改历史主连合约信息
        cont_quotes = self._continuous_table._get_history_cont_quotes(self._trading_day)
        for k, v in cont_quotes.items():
            quotes.setdefault(k, {})  # 实际上,初始行情截面中只有下市合约,没有主连
            quotes[k].update(v)
        self._diffs.append({
            "quotes": quotes,
            "ins_list": "",
            "mdhis_more_data": False,
            "_tqsdk_backtest": self._get_backtest_time()
        })

    async def _send_diff(self):
        """发送数据到 api, 如果 self._diffs 不为空则发送 self._diffs, 不推进行情时间, 否则将时间推进一格, 并发送对应的行情"""
        if self._pending_peek:
            if not self._diffs:
                quotes = await self._generator_diffs(False)
            else:
                quotes = await self._generator_diffs(True)
            for ins, diff in quotes.items():
                self._quotes[ins]["sended_init_quote"] = True
                for d in diff:
                    self._diffs.append({
                        "quotes": {
                            ins: d
                        }
                    })
            if self._diffs:
                # 发送数据集中添加 backtest 字段,开始时间、结束时间、当前时间,表示当前行情推进是由 backtest 推进
                self._diffs.append({"_tqsdk_backtest": self._get_backtest_time()})

                # 切换交易日,将历史的主连合约信息添加的 diffs
                if self._current_dt > self._trading_day_end:
                    # 使用交易日结束时间,每个交易日切换只需要计算一次交易日结束时间
                    # 相比发送 diffs 前每次都用 _current_dt 计算当前交易日,计算次数更少
                    self._trading_day = _get_trading_day_from_timestamp(self._current_dt)
                    self._trading_day_start = _get_trading_day_start_time(self._trading_day)
                    self._trading_day_end = _get_trading_day_end_time(self._trading_day)
                    self._diffs.append({
                        "quotes": self._continuous_table._get_history_cont_quotes(self._trading_day)
                    })
                    self._diffs.append({
                        "quotes": self._stock_dividend._get_dividend(self._data.get('quotes'), self._trading_day)
                    })
                    self._diffs.append({
                        "quotes": {k: {'expired': v.get('expire_datetime', float('nan')) <= self._trading_day_start}
                                   for k, v in self._data.get('quotes').items()}
                    })

                self._sim_recv_chan_send_count += 1
                if self._sim_recv_chan_send_count > 10000:
                    self._sim_recv_chan_send_count = 0
                    self._diffs.append(self._gc_data())
                rtn_data = {
                    "aid": "rtn_data",
                    "data": self._diffs,
                }
                self._diffs = []
                self._pending_peek = False
                await self._sim_recv_chan.send(rtn_data)

    async def _generator_diffs(self, keep_current):
        """
        keep_current 为 True 表示不会推进行情,为 False 表示需要推进行情
        即 self._diffs 为 None 并且 keep_current = True 会推进行情
        """
        quotes = {}
        while self._generators:
            # self._generators 存储了 generator,self._serials 记录一些辅助的信息
            min_request_key = min(self._generators.keys(), key=lambda serial: self._serials[serial]["timestamp"])
            timestamp = self._serials[min_request_key]["timestamp"]  # 所有已订阅数据中的最小行情时间
            quotes_diff = self._serials[min_request_key]["quotes"]
            if timestamp < self._current_dt and self._quotes.get(min_request_key[0], {}).get("sended_init_quote"):
                # 先订阅 A 合约,再订阅 A 合约日线,那么 A 合约的行情时间会回退: 2021-01-04 09:31:59.999999 -> 2021-01-01 18:00:00.000000
                # 如果当前 timestamp 小于 _current_dt,那么这个 quote_diff 不需要发到下游
                # 如果先订阅 A 合约(有夜盘),时间停留在夜盘开始时间, 再订阅 B 合约(没有夜盘),那么 B 合约的行情(前一天收盘时间)应该发下去,
                # 否则 get_quote(B) 等到收到行情才返回,会直接把时间推进到第二天白盘。
                quotes_diff = None
            # 推进时间,一次只会推进最多一个(补数据时有可能是0个)行情时间,并确保<=该行情时间的行情都被发出
            # 如果行情时间大于当前回测时间 则 判断是否diff中已有数据;否则表明此行情时间的数据未全部保存在diff中,则继续append
            if timestamp > self._current_dt:
                if self._diffs or keep_current:  # 如果diffs中已有数据:退出循环并发送数据给下游api
                    break
                else:
                    self._current_dt = timestamp  # 否则将回测时间更新至最新行情时间
            diff = self._serials[min_request_key]["diff"]
            self._diffs.append(diff)
            # klines 请求,需要记录已经发送 api 的数据
            for symbol in diff.get("klines", {}):
                for dur in diff["klines"][symbol]:
                    for kid in diff["klines"][symbol][dur]["data"]:
                        rs = self._sended_to_api.setdefault((symbol, int(dur)), [])
                        kid = int(kid)
                        self._sended_to_api[(symbol, int(dur))] = _rangeset_range_union(rs, (kid, kid + 1))
            quote_info = self._quotes[min_request_key[0]]
            if quotes_diff and (quote_info["min_duration"] != 0 or min_request_key[1] == 0):
                quotes[min_request_key[0]] = quotes_diff
            await self._fetch_serial(min_request_key)
        if self._had_any_generator and not self._generators and not self._diffs:  # 当无可发送数据时则抛出BacktestFinished例外,包括未订阅任何行情 或 所有已订阅行情的最后一笔行情获取完成
            self._api._print("回测结束")
            self._logger.debug("backtest finished")
            if self._current_dt < self._end_dt:
                self._current_dt = 2145888000000000000  # 一个远大于 end_dt 的日期 20380101
            await self._sim_recv_chan.send({
                "aid": "rtn_data",
                "data": [{"_tqsdk_backtest": self._get_backtest_time()}]
            })
            await self._api._wait_until_idle()
            raise BacktestFinished(self._api) from None
        return quotes

    def _get_backtest_time(self) -> dict:
        if self._is_first_send:
            self._is_first_send = False
            return {
                    "start_dt": self._start_dt,
                    "current_dt": self._current_dt,
                    "end_dt": self._end_dt
                }
        else:
            return {
                "current_dt": self._current_dt
            }

    async def _ensure_serial(self, ins, dur, chart_id=None):
        if (ins, dur) not in self._serials:
            quote = self._quotes.setdefault(ins, {  # 在此处设置 min_duration: 每次生成K线的时候会自动生成quote, 记录某一合约的最小duration
                "min_duration": dur
            })
            quote["min_duration"] = min(quote["min_duration"], dur)
            self._serials[(ins, dur)] = {
                "chart_id_set": {chart_id} if chart_id else set()  # 记录当前 serial 对应的 chart_id
            }
            self._generators[(ins, dur)] = self._gen_serial(ins, dur)
            self._had_any_generator = True
            await self._fetch_serial((ins, dur))
        elif chart_id:
            self._serials[(ins, dur)]["chart_id_set"].add(chart_id)

    async def _ensure_query(self, pack):
        """一定收到了对应 query 返回的包"""
        query_pack = {"query": pack["query"]}
        if query_pack.items() <= self._data.get("symbols", {}).get(pack["query_id"], {}).items():
            return
        async with TqChan(self._api, last_only=True) as update_chan:
            self._data["_listener"].add(update_chan)
            while not query_pack.items() <= self._data.get("symbols", {}).get(pack["query_id"], {}).items():
                await update_chan.recv()

    async def _ensure_quote(self, ins):
        # 在接新版合约服务器后,合约信息程序运行过程中查询得到的,这里不再能保证合约一定存在,需要添加 quote 默认值
        quote = _get_obj(self._data, ["quotes", ins], BtQuote(self._api))
        if math.isnan(quote.get("price_tick")):
            query_pack = _query_for_quote(ins)
            await self._md_send_chan.send(query_pack)
            async with TqChan(self._api, last_only=True) as update_chan:
                quote["_listener"].add(update_chan)
                while math.isnan(quote.get("price_tick")):
                    await update_chan.recv()
        if ins not in self._quotes or self._quotes[ins]["min_duration"] > 60000000000:
            await self._ensure_serial(ins, 60000000000)

    async def _fetch_serial(self, key):
        s = self._serials[key]
        try:
            s["timestamp"], s["diff"], s["quotes"] = await self._generators[key].__anext__()
        except StopAsyncIteration:
            del self._generators[key]  # 删除一个行情时间超过结束时间的 generator

    async def _gen_serial(self, ins, dur):
        """k线/tick 序列的 async generator, yield 出来的行情数据带有时间戳, 因此 _send_diff 可以据此归并"""
        # 先定位左端点, focus_datetime 是 lower_bound ,这里需要的是 upper_bound
        # 因此将 view_width 和 focus_position 设置成一样,这样 focus_datetime 所对应的 k线刚好位于屏幕外
        # 使用两个长度为 8964 的 chart,去缓存/回收下游需要的数据
        chart_id_a = _generate_uuid("PYSDK_backtest")
        chart_id_b = _generate_uuid("PYSDK_backtest")
        chart_info = {
            "aid": "set_chart",
            "chart_id": chart_id_a,
            "ins_list": ins,
            "duration": dur,
            "view_width": 8964,  # 设为8964原因:可满足用户所有的订阅长度,并在backtest中将所有的 相同合约及周期 的K线用同一个serial存储
            "focus_datetime": int(self._current_dt),
            "focus_position": 8964,
        }
        chart_a = _get_obj(self._data, ["charts", chart_id_a])
        chart_b = _get_obj(self._data, ["charts", chart_id_b])
        symbol_list = ins.split(',')
        current_id = None  # 当前数据指针
        if dur == 0:
            serials = [_get_obj(self._data, ["ticks", symbol_list[0]])]
        else:
            serials = [_get_obj(self._data, ["klines", s, str(dur)]) for s in symbol_list]
        async with TqChan(self._api, last_only=True) as update_chan:
            for serial in serials:
                serial["_listener"].add(update_chan)
            chart_a["_listener"].add(update_chan)
            chart_b["_listener"].add(update_chan)
            await self._md_send_chan.send(chart_info.copy())
            try:
                async for _ in update_chan:
                    chart = _get_obj(self._data, ["charts", chart_info["chart_id"]])
                    if not (chart_info.items() <= _get_obj(chart, ["state"]).items()):
                        # 当前请求还没收齐回应, 不应继续处理
                        continue
                    left_id = chart.get("left_id", -1)
                    right_id = chart.get("right_id", -1)
                    if (left_id == -1 and right_id == -1) or chart.get("more_data", True):
                        continue  # 定位信息还没收到, 数据没有完全收到
                    last_id = serials[0].get("last_id", -1)
                    if last_id == -1:
                        continue  # 数据序列还没收到
                    if self._data.get("mdhis_more_data", True):
                        self._data["_listener"].add(update_chan)
                        continue
                    else:
                        self._data["_listener"].discard(update_chan)
                    if current_id is None:
                        current_id = max(left_id, 0)
                    # 发送下一段 chart 8964 根 kline
                    chart_info["chart_id"] = chart_id_b if chart_info["chart_id"] == chart_id_a else chart_id_a
                    chart_info["left_kline_id"] = right_id
                    chart_info.pop("focus_datetime", None)
                    chart_info.pop("focus_position", None)
                    await self._md_send_chan.send(chart_info.copy())
                    while True:
                        if current_id > last_id:
                            # 当前 id 已超过 last_id
                            return
                        # 将订阅的8964长度的窗口中的数据都遍历完后,退出循环,然后再次进入并处理下一窗口数据
                        if current_id > right_id:
                            break
                        item = {k: v for k, v in serials[0]["data"].get(str(current_id), {}).items()}
                        if dur == 0:
                            diff = {
                                "ticks": {
                                    ins: {
                                        "last_id": current_id,
                                        "data": {
                                            str(current_id): item,
                                            str(current_id - 8964): None,
                                        }
                                    }
                                }
                            }
                            if item["datetime"] > self._end_dt:  # 超过结束时间
                                return
                            yield item["datetime"], diff, self._get_quotes_from_tick(item)
                        else:
                            timestamp = item["datetime"] if dur < 86400000000000 else _get_trading_day_start_time(
                                item["datetime"])
                            if timestamp > self._end_dt:  # 超过结束时间
                                return
                            binding = serials[0].get("binding", {})
                            diff = {
                                "klines": {
                                    symbol_list[0]: {
                                        str(dur): {
                                            "last_id": current_id,
                                            "data": {
                                                str(current_id): {
                                                    "datetime": item["datetime"],
                                                    "open": item["open"],
                                                    "high": item["open"],
                                                    "low": item["open"],
                                                    "close": item["open"],
                                                    "volume": 0,
                                                    "open_oi": item["open_oi"],
                                                    "close_oi": item["open_oi"],
                                                }
                                            }
                                        }
                                    }
                                }
                            }
                            for chart_id in self._serials[(ins, dur)]["chart_id_set"]:
                                diff["charts"] = {
                                    chart_id: {
                                        "right_id": current_id  # api 中处理多合约 kline 需要 right_id 信息
                                    }
                                }
                            for i, symbol in enumerate(symbol_list):
                                if i == 0:
                                    diff_binding = diff["klines"][symbol_list[0]][str(dur)].setdefault("binding", {})
                                    continue
                                other_id = binding.get(symbol, {}).get(str(current_id), -1)
                                if other_id >= 0:
                                    diff_binding[symbol] = {str(current_id): str(other_id)}
                                    other_item = serials[i]["data"].get(str(other_id), {})
                                    diff["klines"][symbol] = {
                                        str(dur): {
                                            "last_id": other_id,
                                            "data": {
                                                str(other_id): {
                                                    "datetime": other_item["datetime"],
                                                    "open": other_item["open"],
                                                    "high": other_item["open"],
                                                    "low": other_item["open"],
                                                    "close": other_item["open"],
                                                    "volume": 0,
                                                    "open_oi": other_item["open_oi"],
                                                    "close_oi": other_item["open_oi"],
                                                }
                                            }
                                        }
                                    }
                            yield timestamp, diff, self._get_quotes_from_kline_open(
                                self._data["quotes"][symbol_list[0]],
                                timestamp,
                                item)  # K线刚生成时的数据都为开盘价
                            timestamp = item["datetime"] + dur - 1000 \
                                if dur < 86400000000000 else _get_trading_day_start_time(item["datetime"] + dur) - 1000
                            if timestamp > self._end_dt:  # 超过结束时间
                                return
                            diff = {
                                "klines": {
                                    symbol_list[0]: {
                                        str(dur): {
                                            "data": {
                                                str(current_id): item,
                                            }
                                        }
                                    }
                                }
                            }
                            for i, symbol in enumerate(symbol_list):
                                if i == 0:
                                    continue
                                other_id = binding.get(symbol, {}).get(str(current_id), -1)
                                if other_id >= 0:
                                    diff["klines"][symbol] = {
                                        str(dur): {
                                            "data": {
                                                str(other_id): {k: v for k, v in
                                                                serials[i]["data"].get(str(other_id), {}).items()}
                                            }
                                        }
                                    }
                            yield timestamp, diff, self._get_quotes_from_kline(self._data["quotes"][symbol_list[0]],
                                                                               timestamp,
                                                                               item)  # K线结束时生成quote数据
                        current_id += 1
            finally:
                # 释放chart资源
                chart_info["ins_list"] = ""
                await self._md_send_chan.send(chart_info.copy())
                chart_info["chart_id"] = chart_id_b if chart_info["chart_id"] == chart_id_a else chart_id_a
                await self._md_send_chan.send(chart_info.copy())

    def _gc_data(self):
        # api 应该删除的数据 diff
        need_rangeset = {}
        for ins, dur in self._serials:
            if dur == 0:  # tick 在发送数据过程中已经回收内存
                continue
            symbol_list = ins.split(',')
            for s in symbol_list:
                need_rangeset.setdefault((s, dur), [])
            main_serial = _get_obj(self._data, ["klines", symbol_list[0], str(dur)])
            main_serial_rangeset = self._sended_to_api.get((symbol_list[0], dur), [])  # 此 request 还没有给 api 发送过任何数据时为 []
            if not main_serial_rangeset:
                continue
            last_id = main_serial_rangeset[-1][-1] - 1
            assert last_id > -1
            need_rangeset[(symbol_list[0], dur)] = _rangeset_range_union(need_rangeset[(symbol_list[0], dur)],
                                                                         (last_id - 8963, last_id + 1))
            for symbol in symbol_list[1:]:
                symbol_need_rangeset = []
                symbol_binding = main_serial.get("binding", {}).get(symbol, {})
                if symbol_binding:
                    for i in range(last_id - 8963, last_id + 1):
                        other_id = symbol_binding.get(str(i))
                        if other_id:
                            symbol_need_rangeset = _rangeset_range_union(symbol_need_rangeset, (other_id, other_id + 1))
                if symbol_need_rangeset:
                    need_rangeset[(symbol, dur)] = _rangeset_union(need_rangeset[(symbol, dur)], symbol_need_rangeset)

        gc_rangeset = {}
        for key, rs in self._sended_to_api.items():
            gc_rangeset[key] = _rangeset_difference(rs, need_rangeset.get(key, []))

        # 更新 self._sended_to_api
        for key, rs in gc_rangeset.items():
            self._sended_to_api[key] = _rangeset_difference(self._sended_to_api[key], rs)

        gc_klines_diff = {}
        for (symbol, dur), rs in gc_rangeset.items():
            gc_klines_diff.setdefault(symbol, {})
            gc_klines_diff[symbol][str(dur)] = {"data": {}}
            serial = _get_obj(self._data, ["klines", symbol, str(dur)])
            serial_binding = serial.get("binding", None)
            if serial_binding:
                gc_klines_diff[symbol][str(dur)]["binding"] = {s: {} for s in serial_binding.keys()}
            for start_id, end_id in rs:
                for i in range(start_id, end_id):
                    gc_klines_diff[symbol][str(dur)]["data"][str(i)] = None
                    if serial_binding:
                        for s, s_binding in serial_binding.items():
                            gc_klines_diff[symbol][str(dur)]["binding"][s][str(i)] = None
        return {"klines": gc_klines_diff}

    @staticmethod
    def _get_quotes_from_tick(tick):
        quote = {k: v for k, v in tick.items()}
        quote["datetime"] = _timestamp_nano_to_str(tick["datetime"])
        return [quote]

    @staticmethod
    def _get_quotes_from_kline_open(info, timestamp, kline):
        return [
            {  # K线刚生成时的数据都为开盘价
                "datetime": _timestamp_nano_to_str(timestamp),
                "ask_price1": kline["open"] + info["price_tick"],
                "ask_volume1": 1,
                "bid_price1": kline["open"] - info["price_tick"],
                "bid_volume1": 1,
                "last_price": kline["open"],
                "highest": float("nan"),
                "lowest": float("nan"),
                "average": float("nan"),
                "volume": 0,
                "amount": float("nan"),
                "open_interest": kline["open_oi"],
            },
        ]

    @staticmethod
    def _get_quotes_from_kline(info, timestamp, kline):
        """
        分为三个包发给下游:
        1. 根据 diff 协议,对于用户收到的最终结果没有影响
        2. TqSim 撮合交易会按顺序处理收到的包,分别比较 high、low、close 三个价格对应的买卖价
        3. TqSim 撮合交易只用到了买卖价,所以最新价只产生一次 close,而不会发送三次
        """
        return [
            {
                "datetime": _timestamp_nano_to_str(timestamp),
                "ask_price1": kline["high"] + info["price_tick"],
                "ask_volume1": 1,
                "bid_price1": kline["high"] - info["price_tick"],
                "bid_volume1": 1,
                "last_price": kline["close"],
                "highest": float("nan"),
                "lowest": float("nan"),
                "average": float("nan"),
                "volume": 0,
                "amount": float("nan"),
                "open_interest": kline["close_oi"],
            },
            {
                "ask_price1": kline["low"] + info["price_tick"],
                "bid_price1": kline["low"] - info["price_tick"],
            },
            {
                "ask_price1": kline["close"] + info["price_tick"],
                "bid_price1": kline["close"] - info["price_tick"],
            }
        ]
Пример #14
0
class BaseSim(Tradeable):
    def __init__(
            self, account_id, init_balance,
            trade_class: Union[Type[SimTrade], Type[SimTradeStock]]) -> None:
        self._account_id = account_id
        super(BaseSim, self).__init__()

        self.trade_log = {}  # 日期->交易记录及收盘时的权益及持仓
        self.tqsdk_stat = {}  # 回测结束后储存回测报告信息
        self._init_balance = init_balance
        self._current_datetime = "1990-01-01 00:00:00.000000"  # 当前行情时间(最新的 quote 时间)
        self._trading_day_end = "1990-01-01 18:00:00.000000"
        self._local_time_record = float("nan")  # 记录获取最新行情时的本地时间
        self._sim_trade = trade_class(
            account_key=self._account_key,
            account_id=self._account_id,
            init_balance=self._init_balance,
            get_trade_timestamp=self._get_trade_timestamp,
            is_in_trading_time=self._is_in_trading_time)
        self._data = Entity()
        self._data._instance_entity([])
        self._prototype = {
            "quotes": {
                "#": Quote(self),  # 行情的数据原型
            }
        }
        self._quote_tasks = {}

    @property
    def _account_name(self):
        return self._account_id

    @property
    def _account_info(self):
        info = super(BaseSim, self)._account_info
        info.update({"account_id": self._account_id})
        return info

    async def _run(self, api, api_send_chan, api_recv_chan, md_send_chan,
                   md_recv_chan):
        """模拟交易task"""
        self._api = api
        self._tqsdk_backtest = {}  # 储存可能的回测信息
        self._logger = api._logger.getChild("TqSim")  # 调试信息输出
        self._api_send_chan = api_send_chan
        self._api_recv_chan = api_recv_chan
        self._md_send_chan = md_send_chan
        self._md_recv_chan = md_recv_chan
        # True 下游发过 subscribe,但是没有转发给上游;False 表示下游发的 subscribe 都转发给上游
        self._pending_subscribe_downstream = False
        # True 发给上游 subscribe,但是没有收到过回复;False 如果行情不变,上游不会回任何包
        self._pending_subscribe_upstream = False
        self._all_subscribe = set()  # 客户端+模拟交易模块订阅的合约集合
        # 是否已经发送初始账户信息
        self._has_send_init_account = False
        try:
            await super(BaseSim, self)._run(api, api_send_chan, api_recv_chan,
                                            md_send_chan, md_recv_chan)
        finally:
            self._handle_stat_report()
            for s in self._quote_tasks:
                self._quote_tasks[s]["task"].cancel()
            await asyncio.gather(
                *[self._quote_tasks[s]["task"] for s in self._quote_tasks],
                return_exceptions=True)

    async def _handle_recv_data(self, pack, chan):
        """
        处理所有上游收到的数据包,这里应该将需要发送给下游的数据 append 到 self._diffs
        pack: 收到的数据包
        chan: 收到此数据包的 channel
        """
        self._pending_subscribe_upstream = False
        if pack["aid"] == "rtn_data":
            self._md_recv(
                pack)  # md_recv 中会发送 wait_count 个 quotes 包给各个 quote_chan
            await asyncio.gather(*[
                quote_task["quote_chan"].join()
                for quote_task in self._quote_tasks.values()
            ])
        if self._tqsdk_backtest != {} and self._tqsdk_backtest[
                "current_dt"] >= self._tqsdk_backtest["end_dt"]:
            # 回测情况下,把 _handle_stat_report 在循环中回测结束时执行
            self._handle_stat_report()

    async def _handle_req_data(self, pack):
        """
        处理所有下游发送的非 peek_message 数据包
        这里应该将发送的请求转发到指定的某个上游 channel
        """
        if self._is_self_trade_pack(pack):
            if pack["aid"] == "insert_order":
                symbol = pack["exchange_id"] + "." + pack["instrument_id"]
                if symbol not in self._quote_tasks:
                    quote_chan = TqChan(self._api)
                    order_chan = TqChan(self._api)
                    self._quote_tasks[symbol] = {
                        "quote_chan":
                        quote_chan,
                        "order_chan":
                        order_chan,
                        "task":
                        self._api.create_task(
                            self._quote_handler(symbol, quote_chan,
                                                order_chan))
                    }
                await self._quote_tasks[symbol]["order_chan"].send(pack)
            else:
                # pack 里只有 order_id 信息,发送到每一个合约的 order_chan, 交由 quote_task 判断是不是当前合约下的委托单
                for symbol in self._quote_tasks:
                    await self._quote_tasks[symbol]["order_chan"].send(pack)
        elif pack["aid"] == "subscribe_quote":
            # 这里只会增加订阅合约,不会退订合约
            await self._subscribe_quote(set(pack["ins_list"].split(",")))
        else:
            await self._md_send_chan.send(pack)

    async def _on_send_diff(self, pending_peek):
        if pending_peek and self._pending_subscribe_downstream:
            await self._send_subscribe_quote()

    async def _subscribe_quote(self, symbols: [set, str]):
        """
        这里只会增加订阅合约,不会退订合约
        todo: 这里用到了 self._pending_peek ,父类的内部变量
        """
        symbols = symbols if isinstance(symbols, set) else {symbols}
        if symbols - self._all_subscribe:
            self._all_subscribe |= symbols
            if self._pending_peek and not self._pending_subscribe_upstream:
                await self._send_subscribe_quote()
            else:
                self._pending_subscribe_downstream = True

    async def _send_subscribe_quote(self):
        self._pending_subscribe_upstream = True
        self._pending_subscribe_downstream = False
        await self._md_send_chan.send({
            "aid": "subscribe_quote",
            "ins_list": ",".join(self._all_subscribe)
        })

    def _handle_stat_report(self):
        if self.tqsdk_stat:
            return
        self._settle()
        self._report()
        self._diffs.append({
            "trade": {
                self._account_key: {
                    "accounts": {
                        "CNY": {
                            "_tqsdk_stat": self.tqsdk_stat
                        }
                    }
                }
            }
        })

    async def _ensure_quote_info(self, symbol, quote_chan):
        """quote收到合约信息后返回"""
        quote = _get_obj(self._data, ["quotes", symbol], Quote(self._api))
        if quote.get("price_tick") == quote.get("price_tick"):
            return quote.copy()
        if quote.get("price_tick") != quote.get("price_tick"):
            await self._md_send_chan.send(_query_for_quote(symbol))
        async for _ in quote_chan:
            quote_chan.task_done()
            if quote.get("price_tick") == quote.get("price_tick"):
                return quote.copy()

    async def _ensure_quote(self, symbol, quote_chan):
        """quote收到行情以及合约信息后返回"""
        quote = _get_obj(self._data, ["quotes", symbol], Quote(self._api))
        _register_update_chan(quote, quote_chan)
        if quote.get(
                "datetime",
                "") and quote.get("price_tick") == quote.get("price_tick"):
            return quote.copy()
        if quote.get("price_tick") != quote.get("price_tick"):
            # 对于没有合约信息的 quote,发送查询合约信息的请求
            await self._md_send_chan.send(_query_for_quote(symbol))
        async for _ in quote_chan:
            quote_chan.task_done()
            if quote.get(
                    "datetime",
                    "") and quote.get("price_tick") == quote.get("price_tick"):
                return quote.copy()

    async def _quote_handler(self, symbol, quote_chan, order_chan):
        try:
            await self._subscribe_quote(symbol)
            quote = await self._ensure_quote(symbol, quote_chan)
            if quote["ins_class"].endswith(
                    "INDEX") and quote["exchange_id"] == "KQ":
                # 指数可以交易,需要补充 margin commission
                if "margin" not in quote:
                    quote_m = await self._ensure_quote_info(
                        symbol.replace("KQ.i", "KQ.m"), quote_chan)
                    quote_underlying = await self._ensure_quote_info(
                        quote_m["underlying_symbol"], quote_chan)
                    self._data["quotes"][symbol]["margin"] = quote_underlying[
                        "margin"]
                    self._data["quotes"][symbol][
                        "commission"] = quote_underlying["commission"]
                    quote.update(self._data["quotes"][symbol])
            underlying_quote = None
            if quote["ins_class"].endswith("OPTION"):
                # 如果是期权,订阅标的合约行情,确定收到期权标的合约行情
                underlying_symbol = quote["underlying_symbol"]
                await self._subscribe_quote(underlying_symbol)
                underlying_quote = await self._ensure_quote(
                    underlying_symbol, quote_chan)  # 订阅合约
            # 在等待标的行情的过程中,quote_chan 可能有期权行情,把 quote_chan 清空,并用最新行情更新 quote
            while not quote_chan.empty():
                quote_chan.recv_nowait()
                quote_chan.task_done()

            # 用最新行情更新 quote
            quote.update(self._data["quotes"][symbol])
            if underlying_quote:
                underlying_quote.update(
                    self._data["quotes"][underlying_symbol])
            task = self._api.create_task(
                self._forward_chan_handler(order_chan, quote_chan))
            quotes = {symbol: quote}
            if underlying_quote:
                quotes[underlying_symbol] = underlying_quote
            self._sim_trade.update_quotes(symbol, {"quotes": quotes})
            async for pack in quote_chan:
                if "aid" not in pack:
                    diffs, orders_events = self._sim_trade.update_quotes(
                        symbol, pack)
                    self._handle_diffs(diffs, orders_events, "match order")
                elif pack["aid"] == "insert_order":
                    diffs, orders_events = self._sim_trade.insert_order(
                        symbol, pack)
                    self._handle_diffs(diffs, orders_events, "insert order")
                elif pack["aid"] == "cancel_order":
                    diffs, orders_events = self._sim_trade.cancel_order(
                        symbol, pack)
                    self._handle_diffs(diffs, orders_events, "cancel order")
                quote_chan.task_done()
        finally:
            await quote_chan.close()
            await order_chan.close()
            task.cancel()
            await asyncio.gather(task, return_exceptions=True)

    async def _forward_chan_handler(self, chan_from, chan_to):
        async for pack in chan_from:
            await chan_to.send(pack)

    def _md_recv(self, pack):
        for d in pack["data"]:
            self._diffs.append(d)
            # 在第一次收到 mdhis_more_data 为 False 的时候,发送账户初始截面信息,这样回测模式下,往后的模块才有正确的时间顺序
            if not self._has_send_init_account and not d.get(
                    "mdhis_more_data", True):
                self._diffs.append(self._sim_trade.init_snapshot())
                self._diffs.append(
                    {"trade": {
                        self._account_key: {
                            "trade_more_data": False
                        }
                    }})
                self._has_send_init_account = True
            _tqsdk_backtest = d.get("_tqsdk_backtest", {})
            if _tqsdk_backtest:
                # 回测时,用 _tqsdk_backtest 对象中 current_dt 作为 TqSim 的 _current_datetime
                self._tqsdk_backtest.update(_tqsdk_backtest)
                self._current_datetime = _timestamp_nano_to_str(
                    self._tqsdk_backtest["current_dt"])
                self._local_time_record = float("nan")
                # 1. 回测时不使用时间差来模拟交易所时间的原因(_local_time_record始终为初始值nan):
                #   在sim收到行情后记录_local_time_record,然后下发行情到api进行merge_diff(),api需要处理完k线和quote才能结束wait_update(),
                #   若处理时间过长,此时下单则在判断下单时间时与测试用例中的预期时间相差较大,导致测试用例无法通过。
                # 2. 回测不使用时间差的方法来判断下单时间仍是可行的: 与使用了时间差的方法相比, 只对在每个交易时间段最后一笔行情时的下单时间判断有差异,
                #   若不使用时间差, 则在最后一笔行情时下单仍判断为在可交易时间段内, 且可成交.
            quotes_diff = d.get("quotes", {})
            # 先根据 quotes_diff 里的 datetime, 确定出 _current_datetime,再 _merge_diff(同时会发送行情到 quote_chan)
            for symbol, quote_diff in quotes_diff.items():
                if quote_diff is None:
                    continue
                # 若直接使用本地时间来判断下单时间是否在可交易时间段内 可能有较大误差,因此判断的方案为:(在接收到下单指令时判断 估计的交易所时间 是否在交易时间段内)
                # 在更新最新行情时间(即self._current_datetime)时,记录当前本地时间(self._local_time_record),
                # 在这之后若收到下单指令,则获取当前本地时间,判 "最新行情时间 + (当前本地时间 - 记录的本地时间)" 是否在交易时间段内。
                # 另外, 若在盘后下单且下单前未订阅此合约:
                # 因为从_md_recv()中获取数据后立即判断下单时间则速度过快(两次time.time()的时间差小于最后一笔行情(14:59:9995)到15点的时间差),
                # 则会立即成交,为处理此情况则将当前时间减去5毫秒(模拟发生5毫秒网络延迟,则两次time.time()的时间差增加了5毫秒)。
                # todo: 按交易所来存储 _current_datetime(issue: #277)
                if quote_diff.get("datetime", "") > self._current_datetime:
                    # 回测时,当前时间更新即可以由 quote 行情更新,也可以由 _tqsdk_backtest.current_dt 更新,
                    # 在最外层的循环里,_tqsdk_backtest.current_dt 是在 rtn_data.data 中数组位置中的最后一个,会在循环最后一个才更新 self.current_datetime
                    # 导致前面处理 order 时的 _current_datetime 还是旧的行情时间
                    self._current_datetime = quote_diff["datetime"]  # 最新行情时间
                    # 更新最新行情时间时的本地时间,回测时不使用时间差
                    self._local_time_record = (
                        time.time() -
                        0.005) if not self._tqsdk_backtest else float("nan")
                if self._current_datetime > self._trading_day_end:  # 结算
                    self._settle()
                    # 若当前行情时间大于交易日的结束时间(切换交易日),则根据此行情时间更新交易日及交易日结束时间
                    trading_day = _get_trading_day_from_timestamp(
                        self._get_current_timestamp())
                    self._trading_day_end = _timestamp_nano_to_str(
                        _get_trading_day_end_time(trading_day) - 999)
            if quotes_diff:
                _merge_diff(self._data, {"quotes": quotes_diff},
                            self._prototype,
                            persist=False,
                            reduce_diff=False,
                            notify_update_diff=True)

    def _handle_diffs(self, diffs, orders_events, msg):
        """
        处理 sim_trade 返回的 diffs
        orders_events 为持仓变更事件,依次屏幕输出信息,打印日志
        """
        self._diffs += diffs
        for order in orders_events:
            if order["status"] == "FINISHED":
                self._handle_on_finished(msg, order)
            else:
                assert order["status"] == "ALIVE"
                self._handle_on_alive(msg, order)

    def _settle(self):
        if self._trading_day_end[:10] == "1990-01-01":
            return
        # 结算并记录账户截面
        diffs, orders_events, trade_log = self._sim_trade.settle()
        self._handle_diffs(diffs, orders_events, "settle")
        self.trade_log[self._trading_day_end[:10]] = trade_log

    @abstractmethod
    def _handle_on_alive(self, msg, order):
        """
        在 order 状态变为 ALIVE 调用,屏幕输出信息,打印日志
        """
        pass

    @abstractmethod
    def _handle_on_finished(self, msg, order):
        """
        在 order 状态变为 FINISHED 调用,屏幕输出信息,打印日志
        """
        pass

    @abstractmethod
    def _report(self):
        pass

    def _get_current_timestamp(self):
        return _str_to_timestamp_nano(self._current_datetime)

    def _get_trade_timestamp(self):
        return _get_trade_timestamp(self._current_datetime,
                                    self._local_time_record)

    def _is_in_trading_time(self, quote):
        return _is_in_trading_time(quote, self._current_datetime,
                                   self._local_time_record)
Пример #15
0
class TqReconnect(object):
    def __init__(self, logger):
        self._logger = logger
        self._resend_request = {}  # 重连时需要重发的请求
        self._un_processed = False  # 重连后尚未处理完标志
        self._pending_diffs = []
        self._data = Entity()
        self._data._instance_entity([])

    async def _run(self, api, api_send_chan, api_recv_chan, ws_send_chan, ws_recv_chan):
        self._api = api
        send_task = self._api.create_task(self._send_handler(api_send_chan, ws_send_chan))
        try:
            async for pack in ws_recv_chan:
                self._record_upper_data(pack)
                if self._un_processed:  # 处理重连后数据
                    pack_data = pack.get("data", [])
                    self._pending_diffs.extend(pack_data)
                    for d in pack_data:
                        # _merge_diff 之后, self._data 会用于判断是否接收到了完整截面数据
                        _merge_diff(self._data, d, self._api._prototype, persist=False, reduce_diff=False)
                    if self._is_all_received():
                        # 重连后收到完整数据截面
                        self._un_processed = False
                        pack = {
                            "aid": "rtn_data",
                            "data": self._pending_diffs
                        }
                        await api_recv_chan.send(pack)
                        self._logger = self._logger.bind(status=self._status)
                        self._logger.debug("data completed", pack=pack)
                    else:
                        await ws_send_chan.send({"aid": "peek_message"})
                        self._logger.debug("wait for data completed", pack={"aid": "peek_message"})
                else:
                    is_reconnected = False
                    for i in range(len(pack.get("data", []))):
                        for _, notify in pack["data"][i].get("notify", {}).items():
                            if notify["code"] == 2019112902:  # 重连建立
                                is_reconnected = True
                                self._un_processed = True
                                self._logger = self._logger.bind(status=self._status)
                                if i > 0:
                                    ws_send_chan.send_nowait({
                                        "aid": "rtn_data",
                                        "data": pack.get("data", [])[0:i]
                                    })
                                self._pending_diffs = pack.get("data", [])[i:]
                                break
                    if is_reconnected:
                        self._data = Entity()
                        self._data._instance_entity([])
                        for d in self._pending_diffs:
                            _merge_diff(self._data, d, self._api._prototype, persist=False, reduce_diff=False)
                        # 发送所有 resend_request
                        for msg in self._resend_request.values():
                            # 这里必须用 send_nowait 而不是 send,因为如果使用异步写法,在循环中,代码可能执行到 send_task, 可能会修改 _resend_request
                            ws_send_chan.send_nowait(msg)
                            self._logger.debug("resend request", pack=msg)
                        await ws_send_chan.send({"aid": "peek_message"})
                    else:
                        await api_recv_chan.send(pack)
        finally:
            send_task.cancel()
            await asyncio.gather(send_task, return_exceptions=True)

    async def _send_handler(self, api_send_chan, ws_send_chan):
        async for pack in api_send_chan:
            self._record_lower_data(pack)
            await ws_send_chan.send(pack)

    @property
    def _status(self):
        return "WAIT_FOR_COMPLETED" if self._un_processed else "READY"

    @abstractmethod
    def _is_all_received(self):
        """在重连后判断是否收到了全部的数据,可以继续处理后续的数据包"""
        pass

    def _record_upper_data(self, pack):
        """从上游收到的数据中,记录下重连时需要的数据"""
        pass

    def _record_lower_data(self, pack):
        """从下游收到的数据中,记录下重连时需要的数据"""
        pass
Пример #16
0
class TqSim(object):
    """
    天勤模拟交易类

    该类实现了一个本地的模拟账户,并且在内部完成撮合交易,在回测和复盘模式下,只能使用 TqSim 账户来交易。

    限价单要求报单价格达到或超过对手盘价格才能成交, 成交价为报单价格, 如果没有对手盘(涨跌停)则无法成交

    市价单使用对手盘价格成交, 如果没有对手盘(涨跌停)则自动撤单

    模拟交易不会有部分成交的情况, 要成交就是全部成交
    """

    def __init__(self, init_balance: float = 10000000.0, account_id: str = "TQSIM") -> None:
        """
        Args:
            init_balance (float): [可选]初始资金, 默认为一千万

            account_id (str): [可选]帐号, 默认为 "TQSIM"

        Example::

            # 修改TqSim模拟帐号的初始资金为100000
            from tqsdk import TqApi, TqSim
            api = TqApi(TqSim(init_balance=100000))

        """
        self.trade_log = {}  # 日期->交易记录及收盘时的权益及持仓
        self._account_id = account_id
        self._init_balance = float(init_balance)
        if self._init_balance <= 0:
            raise Exception("初始资金(init_balance) %s 错误, 请检查 init_balance 是否填写正确" % (init_balance))
        self._current_datetime = "1990-01-01 00:00:00.000000"  # 当前行情时间(最新的 quote 时间)
        self._trading_day_end = "1990-01-01 18:00:00.000000"
        self._local_time_record = float("nan")  # 记录获取最新行情时的本地时间

    async def _run(self, api, api_send_chan, api_recv_chan, md_send_chan, md_recv_chan):
        """模拟交易task"""
        self._api = api
        self._tqsdk_backtest = {}  # 储存可能的回测信息
        self._tqsdk_stat = {}  # 回测结束后储存回测报告信息
        self._logger = api._logger.getChild("TqSim")  # 调试信息输出
        self._api_send_chan = api_send_chan
        self._api_recv_chan = api_recv_chan
        self._md_send_chan = md_send_chan
        self._md_recv_chan = md_recv_chan
        self._pending_peek = False
        self._diffs = []
        self._account = {
            "currency": "CNY",
            "pre_balance": self._init_balance,
            "static_balance": self._init_balance,
            "balance": self._init_balance,
            "available": self._init_balance,
            "float_profit": 0.0,
            "position_profit": 0.0,  # 期权没有持仓盈亏
            "close_profit": 0.0,
            "frozen_margin": 0.0,
            "margin": 0.0,
            "frozen_commission": 0.0,
            "commission": 0.0,
            "frozen_premium": 0.0,
            "premium": 0.0,
            "deposit": 0.0,
            "withdraw": 0.0,
            "risk_ratio": 0.0,
            "market_value": 0.0,
            "ctp_balance": float("nan"),
            "ctp_available": float("nan"),
        }
        self._positions = {}
        self._orders = {}
        self._data = Entity()
        self._data._instance_entity([])
        self._prototype = {
            "quotes": {
                "#": Quote(self),  # 行情的数据原型
            }
        }
        self._quote_tasks = {}
        self._all_subscribe = set()  # 客户端+模拟交易模块订阅的合约集合
        # 是否已经发送初始账户信息
        self._has_send_init_account = False
        md_task = self._api.create_task(self._md_handler())  # 将所有 md_recv_chan 上收到的包投递到 api_send_chan 上
        try:
            async for pack in self._api_send_chan:
                self._logger.debug("TqSim message received: %s", pack)
                if "_md_recv" in pack:
                    if pack["aid"] == "rtn_data":
                        self._md_recv(pack)  # md_recv 中会发送 wait_count 个 quotes 包给各个 quote_chan
                        await asyncio.gather(*[quote_task["quote_chan"].join() for quote_task in self._quote_tasks.values()])
                        await self._send_diff()
                elif pack["aid"] == "subscribe_quote":
                    await self._subscribe_quote(set(pack["ins_list"].split(",")))
                elif pack["aid"] == "peek_message":
                    self._pending_peek = True
                    await self._send_diff()
                    if self._pending_peek:  # 控制"peek_message"发送: 当没有新的事件需要用户处理时才推进到下一个行情
                        await self._md_send_chan.send(pack)
                elif pack["aid"] == "insert_order":
                    symbol = pack["exchange_id"] + "." + pack["instrument_id"]
                    if symbol not in self._quote_tasks:
                        quote_chan = TqChan(self._api)
                        order_chan = TqChan(self._api)
                        self._quote_tasks[symbol] = {
                            "quote_chan": quote_chan,
                            "order_chan": order_chan,
                            "task": self._api.create_task(self._quote_handler(symbol, quote_chan, order_chan))
                        }
                    await self._quote_tasks[symbol]["order_chan"].send(pack)
                elif pack["aid"] == "cancel_order":
                    # pack 里只有 order_id 信息,发送到每一个合约的 order_chan, 交由 quote_task 判断是不是当前合约下的委托单
                    for symbol in self._quote_tasks:
                        await self._quote_tasks[symbol]["order_chan"].send(pack)
                else:
                    await self._md_send_chan.send(pack)
                if self._tqsdk_backtest != {} and self._tqsdk_backtest["current_dt"] >= self._tqsdk_backtest["end_dt"] \
                        and not self._tqsdk_stat:
                    # 回测情况下,把 _send_stat_report 在循环中回测结束时执行
                    await self._send_stat_report()
        finally:
            if not self._tqsdk_stat:
                await self._send_stat_report()
            md_task.cancel()
            tasks = [md_task]
            for symbol in self._quote_tasks:
                self._quote_tasks[symbol]["task"].cancel()
                tasks.append(self._quote_tasks[symbol]["task"])
            await asyncio.gather(*tasks, return_exceptions=True)

    async def _md_handler(self):
        async for pack in self._md_recv_chan:
            pack["_md_recv"] = True
            await self._api_send_chan.send(pack)

    async def _send_diff(self):
        if self._pending_peek and self._diffs:
            rtn_data = {
                "aid": "rtn_data",
                "data": self._diffs,
            }
            self._diffs = []
            self._pending_peek = False
            self._logger.debug("TqSim message send: %s", rtn_data)
            await self._api_recv_chan.send(rtn_data)

    async def _subscribe_quote(self, symbols: [set, str]):
        """这里只会增加订阅合约,不会退订合约"""
        symbols = symbols if isinstance(symbols, set) else {symbols}
        if symbols - self._all_subscribe:
            self._all_subscribe |= symbols
            await self._md_send_chan.send({
                "aid": "subscribe_quote",
                "ins_list": ",".join(self._all_subscribe)
            })

    async def _send_stat_report(self):
        self._settle()
        self._report()
        await self._api_recv_chan.send({
            "aid": "rtn_data",
            "data": [{
                "trade": {
                    self._account_id: {
                        "accounts": {
                            "CNY": {
                                "_tqsdk_stat": self._tqsdk_stat
                            }
                        }
                    }
                }
            }]
        })

    async def _ensure_quote(self, symbol, quote_chan):
        """quote收到行情后返回"""
        quote = _get_obj(self._data, ["quotes", symbol], Quote(self._api))
        _register_update_chan(quote, quote_chan)
        if quote.get("datetime", ""):
            return quote.copy()
        async for _ in quote_chan:
            quote_chan.task_done()
            if quote.get("datetime", ""):
                return quote.copy()

    async def _quote_handler(self, symbol, quote_chan, order_chan):
        try:
            orders = self._orders.setdefault(symbol, {})
            position = self._normalize_position(symbol)
            self._positions[symbol] = position
            await self._subscribe_quote(symbol)
            quote = await self._ensure_quote(symbol, quote_chan)
            underlying_quote = None
            if quote["ins_class"] in ["FUTURE_OPTION", "OPTION"]:
                # 如果是期权,订阅标的合约行情,确定收到期权标的合约行情
                underlying_symbol = quote["underlying_symbol"]
                await self._subscribe_quote(underlying_symbol)
                underlying_quote = await self._ensure_quote(underlying_symbol, quote_chan)  # 订阅合约
            # 在等待标的行情的过程中,quote_chan 可能有期权行情,把 quote_chan 清空,并用最新行情更新 quote
            while not quote_chan.empty():
                quote_chan.recv_nowait()
                quote_chan.task_done()
            quote.update(self._data["quotes"][symbol])
            if underlying_quote:
                underlying_quote.update(self._data["quotes"][underlying_symbol])
            task = self._api.create_task(self._forward_chan_handler(order_chan, quote_chan))
            async for pack in quote_chan:
                if "aid" not in pack:
                    _simple_merge_diff(quote, pack.get("quotes", {}).get(symbol, {}))
                    if underlying_quote:
                        _simple_merge_diff(underlying_quote, pack.get("quotes", {}).get(underlying_symbol, {}))
                    for order_id in list(orders.keys()):
                        assert orders[order_id]["insert_date_time"] > 0
                        match_msg = self._match_order(orders[order_id], symbol, quote, underlying_quote)
                        if match_msg:
                            self._del_order(symbol, orders[order_id], match_msg)
                            del orders[order_id]
                    self._adjust_position(symbol, price=quote["last_price"])  # 按照合约最新价调整持仓
                elif pack["aid"] == "insert_order":
                    order = self._normalize_order(pack)  # 调整 pack 为 order 对象需要的字段
                    orders[pack["order_id"]] = order  # 记录 order 在 orders 里
                    insert_result = self._insert_order(order, symbol, quote, underlying_quote)
                    if insert_result:
                        self._del_order(symbol, order, insert_result)
                        del orders[pack["order_id"]]
                    else:
                        match_msg = self._match_order(order, symbol, quote, underlying_quote)
                        if match_msg:
                            self._del_order(symbol, orders[pack["order_id"]], match_msg)
                            del orders[pack["order_id"]]
                    # 按照合约最新价调整持仓
                    self._adjust_position(symbol, price=quote["last_price"])
                    await self._send_diff()
                elif pack["aid"] == "cancel_order":
                    if pack["order_id"] in orders:
                        self._del_order(symbol, orders[pack["order_id"]], "已撤单")
                        del orders[pack["order_id"]]
                    await self._send_diff()
                quote_chan.task_done()
        finally:
            await quote_chan.close()
            await order_chan.close()
            task.cancel()
            await asyncio.gather(task, return_exceptions=True)

    async def _forward_chan_handler(self, chan_from, chan_to):
        async for pack in chan_from:
            await chan_to.send(pack)

    def _md_recv(self, pack):
        for d in pack["data"]:
            d.pop("trade", None)
            self._diffs.append(d)
            # 在第一次收到 mdhis_more_data 为 False 的时候,发送账户初始截面信息,这样回测模式下,往后的模块才有正确的时间顺序
            if not self._has_send_init_account and not d.get("mdhis_more_data", True):
                self._send_account()
                self._diffs.append({
                    "trade": {
                        self._account_id: {
                            "orders": {},
                            "positions": {},
                            "trade_more_data": False
                        }
                    }
                })
                self._has_send_init_account = True
            _tqsdk_backtest = d.get("_tqsdk_backtest", {})
            if _tqsdk_backtest:
                # 回测时,用 _tqsdk_backtest 对象中 current_dt 作为 TqSim 的 _current_datetime
                self._tqsdk_backtest.update(_tqsdk_backtest)
                self._current_datetime = datetime.fromtimestamp(
                    self._tqsdk_backtest["current_dt"] / 1e9).strftime("%Y-%m-%d %H:%M:%S.%f")
                self._local_time_record = float("nan")
                # 1. 回测时不使用时间差来模拟交易所时间的原因(_local_time_record始终为初始值nan):
                #   在sim收到行情后记录_local_time_record,然后下发行情到api进行merge_diff(),api需要处理完k线和quote才能结束wait_update(),
                #   若处理时间过长,此时下单则在判断下单时间时与测试用例中的预期时间相差较大,导致测试用例无法通过。
                # 2. 回测不使用时间差的方法来判断下单时间仍是可行的: 与使用了时间差的方法相比, 只对在每个交易时间段最后一笔行情时的下单时间判断有差异,
                #   若不使用时间差, 则在最后一笔行情时下单仍判断为在可交易时间段内, 且可成交.
            quotes_diff = d.get("quotes", {})
            if not quotes_diff:
                continue
            # 先根据 quotes_diff 里的 datetime, 确定出 _current_datetime,再 _merge_diff(同时会发送行情到 quote_chan)
            for symbol, quote_diff in quotes_diff.items():
                if quote_diff is None:
                    continue
                # 若直接使用本地时间来判断下单时间是否在可交易时间段内 可能有较大误差,因此判断的方案为:(在接收到下单指令时判断 估计的交易所时间 是否在交易时间段内)
                # 在更新最新行情时间(即self._current_datetime)时,记录当前本地时间(self._local_time_record),
                # 在这之后若收到下单指令,则获取当前本地时间,判 "最新行情时间 + (当前本地时间 - 记录的本地时间)" 是否在交易时间段内。
                # 另外, 若在盘后下单且下单前未订阅此合约:
                # 因为从_md_recv()中获取数据后立即判断下单时间则速度过快(两次time.time()的时间差小于最后一笔行情(14:59:9995)到15点的时间差),
                # 则会立即成交,为处理此情况则将当前时间减去5毫秒(模拟发生5毫秒网络延迟,则两次time.time()的时间差增加了5毫秒)。
                # todo: 按交易所来存储 _current_datetime(issue: #277)
                if quote_diff.get("datetime", "") > self._current_datetime:
                    # 回测时,当前时间更新即可以由 quote 行情更新,也可以由 _tqsdk_backtest.current_dt 更新,
                    # 在最外层的循环里,_tqsdk_backtest.current_dt 是在 rtn_data.data 中数组位置中的最后一个,会在循环最后一个才更新 self.current_datetime
                    # 导致前面处理 order 时的 _current_datetime 还是旧的行情时间
                    self._current_datetime = quote_diff["datetime"]  # 最新行情时间
                    # 更新最新行情时间时的本地时间,回测时不使用时间差
                    self._local_time_record = (time.time() - 0.005) if not self._tqsdk_backtest else float("nan")
                if self._current_datetime > self._trading_day_end:  # 结算
                    self._settle()
                    # 若当前行情时间大于交易日的结束时间(切换交易日),则根据此行情时间更新交易日及交易日结束时间
                    trading_day = _get_trading_day_from_timestamp(self._get_current_timestamp())
                    self._trading_day_end = datetime.fromtimestamp(
                        (_get_trading_day_end_time(trading_day) - 999) / 1e9).strftime("%Y-%m-%d %H:%M:%S.%f")
            _merge_diff(self._data, {"quotes": quotes_diff}, self._prototype, False, True)

    def _normalize_order(self, order):
        order["exchange_order_id"] = order["order_id"]
        order["volume_orign"] = order["volume"]
        order["volume_left"] = order["volume"]
        order["frozen_margin"] = 0.0
        order["frozen_premium"] = 0.0
        order["last_msg"] = "报单成功"
        order["status"] = "ALIVE"
        order["insert_date_time"] = 0  # 初始化为0:保持 order 的结构不变(所有字段都有,只是值不同)
        del order["aid"]
        del order["volume"]
        return order

    def _normalize_position(self, symbol):
        return {
            "exchange_id": symbol.split(".", maxsplit=1)[0],
            "instrument_id": symbol.split(".", maxsplit=1)[1],
            "pos_long_his": 0,
            "pos_long_today": 0,
            "pos_short_his": 0,
            "pos_short_today": 0,
            "volume_long_today": 0,
            "volume_long_his": 0,
            "volume_long": 0,
            "volume_long_frozen_today": 0,
            "volume_long_frozen_his": 0,
            "volume_long_frozen": 0,
            "volume_short_today": 0,
            "volume_short_his": 0,
            "volume_short": 0,
            "volume_short_frozen_today": 0,
            "volume_short_frozen_his": 0,
            "volume_short_frozen": 0,
            "open_price_long": float("nan"),
            "open_price_short": float("nan"),
            "open_cost_long": 0.0,
            "open_cost_short": 0.0,
            "position_price_long": float("nan"),
            "position_price_short": float("nan"),
            "position_cost_long": 0.0,
            "position_cost_short": 0.0,
            "float_profit_long": 0.0,
            "float_profit_short": 0.0,
            "float_profit": 0.0,
            "position_profit_long": 0.0,
            "position_profit_short": 0.0,
            "position_profit": 0.0,
            "margin_long": 0.0,
            "margin_short": 0.0,
            "margin": 0.0,
            "last_price": None,
            "market_value_long": 0.0,  # 权利方市值(始终 >= 0)
            "market_value_short": 0.0,  # 义务方市值(始终 <= 0)
            "market_value": 0.0,
        }

    def _del_order(self, symbol, order, msg):
        self._logger.info(f"模拟交易委托单 {order['order_id']}: {msg}")
        if order["offset"].startswith("CLOSE"):
            volume_long_frozen = 0 if order["direction"] == "BUY" else -order["volume_left"]
            volume_short_frozen = 0 if order["direction"] == "SELL" else -order["volume_left"]
            if order["exchange_id"] == "SHFE" or order["exchange_id"] == "INE":
                priority = "H" if order["offset"] == "CLOSE" else "T"
            else:
                priority = "HT"
            self._adjust_position(symbol, volume_long_frozen=volume_long_frozen,
                                  volume_short_frozen=volume_short_frozen, priority=priority)
        else:
            self._adjust_account(frozen_margin=-order["frozen_margin"], frozen_premium=-order["frozen_premium"])
            order["frozen_margin"] = 0.0
            order["frozen_premium"] = 0.0
        order["last_msg"] = msg
        order["status"] = "FINISHED"
        self._send_order(order)
        return order

    def _insert_order(self, order, symbol, quote, underlying_quote=None):
        if order["offset"].startswith("CLOSE"):
            volume_long_frozen = 0 if order["direction"] == "BUY" else order["volume_left"]
            volume_short_frozen = 0 if order["direction"] == "SELL" else order["volume_left"]
            if order["exchange_id"] == "SHFE" or order["exchange_id"] == "INE":
                priority = "H" if order["offset"] == "CLOSE" else "T"
            else:
                priority = "TH"
            if not self._adjust_position(symbol, volume_long_frozen=volume_long_frozen,
                                         volume_short_frozen=volume_short_frozen, priority=priority):
                return "平仓手数不足"
        else:
            if ("commission" not in quote or "margin" not in quote) \
                    and quote["ins_class"] not in ["OPTION", "FUTURE_OPTION"]:
                # 除了期权外,主连、指数和组合没有这两个字段
                return "合约不存在"
            if quote["ins_class"] in ["OPTION", "FUTURE_OPTION"]:
                if order["price_type"] == "ANY" and order["exchange_id"] != "CZCE":
                    return f"此交易所({order['exchange_id']}) 不支持期权市价单"
                elif order["direction"] == "SELL":  # 期权的SELL义务仓
                    if quote["option_class"] == "CALL":
                        # 认购期权义务仓开仓保证金=[合约最新价 + Max(12% × 合约标的最新价 - 认购期权虚值, 7% × 合约标的前收盘价)] × 合约单位
                        # 认购期权虚值=Max(行权价 - 合约标的前收盘价,0);
                        order["frozen_margin"] = (quote["last_price"] + max(
                            0.12 * underlying_quote["last_price"] - max(
                                quote["strike_price"] - underlying_quote["last_price"], 0),
                            0.07 * underlying_quote["last_price"])) * quote["volume_multiple"]
                    else:
                        # 认沽期权义务仓开仓保证金=Min[合约最新价+ Max(12% × 合约标的前收盘价 - 认沽期权虚值,7%×行权价),行权价] × 合约单位
                        # 认沽期权虚值=Max(合约标的前收盘价 - 行权价,0)
                        order["frozen_margin"] = min(quote["last_price"] + max(
                            0.12 * underlying_quote["last_price"] - max(
                                underlying_quote["last_price"] - quote["strike_price"], 0),
                            0.07 * quote["strike_price"]), quote["strike_price"]) * quote["volume_multiple"]
                elif order["price_type"] != "ANY":  # 期权的BUY权利仓(市价单立即成交且没有limit_price字段,frozen_premium默认为0)
                    order["frozen_premium"] = order["volume_orign"] * quote["volume_multiple"] * order[
                        "limit_price"]
            else:  # 期货
                # 市价单立即成交或不成, 对api来说普通市价单没有 冻结_xxx 数据存在的状态
                order["frozen_margin"] = quote["margin"] * order["volume_orign"]
            if not self._adjust_account(frozen_margin=order["frozen_margin"],
                                        frozen_premium=order["frozen_premium"]):
                return "开仓资金不足"
        # 可以模拟交易下单
        order["insert_date_time"] = _get_trade_timestamp(self._current_datetime, self._local_time_record)
        self._send_order(order)
        self._logger.info("模拟交易下单 %s: 时间:%s,合约:%s,开平:%s,方向:%s,手数:%s,价格:%s", order["order_id"],
                          datetime.fromtimestamp(order["insert_date_time"] / 1e9).strftime(
                              "%Y-%m-%d %H:%M:%S.%f"), symbol, order["offset"], order["direction"],
                          order["volume_left"], order.get("limit_price", "市价"))
        if not _is_in_trading_time(quote, self._current_datetime, self._local_time_record):
            return "下单失败, 不在可交易时间段内"

    def _match_order(self, order, symbol, quote, underlying_quote=None):
        ask_price = quote["ask_price1"]
        bid_price = quote["bid_price1"]
        if quote["ins_class"] == "FUTURE_INDEX":
            # 在指数交易时,使用 tick 进行回测时,backtest 发的 quote 没有买一卖一价;或者在实时行情中,指数的 quote 也没有买一卖一价
            if ask_price != ask_price:
                ask_price = quote["last_price"] + quote["price_tick"]
            if bid_price != bid_price:
                bid_price = quote["last_price"] - quote["price_tick"]

        if "limit_price" not in order:
            price = ask_price if order["direction"] == "BUY" else bid_price
            if price != price:
                return "市价指令剩余撤销"
        elif order["direction"] == "BUY" and order["limit_price"] >= ask_price:
            price = order["limit_price"]
        elif order["direction"] == "SELL" and order["limit_price"] <= bid_price:
            price = order["limit_price"]
        elif order["time_condition"] == "IOC":  # IOC 立即成交,限价下单且不能成交的价格,直接撤单
            return "已撤单报单已提交"
        else:
            return ""
        trade = {
            "user_id": order["user_id"],
            "order_id": order["order_id"],
            "trade_id": order["order_id"] + "|" + str(order["volume_left"]),
            "exchange_trade_id": order["order_id"] + "|" + str(order["volume_left"]),
            "exchange_id": order["exchange_id"],
            "instrument_id": order["instrument_id"],
            "direction": order["direction"],
            "offset": order["offset"],
            "price": price,
            "volume": order["volume_left"],
            # todo: 可能导致测试结果不确定
            "trade_date_time": _get_trade_timestamp(self._current_datetime, self._local_time_record),
            # 期权quote没有commission字段, 设为固定10元一张
            "commission": (quote["commission"] if quote["ins_class"] not in ["OPTION", "FUTURE_OPTION"] else 10) *
                          order["volume_left"],
        }
        trade_log = self._ensure_trade_log()
        trade_log["trades"].append(trade)
        self._send_trade(trade)
        if order["exchange_id"] == "SHFE" or order["exchange_id"] == "INE":
            priority = "H" if order["offset"] == "CLOSE" else "T"
        else:
            priority = "TH"
        if order["offset"].startswith("CLOSE"):
            volume_long = 0 if order["direction"] == "BUY" else -order["volume_left"]
            volume_short = 0 if order["direction"] == "SELL" else -order["volume_left"]
            self._adjust_position(symbol, volume_long_frozen=volume_long, volume_short_frozen=volume_short,
                                  priority=priority)
        else:
            volume_long = 0 if order["direction"] == "SELL" else order["volume_left"]
            volume_short = 0 if order["direction"] == "BUY" else order["volume_left"]
        self._adjust_position(symbol, volume_long=volume_long, volume_short=volume_short, price=price,
                              priority=priority)
        premium = -order["frozen_premium"] if order["direction"] == "BUY" else order["frozen_premium"]
        self._adjust_account(commission=trade["commission"], premium=premium)
        order["volume_left"] = 0
        return "全部成交"

    def _settle(self):
        if self._trading_day_end[:10] == "1990-01-01":
            return
        trade_log = self._ensure_trade_log()
        # 记录账户截面
        trade_log["account"] = self._account.copy()
        trade_log["positions"] = {k: v.copy() for k, v in self._positions.items()}
        # 为下一交易日调整账户
        self._account["pre_balance"] = self._account["balance"]
        self._account["static_balance"] = self._account["balance"]
        self._account["position_profit"] = 0
        self._account["close_profit"] = 0
        self._account["commission"] = 0
        self._account["premium"] = 0
        self._send_account()
        self._adjust_account()
        # 对于持仓的结算放在这里,没有放在 quote_handler 里的原因:
        # 1. 异步发送的话,会造成如果此时 sim 未收到 pending_peek, 就没法把结算的账户信息发送出去,此时用户代码中 api.get_postion 得到的持仓和 sim 里面的持仓是不一致的
        # set_target_pos 下单时就会产生错单。而且结算时一定是已经收到过行情的数据包,在同步代码的最后一步,会发送出去这个行情包 peeding_peek,
        # quote_handler 处理 settle 的时候, 所以在结算的时候 pending_peek 一定是 False, 要 api 处理过之后,才会收到 peek_message
        # 2. 同步发送的话,就可以和产生切换交易日的数据包同时发送出去
        # 对 order 的处理发生在下一次回复 peek_message
        for position in self._positions.values():
            position["pos_long_his"] = position["volume_long"]
            position["pos_long_today"] = 0
            position["pos_short_his"] = position["volume_short"]
            position["pos_short_today"] = 0
            position["volume_long_today"] = 0
            position["volume_long_his"] = position["volume_long"]
            position["volume_short_today"] = 0
            position["volume_short_his"] = position["volume_short"]
            position["position_price_long"] = position["last_price"]
            position["position_price_short"] = position["last_price"]
            position["position_cost_long"] = position["open_cost_long"] + position["float_profit_long"]
            position["position_cost_short"] = position["open_cost_short"] + position["float_profit_short"]
            position["position_profit_long"] = 0
            position["position_profit_short"] = 0
            position["position_profit"] = 0
            self._send_position(position)
        for symbol in self._orders.keys():
            order_ids = list(self._orders[symbol].keys()).copy()
            for order_id in order_ids:
                self._del_order(symbol, self._orders[symbol][order_id], "交易日结束,自动撤销当日有效的委托单(GFD)")
                del self._orders[symbol][order_id]

    def _report(self):
        if not self.trade_log:
            return
        self._logger.warning("模拟交易成交记录")
        self._tqsdk_stat["init_balance"] = self._init_balance  # 起始资金
        self._tqsdk_stat["balance"] = self._account["balance"]  # 结束资金
        self._tqsdk_stat["max_drawdown"] = 0  # 最大回撤
        max_balance = 0
        daily_yield = []
        # 胜率 盈亏额比例
        trades_logs = {}
        profit_logs = []  # 盈利记录
        loss_logs = []  # 亏损记录
        for d in sorted(self.trade_log.keys()):
            balance = self.trade_log[d]["account"]["balance"]
            if balance > max_balance:
                max_balance = balance
            drawdown = (max_balance - balance) / max_balance
            if drawdown > self._tqsdk_stat["max_drawdown"]:
                self._tqsdk_stat["max_drawdown"] = drawdown
            daily_yield.append(
                self.trade_log[d]["account"]["balance"] / self.trade_log[d]["account"]["pre_balance"] - 1)
            for t in self.trade_log[d]["trades"]:
                symbol = t["exchange_id"] + "." + t["instrument_id"]
                self._logger.warning("时间:%s,合约:%s,开平:%s,方向:%s,手数:%d,价格:%.3f,手续费:%.2f",
                                     datetime.fromtimestamp(t["trade_date_time"] / 1e9).strftime(
                                         "%Y-%m-%d %H:%M:%S.%f"), symbol, t["offset"], t["direction"], t["volume"],
                                     t["price"], t["commission"])
                if symbol not in trades_logs:
                    trades_logs[symbol] = {
                        "BUY": [],
                        "SELL": [],
                    }
                if t["offset"] == "OPEN":
                    # 开仓成交 记录下买卖方向、价格、手数
                    trades_logs[symbol][t["direction"]].append({
                        "volume": t["volume"],
                        "price": t["price"]
                    })
                else:
                    opposite_dir = "BUY" if t["direction"] == "SELL" else "SELL"  # 开仓时的方向
                    opposite_list = trades_logs[symbol][opposite_dir]  # 开仓方向对应 trade log
                    cur_close_volume = t["volume"]
                    cur_close_price = t["price"]
                    cur_close_dir = 1 if t["direction"] == "SELL" else -1
                    while cur_close_volume > 0 and opposite_list[0]:
                        volume = min(cur_close_volume, opposite_list[0]["volume"])
                        profit = (cur_close_price - opposite_list[0]["price"]) * cur_close_dir
                        if profit >= 0:
                            profit_logs.append({
                                "symbol": symbol,
                                "profit": profit,
                                "volume": volume
                            })
                        else:
                            loss_logs.append({
                                "symbol": symbol,
                                "profit": profit,
                                "volume": volume
                            })
                        cur_close_volume -= volume
                        opposite_list[0]["volume"] -= volume
                        if opposite_list[0]["volume"] == 0:
                            opposite_list.pop(0)

        self._tqsdk_stat["profit_volumes"] = sum(p["volume"] for p in profit_logs)  # 盈利手数
        self._tqsdk_stat["loss_volumes"] = sum(l["volume"] for l in loss_logs)  # 亏损手数
        self._tqsdk_stat["profit_value"] = sum(
            p["profit"] * p["volume"] * self._data["quotes"][p["symbol"]]["volume_multiple"] for p in profit_logs)  # 盈利额
        self._tqsdk_stat["loss_value"] = sum(
            l["profit"] * l["volume"] * self._data["quotes"][l["symbol"]]["volume_multiple"] for l in loss_logs)  # 亏损额

        mean = statistics.mean(daily_yield)
        rf = 0.0001
        stddev = statistics.pstdev(daily_yield, mu=mean)
        self._tqsdk_stat["sharpe_ratio"] = 250 ** (1 / 2) * (mean - rf) / stddev if stddev else float("inf")  # 年化夏普率

        _ror = self._tqsdk_stat["balance"] / self._tqsdk_stat["init_balance"]
        self._tqsdk_stat["ror"] = _ror - 1  # 收益率
        self._tqsdk_stat["annual_yield"] = _ror ** (250 / len(self.trade_log)) - 1  # 年化收益率

        self._logger.warning("模拟交易账户资金")
        for d in sorted(self.trade_log.keys()):
            account = self.trade_log[d]["account"]
            self._logger.warning(
                "日期:%s,账户权益:%.2f,可用资金:%.2f,浮动盈亏:%.2f,持仓盈亏:%.2f,平仓盈亏:%.2f,市值:%.2f,保证金:%.2f,手续费:%.2f,风险度:%.2f%%",
                d, account["balance"], account["available"], account["float_profit"], account["position_profit"],
                account["close_profit"], account["market_value"], account["margin"], account["commission"],
                account["risk_ratio"] * 100)

        self._tqsdk_stat["winning_rate"] = (self._tqsdk_stat["profit_volumes"] / (
                self._tqsdk_stat["profit_volumes"] + self._tqsdk_stat["loss_volumes"])) \
            if self._tqsdk_stat["profit_volumes"] + self._tqsdk_stat["loss_volumes"] else 0
        profit_pre_volume = self._tqsdk_stat["profit_value"] / self._tqsdk_stat["profit_volumes"] if self._tqsdk_stat[
            "profit_volumes"] else 0
        loss_pre_volume = self._tqsdk_stat["loss_value"] / self._tqsdk_stat["loss_volumes"] if self._tqsdk_stat[
            "loss_volumes"] else 0
        self._tqsdk_stat["profit_loss_ratio"] = abs(profit_pre_volume / loss_pre_volume) if loss_pre_volume else float(
            "inf")
        self._logger.warning("胜率:%.2f%%,盈亏额比例:%.2f,收益率:%.2f%%,年化收益率:%.2f%%,最大回撤:%.2f%%,年化夏普率:%.4f",
                             self._tqsdk_stat["winning_rate"] * 100,
                             self._tqsdk_stat["profit_loss_ratio"],
                             self._tqsdk_stat["ror"] * 100,
                             self._tqsdk_stat["annual_yield"] * 100,
                             self._tqsdk_stat["max_drawdown"] * 100,
                             self._tqsdk_stat["sharpe_ratio"])

    def _ensure_trade_log(self):
        return self.trade_log.setdefault(self._trading_day_end[:10], {
            "trades": []
        })

    def _adjust_position(self, symbol, volume_long_frozen=0, volume_short_frozen=0, volume_long=0, volume_short=0,
                         price=None, priority=None):
        quote = self._data.get("quotes", {}).get(symbol, {})
        underlying_quote = self._data.get("quotes", {}).get(quote["underlying_symbol"], {}) if "underlying_symbol" in quote else None
        position = self._positions[symbol]
        volume_multiple = quote["volume_multiple"]
        if volume_long_frozen:
            position["volume_long_frozen"] += volume_long_frozen
            if priority[0] == "T":
                position["volume_long_frozen_today"] += volume_long_frozen
                if len(priority) > 1:
                    if position["volume_long_frozen_today"] < 0:
                        position["volume_long_frozen_his"] += position["volume_long_frozen_today"]
                        position["volume_long_frozen_today"] = 0
                    elif position["volume_long_today"] < position["volume_long_frozen_today"]:
                        position["volume_long_frozen_his"] += position["volume_long_frozen_today"] - position[
                            "volume_long_today"]
                        position["volume_long_frozen_today"] = position["volume_long_today"]
            else:
                position["volume_long_frozen_his"] += volume_long_frozen
                if len(priority) > 1:
                    if position["volume_long_frozen_his"] < 0:
                        position["volume_long_frozen_today"] += position["volume_long_frozen_his"]
                        position["volume_long_frozen_his"] = 0
                    elif position["volume_long_his"] < position["volume_long_frozen_his"]:
                        position["volume_long_frozen_today"] += position["volume_long_frozen_his"] - position[
                            "volume_long_his"]
                        position["volume_long_frozen_his"] = position["volume_long_his"]
        if volume_short_frozen:
            position["volume_short_frozen"] += volume_short_frozen
            if priority[0] == "T":
                position["volume_short_frozen_today"] += volume_short_frozen
                if len(priority) > 1:
                    if position["volume_short_frozen_today"] < 0:
                        position["volume_short_frozen_his"] += position["volume_short_frozen_today"]
                        position["volume_short_frozen_today"] = 0
                    elif position["volume_short_today"] < position["volume_short_frozen_today"]:
                        position["volume_short_frozen_his"] += position["volume_short_frozen_today"] - position[
                            "volume_short_today"]
                        position["volume_short_frozen_today"] = position["volume_short_today"]
            else:
                position["volume_short_frozen_his"] += volume_short_frozen
                if len(priority) > 1:
                    if position["volume_short_frozen_his"] < 0:
                        position["volume_short_frozen_today"] += position["volume_short_frozen_his"]
                        position["volume_short_frozen_his"] = 0
                    elif position["volume_short_his"] < position["volume_short_frozen_his"]:
                        position["volume_short_frozen_today"] += position["volume_short_frozen_his"] - position[
                            "volume_short_his"]
                        position["volume_short_frozen_his"] = position["volume_short_his"]
        if price is not None and price == price:
            if position["last_price"] is not None:
                float_profit_long = (price - position["last_price"]) * position["volume_long"] * volume_multiple
                float_profit_short = (position["last_price"] - price) * position["volume_short"] * volume_multiple
                float_profit = float_profit_long + float_profit_short
                position["float_profit_long"] += float_profit_long
                position["float_profit_short"] += float_profit_short
                position["float_profit"] += float_profit
                if quote["ins_class"] in ["OPTION", "FUTURE_OPTION"]:  # 期权市值 = 权利金 + 期权持仓盈亏
                    position["market_value_long"] += float_profit_long  # 权利方市值(始终 >= 0)
                    position["market_value_short"] += float_profit_short  # 义务方市值(始终 <= 0)
                    position["market_value"] += float_profit
                    self._adjust_account(float_profit=float_profit, market_value=float_profit)
                else:  # 期权没有持仓盈亏
                    position["position_profit_long"] += float_profit_long
                    position["position_profit_short"] += float_profit_short
                    position["position_profit"] += float_profit
                    self._adjust_account(float_profit=float_profit, position_profit=float_profit)
            position["last_price"] = price
        if volume_long:  # volume_long > 0:买开,  < 0:卖平
            close_profit = 0 if volume_long > 0 else (position["last_price"] - position[
                "position_price_long"]) * -volume_long * volume_multiple
            float_profit = 0 if volume_long > 0 else position["float_profit_long"] / position[
                "volume_long"] * volume_long
            position["open_cost_long"] += volume_long * position["last_price"] * volume_multiple if volume_long > 0 else \
                position["open_cost_long"] / position["volume_long"] * volume_long
            position["position_cost_long"] += volume_long * position[
                "last_price"] * volume_multiple if volume_long > 0 else position["position_cost_long"] / position[
                "volume_long"] * volume_long
            market_value = 0
            margin = 0.0
            if quote["ins_class"] in ["OPTION", "FUTURE_OPTION"]:
                # 期权市值 = 权利金 + 期权持仓盈亏
                market_value = position["last_price"] * volume_long * volume_multiple
                position["market_value_long"] += market_value
                position["market_value"] += market_value
            else:
                margin = volume_long * quote["margin"]
                position["position_profit_long"] -= close_profit
                position["position_profit"] -= close_profit
            position["volume_long"] += volume_long
            position["open_price_long"] = position["open_cost_long"] / volume_multiple / position["volume_long"] if \
                position["volume_long"] else float("nan")
            position["position_price_long"] = position["position_cost_long"] / volume_multiple / position[
                "volume_long"] if position["volume_long"] else float("nan")
            position["float_profit_long"] += float_profit
            position["float_profit"] += float_profit
            position["margin_long"] += margin
            position["margin"] += margin
            if priority[0] == "T":
                position["volume_long_today"] += volume_long
                if len(priority) > 1:
                    if position["volume_long_today"] < 0:
                        position["volume_long_his"] += position["volume_long_today"]
                        position["volume_long_today"] = 0
            else:
                position["volume_long_his"] += volume_long
                if len(priority) > 1:
                    if position["volume_long_his"] < 0:
                        position["volume_long_today"] += position["volume_long_his"]
                        position["volume_long_his"] = 0

            if priority[0] == "T":
                position["pos_long_today"] += volume_long
                if len(priority) > 1:
                    if position["pos_long_today"] < 0:
                        position["pos_long_his"] += position["pos_long_today"]
                        position["pos_long_today"] = 0
            else:
                position["pos_long_his"] += volume_long
                if len(priority) > 1:
                    if position["pos_long_his"] < 0:
                        position["pos_long_today"] += position["pos_long_his"]
                        position["pos_long_his"] = 0

            self._adjust_account(float_profit=float_profit,
                                 position_profit=-close_profit if quote["ins_class"] not in ["OPTION",
                                                                                             "FUTURE_OPTION"] else 0,
                                 close_profit=close_profit, margin=margin, market_value=market_value)
        if volume_short:  # volume_short > 0: 卖开,  < 0:买平
            close_profit = 0 if volume_short > 0 else (position["position_price_short"] - position[
                "last_price"]) * -volume_short * volume_multiple
            float_profit = 0 if volume_short > 0 else position["float_profit_short"] / position[
                "volume_short"] * volume_short
            # 期权: open_cost_short > 0, open_cost_long > 0
            position["open_cost_short"] += volume_short * position[
                "last_price"] * volume_multiple if volume_short > 0 else position["open_cost_short"] / position[
                "volume_short"] * volume_short
            position["position_cost_short"] += volume_short * position[
                "last_price"] * volume_multiple if volume_short > 0 else position["position_cost_short"] / position[
                "volume_short"] * volume_short
            market_value = 0
            margin = 0
            if quote["ins_class"] in ["OPTION", "FUTURE_OPTION"]:
                market_value = -(position["last_price"] * volume_short * volume_multiple)
                position["market_value_short"] += market_value
                position["market_value"] += market_value
                if volume_short > 0:
                    if quote["option_class"] == "CALL":
                        margin = (quote["last_price"] + max(0.12 * underlying_quote["last_price"] - max(
                            quote["strike_price"] - underlying_quote["last_price"], 0),
                                                            0.07 * underlying_quote["last_price"])) * quote[
                                     "volume_multiple"]
                    else:
                        margin = min(quote["last_price"] + max(0.12 * underlying_quote["last_price"] - max(
                            underlying_quote["last_price"] - quote["strike_price"], 0), 0.07 * quote["strike_price"]),
                                     quote["strike_price"]) * quote["volume_multiple"]
            else:
                margin = volume_short * quote["margin"]
                position["position_profit_short"] -= close_profit
                position["position_profit"] -= close_profit
            position["volume_short"] += volume_short
            position["open_price_short"] = position["open_cost_short"] / volume_multiple / position["volume_short"] if \
                position["volume_short"] else float("nan")
            position["position_price_short"] = position["position_cost_short"] / volume_multiple / position[
                "volume_short"] if position["volume_short"] else float("nan")
            position["float_profit_short"] += float_profit
            position["float_profit"] += float_profit
            position["margin_short"] += margin
            position["margin"] += margin
            if priority[0] == "T":
                position["volume_short_today"] += volume_short
                if len(priority) > 1:
                    if position["volume_short_today"] < 0:
                        position["volume_short_his"] += position["volume_short_today"]
                        position["volume_short_today"] = 0
            else:
                position["volume_short_his"] += volume_short
                if len(priority) > 1:
                    if position["volume_short_his"] < 0:
                        position["volume_short_today"] += position["volume_short_his"]
                        position["volume_short_his"] = 0

            if priority[0] == "T":
                position["pos_short_today"] += volume_short
                if len(priority) > 1:
                    if position["pos_short_today"] < 0:
                        position["pos_short_his"] += position["pos_short_today"]
                        position["pos_short_today"] = 0
            else:
                position["pos_short_his"] += volume_short
                if len(priority) > 1:
                    if position["pos_short_his"] < 0:
                        position["pos_short_today"] += position["pos_short_his"]
                        position["pos_short_his"] = 0
            self._adjust_account(float_profit=float_profit,
                                 position_profit=-close_profit if quote["ins_class"] not in ["OPTION",
                                                                                             "FUTURE_OPTION"] else 0,
                                 close_profit=close_profit,
                                 margin=margin, market_value=market_value)
        self._send_position(position)
        return position["volume_long_his"] - position["volume_long_frozen_his"] >= 0 and position["volume_long_today"] - \
               position["volume_long_frozen_today"] >= 0 and \
               position["volume_short_his"] - position["volume_short_frozen_his"] >= 0 and position[
                   "volume_short_today"] - position["volume_short_frozen_today"] >= 0

    def _adjust_account(self, commission=0.0, frozen_margin=0.0, frozen_premium=0.0, float_profit=0.0,
                        position_profit=0.0, close_profit=0.0, margin=0.0, premium=0.0, market_value=0.0):
        # 权益 += 持仓盈亏 + 平仓盈亏 - 手续费 + 权利金(收入为负值,支出为正值) + 市值
        self._account["balance"] += position_profit + close_profit - commission + premium + market_value
        # 可用资金 += 权益 - 冻结保证金 - 保证金 - 冻结权利金 - 市值
        self._account[
            "available"] += position_profit + close_profit - commission + premium - frozen_margin - margin - frozen_premium
        self._account["float_profit"] += float_profit
        self._account["position_profit"] += position_profit
        self._account["close_profit"] += close_profit
        self._account["frozen_margin"] += frozen_margin
        self._account["frozen_premium"] += frozen_premium
        self._account["margin"] += margin
        # premium变量的值有正负,正数表示收入的权利金,负数表示付出的权利金;account["premium"]为累计值
        self._account["premium"] += premium
        self._account["market_value"] += market_value
        self._account["commission"] += commission
        self._account["risk_ratio"] = self._account["margin"] / self._account[
            "balance"] if self._account["balance"] else 0.0
        self._send_account()
        return self._account["available"] >= 0

    def _send_trade(self, trade):
        self._diffs.append({
            "trade": {
                self._account_id: {
                    "trades": {
                        trade["trade_id"]: trade.copy()
                    }
                }
            }
        })

    def _send_order(self, order):
        self._diffs.append({
            "trade": {
                self._account_id: {
                    "orders": {
                        order["order_id"]: order.copy()
                    }
                }
            }
        })

    def _send_position(self, position):
        self._diffs.append({
            "trade": {
                self._account_id: {
                    "positions": {
                        position["exchange_id"] + "." + position["instrument_id"]: position.copy()
                    }
                }
            }
        })

    def _send_account(self):
        self._diffs.append({
            "trade": {
                self._account_id: {
                    "accounts": {
                        "CNY": self._account.copy()
                    }
                }
            }
        })

    def _get_current_timestamp(self):
        return int(datetime.strptime(self._current_datetime, "%Y-%m-%d %H:%M:%S.%f").timestamp() * 1e6) * 1000
Пример #17
0
class TradeExtension():
    """
    为持仓、委托单、成交对象添加 合约信息

    * 为期权合约相应的持仓、委托单、成交,添加以下字段
        + option_class 代表期权方向 CALL or PUT,非期权合约该处显示为NONE
        + underlying_symbol
        + strike_price
        + expire_rest_days 距离到期日剩余天数

    """

    def __init__(self, api):
        self._api = api
        self._data = Entity()  # 交易业务信息截面,需要定于数据原型,使用 Entity 类型 和 _merge_diff
        self._data._instance_entity([])
        self._new_objs_list = []
        self._prototype = {
            "trade": {
                "*": {
                    "@": CustomDict(self._api, self._new_objs_list)
                }
            }
        }
        self._data_quotes = {}  # 行情信息截面,只需要 quotes 数据。这里不需要定义数据原型,使用普通 dict 和 _simple_merge_diff
        self._diffs = []
        self._all_trade_symbols = set()  # 所有持仓、委托、成交中的合约
        self._query_symbols = set()  # 已经发送合约信息请求 + 已经知道合约信息的合约
        self._need_wait_symbol_info = set()  # 需要发送合约信息请求 + 不知道合约信息的合约

    async def _run(self, api_send_chan, api_recv_chan, md_send_chan, md_recv_chan):
        self._logger = self._api._logger.getChild("TradeExtension")
        self._api_send_chan = api_send_chan
        self._api_recv_chan = api_recv_chan
        self._md_send_chan = md_send_chan
        self._md_recv_chan = md_recv_chan
        self._datetime_state = TqDatetimeState()
        self._trading_day_end = None
        md_task = self._api.create_task(self._md_handler())
        self._pending_peek = False  # True 表示收到下游的 peek_message ,并且没有发给过下游回复;False 表示发给过下游回复,没有 pending_peek_message
        self._pending_peek_md = False  # True 表示发给过上游 peek_message;False 表示对上游没有 pending_peek_message
        try:
            async for pack in api_send_chan:
                if "_md_recv" in pack:
                    self._pending_peek_md = False
                    await self._md_recv(pack)
                    await self._send_diff()
                    if self._pending_peek and self._pending_peek_md is False:
                        self._pending_peek_md = True
                        await self._md_send_chan.send({"aid": "peek_message"})
                elif pack["aid"] == "peek_message":
                    self._pending_peek = True
                    await self._send_diff()
                    if self._pending_peek and self._pending_peek_md is False:
                        self._pending_peek_md = True
                        await self._md_send_chan.send(pack)
                else:
                    await self._md_send_chan.send(pack)
        finally:
            md_task.cancel()

    async def _md_handler(self):
        """0 接收上游数据包 """
        async for pack in self._md_recv_chan:
            pack["_md_recv"] = True
            await self._api_send_chan.send(pack)

    async def _md_recv(self, pack):
        """ 处理下行数据包
        0 将行情数据和交易数据合并至 self._data
        1 生成增量业务截面, 该截面包含期权补充的字段
        """
        for d in pack.get("data", {}):
            self._datetime_state.update_state(d)
            _simple_merge_diff(self._data_quotes, d.get('quotes', {}))
            _merge_diff(self._data, {"trade": d.get('trade', {})}, prototype=self._prototype, persist=False, reduce_diff=False)
            self._diffs.append(d)  # 添加至 self._diff 等待被发送

        for obj in self._new_objs_list:
            # 新添加的 Position / Order / Trade  节点
            if hasattr(obj, '_path') and obj['_path'][2] in ['positions', 'trades', 'orders']:
                symbol = f"{obj.get('exchange_id', '')}.{obj.get('instrument_id', '')}"
                if symbol not in self._all_trade_symbols:
                    self._all_trade_symbols.add(symbol)
                    self._need_wait_symbol_info.add(symbol)  # 需要发送合约信息请求

        for s in self._need_wait_symbol_info.copy():
            if self._data_quotes.get(s, {}).get("price_tick", 0) > 0:
                self._need_wait_symbol_info.remove(s)  # 需要发送合约信息请求 + 不知道合约信息的合约

        # 不知道合约信息 并且未发送请求查询合约信息
        unknown_symbols = self._need_wait_symbol_info - self._query_symbols
        if len(unknown_symbols) > 0:
            self._query_symbols = self._query_symbols.union(unknown_symbols)  # 所有发送过ins_query的合约
            query_pack = _query_for_quote(list(unknown_symbols))
            await self._md_send_chan.send(query_pack)

    def _generate_pend_diff(self):
        """"
        补充期权额外字段
        此函数在 send_diff() 才会调用, self._datetime_state.data_ready 一定为 True,
        调用 self._datetime_state.get_current_dt() 一定有正确的当前时间
        """
        pend_diff = {}
        account_keys = list(self._data.get('trade', {}).keys())
        objs_keys = ['positions', 'trades', 'orders']

        # 如果有新添加的合约, 只填充一次即可
        if self._new_objs_list:
            pend_diff.setdefault('trade', {k: {o_k: {} for o_k in objs_keys} for k in account_keys})
            for obj in self._new_objs_list:
                # 新添加的 Position / Order / Trade  节点
                if hasattr(obj, '_path') and obj['_path'][2] in objs_keys:
                    account_key = obj['_path'][1]
                    obj_key = obj['_path'][2]
                    item_id = obj['_path'][3]
                    quote = self._data_quotes.get(f"{obj.get('exchange_id', '')}.{obj.get('instrument_id', '')}", {})
                    if quote.get('ins_class', '').endswith('OPTION'):
                        pend_diff_item = pend_diff['trade'][account_key][obj_key].setdefault(item_id, {})
                        pend_diff_item['option_class'] = quote.get('option_class')
                        pend_diff_item['strike_price'] = quote.get('strike_price')
                        pend_diff_item['underlying_symbol'] = quote.get('underlying_symbol')
                        if quote.get('expire_datetime'):
                            pend_diff_item['expire_rest_days'] = _get_expire_rest_days(quote.get('expire_datetime'),
                                                                                       self._datetime_state.get_current_dt() / 1e9)
            self._new_objs_list.clear()

        # 如果有切换交易日,所有合约都需要修改 expire_rest_days
        current_dt = self._datetime_state.get_current_dt()
        if self._trading_day_end is None or current_dt > self._trading_day_end:
            pend_diff.setdefault('trade', {k: {o_k: {} for o_k in objs_keys} for k in account_keys})
            for account_key, account_node in self._data.get('trade', {}).items():
                for k in objs_keys:
                    for item_id, item in account_node.get(k, {}).items():
                        quote = self._data_quotes.get(f"{item['exchange_id']}.{item['instrument_id']}", {})
                        if quote.get('ins_class', '').endswith('OPTION') and quote.get('expire_datetime'):
                            pend_diff_item = pend_diff['trade'][account_key][k].setdefault(item_id, {})
                            # 剩余到期日字段,每天都会更新,每次都重新计算
                            pend_diff_item['expire_rest_days'] = _get_expire_rest_days(quote.get('expire_datetime'),
                                                                                       current_dt / 1e9)
            self._trading_day_end = _get_trading_day_end_time(_get_trading_day_from_timestamp(current_dt))
        return pend_diff

    async def _send_diff(self):
        if self._datetime_state.data_ready and self._pending_peek and self._diffs and len(self._need_wait_symbol_info) == 0:
            # 生成增量业务截面, 该截面包含期权补充的字段,只在真正需要给下游发送数据时,才将需要发送的数据放在 _diffs 中
            pend_diff = self._generate_pend_diff()
            self._diffs.append(pend_diff)
            rtn_data = {
                "aid": "rtn_data",
                "data": self._diffs,
            }
            self._diffs = []
            self._pending_peek = False
            await self._api_recv_chan.send(rtn_data)
Пример #18
0
 async def _run(self, api, api_send_chan, api_recv_chan, md_send_chan, md_recv_chan):
     """模拟交易task"""
     self._api = api
     self._tqsdk_backtest = {}  # 储存可能的回测信息
     self._tqsdk_stat = {}  # 回测结束后储存回测报告信息
     self._logger = api._logger.getChild("TqSim")  # 调试信息输出
     self._api_send_chan = api_send_chan
     self._api_recv_chan = api_recv_chan
     self._md_send_chan = md_send_chan
     self._md_recv_chan = md_recv_chan
     self._pending_peek = False
     self._diffs = []
     self._account = {
         "currency": "CNY",
         "pre_balance": self._init_balance,
         "static_balance": self._init_balance,
         "balance": self._init_balance,
         "available": self._init_balance,
         "float_profit": 0.0,
         "position_profit": 0.0,  # 期权没有持仓盈亏
         "close_profit": 0.0,
         "frozen_margin": 0.0,
         "margin": 0.0,
         "frozen_commission": 0.0,
         "commission": 0.0,
         "frozen_premium": 0.0,
         "premium": 0.0,
         "deposit": 0.0,
         "withdraw": 0.0,
         "risk_ratio": 0.0,
         "market_value": 0.0,
         "ctp_balance": float("nan"),
         "ctp_available": float("nan"),
     }
     self._positions = {}
     self._orders = {}
     self._data = Entity()
     self._data._instance_entity([])
     self._prototype = {
         "quotes": {
             "#": Quote(self),  # 行情的数据原型
         }
     }
     self._quote_tasks = {}
     self._all_subscribe = set()  # 客户端+模拟交易模块订阅的合约集合
     # 是否已经发送初始账户信息
     self._has_send_init_account = False
     md_task = self._api.create_task(self._md_handler())  # 将所有 md_recv_chan 上收到的包投递到 api_send_chan 上
     try:
         async for pack in self._api_send_chan:
             self._logger.debug("TqSim message received: %s", pack)
             if "_md_recv" in pack:
                 if pack["aid"] == "rtn_data":
                     self._md_recv(pack)  # md_recv 中会发送 wait_count 个 quotes 包给各个 quote_chan
                     await asyncio.gather(*[quote_task["quote_chan"].join() for quote_task in self._quote_tasks.values()])
                     await self._send_diff()
             elif pack["aid"] == "subscribe_quote":
                 await self._subscribe_quote(set(pack["ins_list"].split(",")))
             elif pack["aid"] == "peek_message":
                 self._pending_peek = True
                 await self._send_diff()
                 if self._pending_peek:  # 控制"peek_message"发送: 当没有新的事件需要用户处理时才推进到下一个行情
                     await self._md_send_chan.send(pack)
             elif pack["aid"] == "insert_order":
                 symbol = pack["exchange_id"] + "." + pack["instrument_id"]
                 if symbol not in self._quote_tasks:
                     quote_chan = TqChan(self._api)
                     order_chan = TqChan(self._api)
                     self._quote_tasks[symbol] = {
                         "quote_chan": quote_chan,
                         "order_chan": order_chan,
                         "task": self._api.create_task(self._quote_handler(symbol, quote_chan, order_chan))
                     }
                 await self._quote_tasks[symbol]["order_chan"].send(pack)
             elif pack["aid"] == "cancel_order":
                 # pack 里只有 order_id 信息,发送到每一个合约的 order_chan, 交由 quote_task 判断是不是当前合约下的委托单
                 for symbol in self._quote_tasks:
                     await self._quote_tasks[symbol]["order_chan"].send(pack)
             else:
                 await self._md_send_chan.send(pack)
             if self._tqsdk_backtest != {} and self._tqsdk_backtest["current_dt"] >= self._tqsdk_backtest["end_dt"] \
                     and not self._tqsdk_stat:
                 # 回测情况下,把 _send_stat_report 在循环中回测结束时执行
                 await self._send_stat_report()
     finally:
         if not self._tqsdk_stat:
             await self._send_stat_report()
         md_task.cancel()
         tasks = [md_task]
         for symbol in self._quote_tasks:
             self._quote_tasks[symbol]["task"].cancel()
             tasks.append(self._quote_tasks[symbol]["task"])
         await asyncio.gather(*tasks, return_exceptions=True)
Пример #19
0
class TqStockProfit():
    """
    股票盈亏计算模块

    * 订阅已有持仓股票合约和行情
    * 计算股票持仓与资产的盈亏

    """
    def __init__(self, api):
        self._api = api
        self._data = Entity()  # 业务信息截面
        self._data._instance_entity([])
        self._diffs = []
        self._all_subscribe = set()

    async def _run(self, api_send_chan, api_recv_chan, md_send_chan,
                   md_recv_chan):

        self._logger = self._api._logger.getChild("TqStockProfit")
        self._api_send_chan = api_send_chan
        self._api_recv_chan = api_recv_chan
        self._md_send_chan = md_send_chan
        self._md_recv_chan = md_recv_chan
        md_task = self._api.create_task(self._md_handler())
        self._pending_peek = False
        try:
            async for pack in api_send_chan:
                if "_md_recv" in pack:
                    await self._md_recv(pack)
                    await self._send_diff()
                    if not self._is_diff_complete():
                        await self._md_send_chan.send({"aid": "peek_message"})
                elif pack["aid"] == "subscribe_quote":
                    await self._subscribe_quote(
                        set(pack["ins_list"].split(",")))
                elif pack["aid"] == "peek_message":
                    self._pending_peek = True
                    await self._send_diff()
                    if self._pending_peek:
                        await self._md_send_chan.send(pack)
                else:
                    await self._md_send_chan.send(pack)
        finally:
            md_task.cancel()

    async def _md_handler(self):
        """0 接收上游数据包 """
        async for pack in self._md_recv_chan:
            pack["_md_recv"] = True
            await self._api_send_chan.send(pack)

    async def _md_recv(self, pack):
        """ 处理下行数据包
        0 将行情数据和交易数据合并至 self._data
        1 生成增量业务截面, 该截面包含 持仓盈亏和资产盈亏信息
        """
        for d in pack.get("data", {}):
            if "quotes" in d:
                # 行情数据仅仅合并沪深两市的行情数据
                stock_quote = {
                    k: v
                    for k, v in d.get('quotes').items()
                    if k.startswith("SSE") or k.startswith("SZSE")
                }
                _simple_merge_diff(self._data, {"quotes": stock_quote})
            if "trade" in d:
                _simple_merge_diff(self._data, d)
            # 添加至 self._diff 等待被发送
            self._diffs.append(d)

        # 计算持仓和账户资产的盈亏增量截面
        pend_diff = await self._generate_pend_diff()
        self._diffs.append(pend_diff)

    async def _generate_pend_diff(self):
        """" 盈亏计算 """
        pend_diff = {}
        pend_diff.setdefault(
            'trade', {
                k: {
                    'accounts': {
                        'CNY': {}
                    },
                    'positions': {}
                }
                for k in self._data.get('trade', {})
            })
        # 计算持仓盈亏
        for account_key in self._data.get('trade', {}):
            # 盈亏计算仅仅计算股票账户
            if self._data['trade'].get(account_key,
                                       {}).get("account_type",
                                               "FUTURE") == "FUTURE":
                continue
            for symbol, _ in self._data['trade'][account_key].get(
                    'positions', {}).items():
                await self._subscribe_quote(symbol)
                last_price = self._data["quotes"].get(symbol, {}).get(
                    'last_price', float("nan"))
                if not math.isnan(last_price):
                    diff = self._update_position(account_key, symbol,
                                                 last_price)
                    pend_diff['trade'][account_key]['positions'][symbol] = diff
                    _simple_merge_diff(
                        self._data["trade"][account_key]["positions"],
                        {symbol: diff})

        # 当截面完整时, 全量刷新所有账户的资产盈亏
        if self._is_diff_complete():
            for account_key in self._data.get('trade', {}):
                if self._data['trade'].get(account_key,
                                           {}).get("account_type",
                                                   "FUTURE") == "FUTURE":
                    continue
                all_position = self._data["trade"][account_key].get(
                    "positions", {})
                pend_diff['trade'][account_key]['accounts']['CNY']['float_profit'] = \
                    sum([v.get('float_profit', 0) for k, v in all_position.items()])

        return pend_diff

    async def _send_diff(self):
        if self._pending_peek and self._is_diff_complete() and self._diffs:
            rtn_data = {
                "aid": "rtn_data",
                "data": self._diffs,
            }
            self._diffs = []
            self._pending_peek = False
            await self._api_recv_chan.send(rtn_data)

    async def _subscribe_quote(self, symbols: [set, str]):
        """这里只会增加订阅合约,不会退订合约"""
        symbols = symbols if isinstance(symbols, set) else {symbols}
        if symbols - self._all_subscribe:
            self._all_subscribe |= symbols
            await self._md_send_chan.send({
                "aid":
                "subscribe_quote",
                "ins_list":
                ",".join(self._all_subscribe)
            })

    def _update_position(self, key, symbol, last_price):
        """更新持仓盈亏"""
        diff = {}
        position = self._data["trade"][key]["positions"][symbol]
        diff["last_price"] = last_price
        diff["cost"] = position['cost_price'] * position['volume']
        diff["float_profit"] = (last_price -
                                position['cost_price']) * position['volume']
        return diff

    def _is_diff_complete(self):
        """当前账户截面是否已经完全处理完整, 即当所有股票的最新价不为空时"""
        for account_key in self._data.get('trade', {}):
            for symbol, _ in self._data['trade'][account_key].get(
                    'positions', {}).items():
                quote = self._data["quotes"].get(symbol, {})
                if math.isnan(quote.get('last_price', float("nan"))):
                    return False
        return True