Ejemplo n.º 1
0
    def is_trading_time(self, cur_datetime:datetime)->bool:
        """
        判断当前时间是否属于交易时间段

        :param cur_datetime:
        :return:
        """
        # TODO: 先判断是否交易日
        cur_time = Time(hour=cur_datetime.hour, minute=cur_datetime.minute, second=cur_datetime.second)
        return (self.TRADING_HOURS_AM[0]<=cur_time<=self.TRADING_HOURS_AM[1]) or (self.TRADING_HOURS_PM[0]<=cur_time<=self.TRADING_HOURS_PM[1])
Ejemplo n.º 2
0
    def is_trading_time(self, cur_datetime:datetime)->bool:
        """
        判断当前时间是否属于交易时间段

        :param cur_datetime:
        :return:
        """
        is_trading_day = cur_datetime.strftime("%Y-%m-%d") in self.trading_days_list
        if not is_trading_day:
            return False
        cur_time = Time(hour=cur_datetime.hour, minute=cur_datetime.minute, second=cur_datetime.second)
        return (self.TRADING_HOURS_AM[0]<=cur_time<=self.TRADING_HOURS_AM[1]) or (self.TRADING_HOURS_PM[0]<=cur_time<=self.TRADING_HOURS_PM[1])
Ejemplo n.º 3
0
    def next_trading_datetime(self, cur_datetime:datetime, security:Stock)->datetime:
        """
        根据已有数据寻找下一个属于交易时间的时间点,如果找不到,返回None

        :param cur_datetime:
        :param security:
        :return:
        """
        # 移动一个时间单位,看是否属于交易时间
        next_datetime = cur_datetime + relativedelta(seconds=self.TIME_STEP)
        next_time = Time(hour=next_datetime.hour, minute=next_datetime.minute, second=next_datetime.second)
        next_trading_daytime = None

        # 如果下一个时间点,不属于交易日;或者已经超出pm交易时间,找到并返回下一个交易日的开盘时间
        if (next_datetime.strftime("%Y-%m-%d") not in self.trading_days[security]) or (next_time>self.TRADING_HOURS_PM[1]):
            for trading_day in self.trading_days[security]:
                year, month, day = trading_day.split("-")
                trade_datetime = datetime(
                    int(year),
                    int(month),
                    int(day),
                    self.TRADING_HOURS_AM[0].hour,
                    self.TRADING_HOURS_AM[0].minute,
                    self.TRADING_HOURS_AM[0].second
                )
                if trade_datetime>=next_datetime:
                    next_trading_daytime = trade_datetime
                    break
        # 如果下一个时间点属于交易日,并且没有超出pm交易时间,则找到并返回上午或者下午的开盘时间
        elif (not self.is_trading_time(next_datetime)):
            if next_time < self.TRADING_HOURS_AM[0]:
                next_trading_daytime = datetime(
                    next_datetime.year,
                    next_datetime.month,
                    next_datetime.day,
                    self.TRADING_HOURS_AM[0].hour,
                    self.TRADING_HOURS_AM[0].minute,
                    self.TRADING_HOURS_AM[0].second
                )
            elif next_time < self.TRADING_HOURS_PM[0]:
                next_trading_daytime = datetime(
                    next_datetime.year,
                    next_datetime.month,
                    next_datetime.day,
                    self.TRADING_HOURS_PM[0].hour,
                    self.TRADING_HOURS_PM[0].minute,
                    self.TRADING_HOURS_PM[0].second
                )
        # 如果下一个时间点属于交易日,并且属于交易时间段内,则直接返回下一个时间点
        else:
            next_trading_daytime = next_time
        return next_trading_daytime
Ejemplo n.º 4
0
class BacktestGateway(BaseGateway):

    # 定义交易时间 (港股)
    TRADING_HOURS_AM = [Time(9, 30, 0), Time(12, 0, 0)]
    TRADING_HOURS_PM = [Time(13, 0, 0), Time(16, 0, 0)]

    # 定义最小时间单位 (秒)
    TIME_STEP = 60

    # 参数设定
    SHORT_INTEREST_RATE = 0.0098  # 融券利息

    def __init__(
        self,
        securities: List[Stock],
        start: datetime,
        end: datetime,
        dtype: List[str] = ["open", "high", "low", "close", "volume"]
    ) -> Dict[Stock, Iterator]:
        """
        历史数据分派器

        :param securities:
        :param start:
        :param end:
        :param dtype:
        :return:
        """
        super().__init__(securities)
        data_iterators = {}
        trading_days = {}
        for security in securities:
            full_data = _get_full_data(security=security,
                                       start=start,
                                       end=end,
                                       dtype=dtype)
            data_it = _get_data_iterator(security=security,
                                         full_data=full_data)
            data_iterators[security] = data_it
            trading_days[security] = sorted(
                set(t.split(" ")[0] for t in full_data["time_key"].values))
        self.data_iterators = data_iterators
        self.trading_days = trading_days
        trading_days_list = set()
        for k, v in self.trading_days.items():
            trading_days_list.update(v)
        self.trading_days_list = sorted(trading_days_list)
        self.prev_cache = {s: None for s in securities}
        self.next_cache = {s: None for s in securities}
        self.start = start
        self.end = end
        self.market_datetime = start

    def set_trade_mode(self, trade_mode: TradeMode):
        """设置交易模式"""
        self.trade_mode = trade_mode

    def is_trading_time(self, cur_datetime: datetime) -> bool:
        """
        判断当前时间是否属于交易时间段

        :param cur_datetime:
        :return:
        """
        is_trading_day = cur_datetime.strftime(
            "%Y-%m-%d") in self.trading_days_list
        if not is_trading_day:
            return False
        cur_time = Time(hour=cur_datetime.hour,
                        minute=cur_datetime.minute,
                        second=cur_datetime.second)
        return (
            self.TRADING_HOURS_AM[0] <= cur_time <= self.TRADING_HOURS_AM[1]
        ) or (self.TRADING_HOURS_PM[0] <= cur_time <= self.TRADING_HOURS_PM[1])

    def next_trading_datetime(self, cur_datetime: datetime,
                              security: Stock) -> datetime:
        """
        根据已有数据寻找下一个属于交易时间的时间点,如果找不到,返回None

        :param cur_datetime:
        :param security:
        :return:
        """
        # 移动一个时间单位,看是否属于交易时间
        next_datetime = cur_datetime + relativedelta(seconds=self.TIME_STEP)
        next_time = Time(hour=next_datetime.hour,
                         minute=next_datetime.minute,
                         second=next_datetime.second)
        next_trading_daytime = None

        # 如果下一个时间点,不属于交易日;或者已经超出pm交易时间,找到并返回下一个交易日的开盘时间
        if (next_datetime.strftime("%Y-%m-%d")
                not in self.trading_days[security]) or (
                    next_time > self.TRADING_HOURS_PM[1]):
            for trading_day in self.trading_days[security]:
                year, month, day = trading_day.split("-")
                trade_datetime = datetime(int(year), int(month), int(day),
                                          self.TRADING_HOURS_AM[0].hour,
                                          self.TRADING_HOURS_AM[0].minute,
                                          self.TRADING_HOURS_AM[0].second)
                if trade_datetime >= next_datetime:
                    next_trading_daytime = trade_datetime
                    break
        # 如果下一个时间点属于交易日,并且没有超出pm交易时间,则找到并返回上午或者下午的开盘时间
        elif (not self.is_trading_time(next_datetime)):
            if next_time < self.TRADING_HOURS_AM[0]:
                next_trading_daytime = datetime(
                    next_datetime.year, next_datetime.month, next_datetime.day,
                    self.TRADING_HOURS_AM[0].hour,
                    self.TRADING_HOURS_AM[0].minute,
                    self.TRADING_HOURS_AM[0].second)
            elif next_time < self.TRADING_HOURS_PM[0]:
                next_trading_daytime = datetime(
                    next_datetime.year, next_datetime.month, next_datetime.day,
                    self.TRADING_HOURS_PM[0].hour,
                    self.TRADING_HOURS_PM[0].minute,
                    self.TRADING_HOURS_PM[0].second)
        # 如果下一个时间点属于交易日,并且属于交易时间段内,则直接返回下一个时间点
        else:
            next_trading_daytime = next_time
        return next_trading_daytime

    def get_recent_bar(self, security: Stock, cur_datetime: datetime):
        """
        获取最接近当前时间的数据点

        :param security:
        :param cur_time:
        :return:
        """
        assert cur_datetime >= self.market_datetime, f"历史不能回头,当前时间{cur_datetime}在dispatcher的系统时间{self.market_datetime}之前了"
        data_it = self.data_iterators[security]
        data_prev = self.prev_cache[security]
        data_next = self.next_cache[security]

        if cur_datetime > self.end:
            pass

        elif (data_prev is None) and (data_next is None):
            bar = next(data_it)
            if bar.datetime > cur_datetime:
                self.next_cache[security] = bar
            else:
                while bar.datetime <= cur_datetime:
                    self.prev_cache[security] = bar
                    bar = next(data_it)
                self.next_cache[security] = bar

        else:
            if self.next_cache[security].datetime <= cur_datetime:
                self.prev_cache[security] = self.next_cache[security]
                try:
                    bar = next(data_it)
                    while bar.datetime <= cur_datetime:
                        self.prev_cache[security] = bar
                        bar = next(data_it)
                    self.next_cache[security] = bar
                except StopIteration:
                    pass

        self.market_datetime = cur_datetime
        return self.prev_cache[security]

    def place_order(self, order: Order):
        """最简单的处理,假设全部成交"""
        order.filled_time = self.market_datetime
        order.filled_quantity = order.quantity
        order.filled_avg_price = order.price
        order.status = OrderStatus.FILLED
        orderid = "bt-order-" + str(uuid.uuid4())
        dealid = "bt-deal-" + str(uuid.uuid4())
        self.orders.put(orderid, order)

        deal = Deal(security=order.security,
                    direction=order.direction,
                    offset=order.offset,
                    order_type=order.order_type,
                    updated_time=self.market_datetime,
                    filled_avg_price=order.price,
                    filled_quantity=order.quantity,
                    dealid=dealid,
                    orderid=orderid)
        self.deals.put(dealid, deal)

        return orderid

    def cancel_order(self, orderid):
        """取消订单"""
        order = self.orders.get(orderid)
        if order.status in (OrderStatus.FILLED, OrderStatus.CANCELLED,
                            OrderStatus.FAILED):
            print(f"不能取消订单{orderid},因为订单状态已经为{order.status}")
            return
        order.status = OrderStatus.CANCELLED
        self.orders.put(orderid, order)
Ejemplo n.º 5
0
class FutuGateway(BaseGateway):

    # 定义交易时间 (港股)
    TRADING_HOURS_AM = [Time(9, 30, 0), Time(12, 0, 0)]
    TRADING_HOURS_PM = [Time(13, 0, 0), Time(16, 0, 0)]

    # 定义最小时间单位 (秒)
    TIME_STEP = 60

    # 参数设定
    SHORT_INTEREST_RATE = 0.0098  # 融券利息

    # 名字
    NAME = "FUTU"

    def __init__(
        self,
        securities: List[Stock],
        start: datetime = None,
        end: datetime = None,
        fees: BaseFees = FutuHKEquityFees,
    ):
        super().__init__(securities)
        self.fees = fees
        self.start = start
        self.end = end

        self.trade_mode = None

        self.quote_ctx = OpenQuoteContext(host=FUTU["host"], port=FUTU["port"])
        self.connect_quote()
        self.subscribe()

        self.trd_ctx = OpenHKTradeContext(host=FUTU["host"], port=FUTU["port"])
        self.connect_trade()

    def close(self):
        self.quote_ctx.close()  # 关闭当条连接,FutuOpenD会在1分钟后自动取消相应股票相应类型的订阅
        self.trd_ctx.close()  # 关闭交易通道

    def connect_quote(self):
        """
        行情需要处理报价和订单簿
        """
        class QuoteHandler(StockQuoteHandlerBase):
            gateway = self

            def on_recv_rsp(self, rsp_str):
                ret_code, content = super(QuoteHandler,
                                          self).on_recv_rsp(rsp_str)
                if ret_code != RET_OK:
                    return RET_ERROR, content
                self.gateway.process_quote(content)
                return RET_OK, content

        class OrderBookHandler(OrderBookHandlerBase):
            gateway = self

            def on_recv_rsp(self, rsp_str):
                ret_code, content = super(OrderBookHandler,
                                          self).on_recv_rsp(rsp_str)
                if ret_code != RET_OK:
                    return RET_ERROR, content
                self.gateway.process_orderbook(content)
                return RET_OK, content

        self.quote_ctx.set_handler(QuoteHandler())
        self.quote_ctx.set_handler(OrderBookHandler())
        self.quote_ctx.start()
        print("行情接口连接成功")

    def connect_trade(self):
        """
        交易需要处理订单和成交
        """
        class TradeOrderHandler(TradeOrderHandlerBase):
            gateway = self

            def on_recv_rsp(self, rsp_str):
                ret_code, content = super(TradeOrderHandler,
                                          self).on_recv_rsp(rsp_str)
                if ret_code != RET_OK:
                    return RET_ERROR, content
                self.gateway.process_order(content)
                return RET_OK, content

        class TradeDealHandler(TradeDealHandlerBase):
            gateway = self

            def on_recv_rsp(self, rsp_str):
                ret_code, content = super(TradeDealHandler,
                                          self).on_recv_rsp(rsp_str)
                if ret_code != RET_OK:
                    return RET_ERROR, content
                self.gateway.process_deal(content)
                return RET_OK, content

        self.trd_ctx.set_handler(TradeOrderHandler())
        self.trd_ctx.set_handler(TradeDealHandler())
        print(self.trd_ctx.unlock_trade(FUTU["pwd_unlock"]))
        self.trd_ctx.start()
        print("交易接口连接成功")

    def process_quote(self, content: pd.DataFrame):
        """更新报价的状态"""
        stock = self.get_stock(code=content['code'].values[0])
        if stock is None:
            return
        svr_datetime_str = content["data_date"].values[0] + " " + content[
            "data_time"].values[0]
        svr_datetime = try_parsing_datetime(svr_datetime_str)
        quote = Quote(
            security=stock,
            exchange=stock.exchange,
            datetime=svr_datetime,
            last_price=content['last_price'].values[0],
            open_price=content['open_price'].values[0],
            high_price=content['high_price'].values[0],
            low_price=content['last_price'].values[0],
            prev_close_price=content['prev_close_price'].values[0],
            volume=content['volume'].values[0],
            turnover=content['turnover'].values[0],
            turnover_rate=content['turnover_rate'].values[0],
            amplitude=content['amplitude'].values[0],
            suspension=content['suspension'].values[0],
            price_spread=content['price_spread'].values[0],
            sec_status=content['sec_status'].values[0],
        )
        self.quote.put(stock, quote)

    def process_orderbook(self, content: Dict):
        """更新订单簿的状态"""
        stock = self.get_stock(code=content['code'])
        if stock is None:
            return
        svr_datetime = max(
            try_parsing_datetime(content['svr_recv_time_bid']),
            try_parsing_datetime(content['svr_recv_time_ask']),
        )
        orderbook = OrderBook(security=stock,
                              exchange=stock.exchange,
                              datetime=svr_datetime)
        for i, bid in enumerate(content['Bid']):
            setattr(orderbook, f"bid_price_{i+1}", bid[0])
            setattr(orderbook, f"bid_volume_{i+1}", bid[1])
            setattr(orderbook, f"bid_num_{i+1}", bid[2])
        for i, ask in enumerate(content['Ask']):
            setattr(orderbook, f"ask_price_{i+1}", ask[0])
            setattr(orderbook, f"ask_volume_{i+1}", ask[1])
            setattr(orderbook, f"ask_num_{i+1}", ask[2])
        self.orderbook.put(stock, orderbook)

    def process_order(self, content: pd.DataFrame):
        """更新订单的状态"""
        orderid = content["order_id"].values[0]
        order = self.orders.get(orderid)  # blocking
        order.updated_time = try_parsing_datetime(
            content["updated_time"].values[0])
        order.filled_avg_price = content["dealt_avg_price"].values[0]
        order.filled_quantity = content["dealt_qty"].values[0]
        order.status = convert_orderstatus_futu2qt(
            content["order_status"].values[0])
        # 富途的仿真环境不推送deal,需要在这里进行模拟处理
        if self.trade_mode == TradeMode.SIMULATE and order.status in (
                QTOrderStatus.FILLED, QTOrderStatus.PART_FILLED):
            dealid = "futu-sim-deal-" + str(uuid.uuid4())
            deal = Deal(security=order.security,
                        direction=order.direction,
                        offset=order.offset,
                        order_type=order.order_type,
                        updated_time=order.updated_time,
                        filled_avg_price=order.filled_avg_price,
                        filled_quantity=order.filled_quantity,
                        dealid=dealid,
                        orderid=orderid)
            self.deals.put(dealid, deal)
        self.orders.put(orderid, order)

    def process_deal(self, content: pd.DataFrame):
        """更新成交的信息"""
        orderid = content["order_id"].values[0]
        dealid = content["deal_id"].values[0]
        order = self.orders.get(orderid)  # blocking
        deal = Deal(security=order.security,
                    direction=order.direction,
                    offset=order.offset,
                    order_type=order.order_type,
                    updated_time=try_parsing_datetime(
                        content["create_time"].values[0]),
                    filled_avg_price=content["price"].values[0],
                    filled_quantity=content["qty"].values[0],
                    dealid=dealid,
                    orderid=orderid)
        self.deals.put(dealid, deal)

    @property
    def market_datetime(self):
        return datetime.now()

    def set_trade_mode(self, trade_mode: TradeMode):
        self.trade_mode = trade_mode
        self.futu_trd_env = convert_trade_mode_qt2futu(trade_mode)

    def subscribe(self):
        # TODO: 暂时写死了订阅1分钟k线
        codes = [s.code for s in self.securities]
        ret_sub, err_message = self.quote_ctx.subscribe(
            codes, [SubType.K_1M, SubType.QUOTE, SubType.ORDER_BOOK],
            subscribe_push=True)
        # 订阅成功后FutuOpenD将持续收到服务器的推送,False代表暂时不需要推送给脚本
        if ret_sub == RET_OK:  # 订阅成功
            print(f"成功订阅1min K线、报价和订单簿: {self.securities}")
        else:
            raise ValueError(f"订阅失败: {err_message}")

    def is_trading_time(self, cur_datetime: datetime) -> bool:
        """
        判断当前时间是否属于交易时间段

        :param cur_datetime:
        :return:
        """
        # TODO: 先判断是否交易日
        cur_time = Time(hour=cur_datetime.hour,
                        minute=cur_datetime.minute,
                        second=cur_datetime.second)
        return (
            self.TRADING_HOURS_AM[0] <= cur_time <= self.TRADING_HOURS_AM[1]
        ) or (self.TRADING_HOURS_PM[0] <= cur_time <= self.TRADING_HOURS_PM[1])

    def get_recent_bar(self,
                       security: Stock,
                       cur_datetime: datetime = None,
                       num_of_bars: int = 1) -> Union[Bar, List[Bar]]:
        """
        获取最接近当前时间的数据点
,
        :param security:
        :param cur_time:
        :return:
        """
        ret_code, data = self.quote_ctx.get_cur_kline(
            security.code, num_of_bars, SubType.K_1M,
            AuType.QFQ)  # 获取港股00700最近2个K线数据
        if ret_code:
            print('error:', data)
            return
        bars = []
        for i in range(data.shape[0]):
            bar_time = datetime.strptime(data.loc[i, "time_key"],
                                         "%Y-%m-%d %H:%M:%S")
            # if bar_time>cur_datetime:
            #     break
            bar = Bar(datetime=bar_time,
                      security=security,
                      open=data.loc[i, "open"],
                      high=data.loc[i, "high"],
                      low=data.loc[i, "low"],
                      close=data.loc[i, "close"],
                      volume=data.loc[i, "volume"])
            bars.append(bar)
        if len(bars) == 1:
            return bars[0]
        else:
            return bars

    def get_stock(self, code: str) -> Stock:
        """根据股票代号,找到对应的股票"""
        for stock in self.securities:
            if stock.code == code:
                return stock
        return None

    def place_order(self, order: Order) -> str:
        """提交订单"""
        ret_code, data = self.trd_ctx.place_order(
            price=order.price,
            qty=order.quantity,
            code=order.security.code,
            trd_side=convert_direction_qt2futu(order.direction),
            trd_env=self.futu_trd_env)
        if ret_code:
            print(f"提交订单失败:{data}")
            return ""
        orderid = data["order_id"].values[0]  # 如果成功提交订单,一定会返回一个orderid
        order.status = QTOrderStatus.SUBMITTED  # 修改状态为已提交
        self.orders.put(orderid, order)  # 稍后通过callback更新order状态
        return orderid

    def cancel_order(self, orderid):
        """取消订单"""
        ret_code, data = self.trd_ctx.modify_order(ModifyOrderOp.CANCEL,
                                                   orderid,
                                                   0,
                                                   0,
                                                   trd_env=self.futu_trd_env)
        if ret_code:
            print(f"撤单失败:{data}")

    def get_broker_balance(self) -> AccountBalance:
        """获取券商资金"""
        ret_code, data = self.trd_ctx.accinfo_query(trd_env=self.futu_trd_env)
        if ret_code:
            print(f"获取券商资金失败:{data}")
            return
        balance = AccountBalance()
        balance.cash = data["cash"].values[0]
        balance.power = data["power"].values[0]
        balance.max_power_short = data["max_power_short"].values[0]
        balance.net_cash_power = data["net_cash_power"].values[0]
        return balance

    def get_broker_position(self, security: Stock,
                            direction: Direction) -> PositionData:
        """获取券商持仓"""
        positions = self.get_all_broker_positions()
        for position_data in positions:
            if position_data.security == security and position_data.direction == direction:
                return position_data
        return None

    def get_all_broker_positions(self) -> List[PositionData]:
        """获取券商所有持仓"""
        ret_code, data = self.trd_ctx.position_list_query(
            trd_env=self.futu_trd_env)
        if ret_code:
            print(f"获取券商所有持仓失败:{data}")
            return
        positions = []
        for idx, row in data.iterrows():
            security = self.get_stock(code=row["code"])
            if security is None:
                security = Stock(code=row["code"],
                                 stock_name=row["stock_name"])
            position_data = PositionData(
                security=security,
                direction=Direction.LONG
                if row["position_side"] == "LONG" else Direction.SHORT,
                holding_price=row["cost_price"],
                quantity=row["qty"],
                update_time=datetime.now(),
            )
            positions.append(position_data)
        return positions

    def get_quote(self, security: Stock) -> Quote:
        """获取报价"""
        return self.quote.get(security)

    def get_orderbook(self, security: Stock) -> OrderBook:
        """获取订单簿"""
        return self.orderbook.get(security)

    def get_capital_distribution(self, security: Stock) -> CapitalDistribution:
        """capital distribution"""
        ret_code, data = self.quote_ctx.get_capital_distribution(security.code)
        if ret_code:
            print(f"获取资金分布失败:{data}")
            return
        cap_dist = CapitalDistribution(
            datetime=datetime.strptime(data["update_time"].values[0],
                                       "%Y-%m-%d %H:%M:%S"),
            security=security,
            capital_in_big=data["capital_in_big"].values[0],
            capital_in_mid=data["capital_in_mid"].values[0],
            capital_in_small=data["capital_in_small"].values[0],
            capital_out_big=data["capital_out_big"].values[0],
            capital_out_mid=data["capital_out_mid"].values[0],
            capital_out_small=data["capital_out_small"].values[0])
        return cap_dist
Ejemplo n.º 6
0
class BacktestGateway(BaseGateway):

    # 定义交易时间 (港股)
    TRADING_HOURS_AM = [Time(9, 30, 0), Time(12, 0, 0)]
    TRADING_HOURS_PM = [Time(13, 0, 0), Time(16, 0, 0)]

    # 定义最小时间单位 (秒)
    TIME_STEP = 60

    # 参数设定
    SHORT_INTEREST_RATE = 0.0098  # 融券利息

    # 名字
    NAME = "BACKTEST"

    # 指定的数据字段
    DTYPES = None

    def __init__(
            self,
            securities: List[Stock],
            start: datetime,
            end: datetime,
            dtypes: Dict[str, List[str]] = dict(
                k1m=["time_key", "open", "high", "low", "close", "volume"]),
            fees: BaseFees = FutuHKEquityFees,  # 默认是港股富途收费
    ) -> Dict[Stock, Iterator]:
        """
        历史数据分派器

        :param securities:
        :param start:
        :param end:
        :param dtypes: key与DATA_PATH对应,e.g. {"k1m": ["open", "high", "low", "close", "volume"]}
        :param fees:
        :return:
        """
        self.DTYPES = dtypes
        assert set(dtypes.keys()) == set(DATA_PATH.keys()), (
            f"在{self.__class__.__name__}的__init__函数里,"
            f"输入参数dtypes的键值必须与DATA_PATH里的设定一致,dtypes需输入以下数据:{','.join(DATA_PATH.keys())},"
            f"但目前只有:{','.join(dtypes.keys())}")
        super().__init__(securities)
        self.fees = fees
        data_iterators = dict()
        prev_cache = dict()
        next_cache = dict()
        trading_days = dict()
        for security in securities:
            data_iterators[security] = dict()
            prev_cache[security] = dict()
            next_cache[security] = dict()
            for dfield in DATA_PATH.keys():  # k线数据 | 资金分布数据
                # 存储进生成器字典data_iterators
                data = _get_data(security=security,
                                 start=start,
                                 end=end,
                                 dfield=dfield,
                                 dtype=dtypes[dfield])
                data_it = _get_data_iterator(security=security,
                                             full_data=data,
                                             class_name=DATA_MODEL[dfield])
                data_iterators[security][dfield] = data_it
                # 初始化数据缓存
                prev_cache[security][dfield] = None
                next_cache[security][dfield] = None
                # 记录回测交易日(以分钟k线日历为准)
                if dfield == "k1m":
                    trading_days[security] = sorted(
                        set(t.split(" ")[0] for t in data["time_key"].values))
        self.data_iterators = data_iterators
        self.prev_cache = prev_cache
        self.next_cache = next_cache
        self.trading_days = trading_days
        trading_days_list = set()
        for k, v in self.trading_days.items():
            trading_days_list.update(v)
        self.trading_days_list = sorted(trading_days_list)

        self.start = start
        self.end = end
        self.market_datetime = start

    def set_trade_mode(self, trade_mode: TradeMode):
        """设置交易模式"""
        self.trade_mode = trade_mode

    def is_trading_time(self, cur_datetime: datetime) -> bool:
        """
        判断当前时间是否属于交易时间段

        :param cur_datetime:
        :return:
        """
        is_trading_day = cur_datetime.strftime(
            "%Y-%m-%d") in self.trading_days_list
        if not is_trading_day:
            return False
        cur_time = Time(hour=cur_datetime.hour,
                        minute=cur_datetime.minute,
                        second=cur_datetime.second)
        return (
            self.TRADING_HOURS_AM[0] <= cur_time <= self.TRADING_HOURS_AM[1]
        ) or (self.TRADING_HOURS_PM[0] <= cur_time <= self.TRADING_HOURS_PM[1])

    def next_trading_datetime(self, cur_datetime: datetime,
                              security: Stock) -> datetime:
        """
        根据已有数据寻找下一个属于交易时间的时间点,如果找不到,返回None

        :param cur_datetime:
        :param security:
        :return:
        """
        # 移动一个时间单位,看是否属于交易时间
        next_datetime = cur_datetime + relativedelta(seconds=self.TIME_STEP)
        next_time = Time(hour=next_datetime.hour,
                         minute=next_datetime.minute,
                         second=next_datetime.second)
        next_trading_daytime = None

        # 如果下一个时间点,不属于交易日;或者已经超出pm交易时间,找到并返回下一个交易日的开盘时间
        if (next_datetime.strftime("%Y-%m-%d")
                not in self.trading_days[security]) or (
                    next_time > self.TRADING_HOURS_PM[1]):
            for trading_day in self.trading_days[security]:
                year, month, day = trading_day.split("-")
                trade_datetime = datetime(int(year), int(month), int(day),
                                          self.TRADING_HOURS_AM[0].hour,
                                          self.TRADING_HOURS_AM[0].minute,
                                          self.TRADING_HOURS_AM[0].second)
                if trade_datetime >= next_datetime:
                    next_trading_daytime = trade_datetime
                    break
        # 如果下一个时间点属于交易日,并且没有超出pm交易时间,则找到并返回上午或者下午的开盘时间
        elif (not self.is_trading_time(next_datetime)):
            if next_time < self.TRADING_HOURS_AM[0]:
                next_trading_daytime = datetime(
                    next_datetime.year, next_datetime.month, next_datetime.day,
                    self.TRADING_HOURS_AM[0].hour,
                    self.TRADING_HOURS_AM[0].minute,
                    self.TRADING_HOURS_AM[0].second)
            elif next_time < self.TRADING_HOURS_PM[0]:
                next_trading_daytime = datetime(
                    next_datetime.year, next_datetime.month, next_datetime.day,
                    self.TRADING_HOURS_PM[0].hour,
                    self.TRADING_HOURS_PM[0].minute,
                    self.TRADING_HOURS_PM[0].second)
        # 如果下一个时间点属于交易日,并且属于交易时间段内,则直接返回下一个时间点
        else:
            next_trading_daytime = next_time
        return next_trading_daytime

    def get_recent_data(
        self, security: Stock, cur_datetime: datetime, **kwargs
    ) -> Dict[str, Union[Bar, CapitalDistribution]] or Union[
            Bar, CapitalDistribution]:
        """
        获取最接近当前时间的数据点

        :param security:
        :param cur_time:
        :return:
        """
        assert cur_datetime >= self.market_datetime, f"历史不能回头,当前时间{cur_datetime}在dispatcher的系统时间{self.market_datetime}之前了"
        if kwargs:
            assert "dfield" in kwargs, f"`dfield` should be passed in as kwargs, but kwargs={kwargs}"
            dfields = [kwargs["dfield"]]
        else:
            dfields = DATA_PATH
        data_it = dict()
        data_prev = dict()
        data_next = dict()
        for dfield in dfields:
            data_it[dfield] = self.data_iterators[security][dfield]
            data_prev[dfield] = self.prev_cache[security][dfield]
            data_next[dfield] = self.next_cache[security][dfield]

            if cur_datetime > self.end:
                pass

            elif (data_prev[dfield] is None) and (data_next[dfield] is None):
                data = next(data_it[dfield])
                if data.datetime > cur_datetime:
                    self.next_cache[security][dfield] = data
                else:
                    while data.datetime <= cur_datetime:
                        self.prev_cache[security][dfield] = data
                        data = next(data_it[dfield])
                    self.next_cache[security][dfield] = data

            else:
                if self.next_cache[security][dfield].datetime <= cur_datetime:
                    self.prev_cache[security][dfield] = self.next_cache[
                        security][dfield]
                    try:
                        data = next(data_it[dfield])
                        while data.datetime <= cur_datetime:
                            self.prev_cache[security][dfield] = data
                            data = next(data_it[dfield])
                        self.next_cache[security][dfield] = data
                    except StopIteration:
                        pass

        self.market_datetime = cur_datetime
        if len(dfields) == 1:
            return self.prev_cache[security][dfield]
        return self.prev_cache[security]

    def place_order(self, order: Order) -> str:
        """最简单的处理,假设全部成交"""
        order.filled_time = self.market_datetime
        order.filled_quantity = order.quantity
        order.filled_avg_price = order.price
        order.status = OrderStatus.FILLED
        orderid = "bt-order-" + str(uuid.uuid4())
        dealid = "bt-deal-" + str(uuid.uuid4())
        self.orders.put(orderid, order)

        deal = Deal(security=order.security,
                    direction=order.direction,
                    offset=order.offset,
                    order_type=order.order_type,
                    updated_time=self.market_datetime,
                    filled_avg_price=order.price,
                    filled_quantity=order.quantity,
                    dealid=dealid,
                    orderid=orderid)
        self.deals.put(dealid, deal)

        return orderid

    def cancel_order(self, orderid):
        """取消订单"""
        order = self.orders.get(orderid)
        if order.status in (OrderStatus.FILLED, OrderStatus.CANCELLED,
                            OrderStatus.FAILED):
            print(f"不能取消订单{orderid},因为订单状态已经为{order.status}")
            return
        order.status = OrderStatus.CANCELLED
        self.orders.put(orderid, order)

    def get_broker_balance(self) -> AccountBalance:
        """获取券商资金 (回测此接口不可用)"""
        return None

    def get_broker_position(self, security: Stock,
                            direction: Direction) -> PositionData:
        """获取券商持仓 (回测此接口不可用)"""
        return None

    def get_all_broker_positions(self) -> List[PositionData]:
        """获取券商所有持仓 (回测此接口不可用)"""
        return None

    def get_quote(self, security: Stock) -> Quote:
        """获取报价 (回测此接口不可用)"""
        return None

    def get_orderbook(self, security: Stock) -> OrderBook:
        """获取订单簿 (回测此接口不可用)"""
        return None
Ejemplo n.º 7
0
class FutuGateway(BaseGateway):

    # 定义交易时间 (港股)
    TRADING_HOURS_AM = [Time(9,30,0), Time(12,0,0)]
    TRADING_HOURS_PM = [Time(13,0,0), Time(16,0,0)]

    # 定义最小时间单位 (秒)
    TIME_STEP = 60

    # 参数设定
    SHORT_INTEREST_RATE = 0.0098  # 融券利息


    def __init__(self,
                 securities: List[Stock],
                 start: datetime = None,
                 end: datetime = None,
        ):
        super().__init__(securities)
        self.start = start
        self.end = end

        self.trade_mode = None

        self.trd_ctx = OpenHKTradeContext(host=FUTU["host"], port=FUTU["port"])
        self.connect_trade()

        self.quote_ctx = OpenQuoteContext(host=FUTU["host"], port=FUTU["port"])
        self.subscribe()

    def close(self):
        self.quote_ctx.close() # 关闭当条连接,FutuOpenD会在1分钟后自动取消相应股票相应类型的订阅
        self.trd_ctx.close()   # 关闭交易通道

    def connect_trade(self):
        """交易需要处理订单和成交"""
        class TradeOrderHandler(TradeOrderHandlerBase):
            gateway = self
            def on_recv_rsp(self, rsp_str):
                ret_code, content = super(TradeOrderHandler, self).on_recv_rsp(
                    rsp_str
                )
                if ret_code != RET_OK:
                    return RET_ERROR, content
                self.gateway.process_order(content)
                return RET_OK, content

        class TradeDealHandler(TradeDealHandlerBase):
            gateway = self
            def on_recv_rsp(self, rsp_str):
                ret_code, content = super(TradeDealHandler, self).on_recv_rsp(
                    rsp_str
                )
                if ret_code != RET_OK:
                    return RET_ERROR, content
                self.gateway.process_deal(content)
                return RET_OK, content

        self.trd_ctx.set_handler(TradeOrderHandler())
        self.trd_ctx.set_handler(TradeDealHandler())
        print(self.trd_ctx.unlock_trade(FUTU["pwd_unlock"]))

    def process_order(self, content:pd.DataFrame):
        """更新订单的状态"""
        orderid = content["order_id"].values[0]
        order = self.orders.get(orderid) # blocking
        order.updated_time = try_parsing_datetime(content["updated_time"].values[0])
        order.filled_avg_price = content["dealt_avg_price"].values[0]
        order.filled_quantity = content["dealt_qty"].values[0]
        order.status = convert_orderstatus_futu2qt(content["order_status"].values[0])
        self.orders.put(orderid, order)

    def process_deal(self, content: pd.DataFrame):
        """更新成交的信息"""
        orderid = content["order_id"].values[0]
        dealid = content["deal_id"].values[0]
        order = self.orders.get(orderid) # blocking
        deal = Deal(
            security=order.security,
            direction=order.direction,
            offset=order.offset,
            order_type=order.order_type,
            updated_time=try_parsing_datetime(content["create_time"].values[0]),
            filled_avg_price=content["price"].values[0],
            filled_quantity=content["qty"].values[0],
            dealid=dealid,
            orderid=orderid
        )
        self.deals.put(dealid, deal)

    @property
    def market_datetime(self):
        return datetime.now()

    def set_trade_mode(self, trade_mode:TradeMode):
        self.trade_mode = trade_mode
        self.futu_trd_env = convert_trade_mode_qt2futu(trade_mode)

    def subscribe(self):
        # TODO: 暂时写死了订阅1分钟k线
        codes = [s.code for s in self.securities]
        ret_sub, err_message = self.quote_ctx.subscribe(codes, [SubType.K_1M], subscribe_push=False)
        # 先订阅K 线类型。订阅成功后FutuOpenD将持续收到服务器的推送,False代表暂时不需要推送给脚本
        if ret_sub == RET_OK:  # 订阅成功
            print(f"成功订阅1M k线: {self.securities}")
        else:
            print(f"订阅失败: {err_message}")

    def is_trading_time(self, cur_datetime:datetime)->bool:
        """
        判断当前时间是否属于交易时间段

        :param cur_datetime:
        :return:
        """
        # TODO: 先判断是否交易日
        cur_time = Time(hour=cur_datetime.hour, minute=cur_datetime.minute, second=cur_datetime.second)
        return (self.TRADING_HOURS_AM[0]<=cur_time<=self.TRADING_HOURS_AM[1]) or (self.TRADING_HOURS_PM[0]<=cur_time<=self.TRADING_HOURS_PM[1])

    def get_recent_bar(self, security:Stock, cur_datetime:datetime=None, num_of_bars:int=1)->Union[Bar, List[Bar]]:
        """
        获取最接近当前时间的数据点
,
        :param security:
        :param cur_time:
        :return:
        """
        ret, data = self.quote_ctx.get_cur_kline(security.code, num_of_bars, SubType.K_1M, AuType.QFQ)  # 获取港股00700最近2个K线数据
        if ret == RET_OK:
            bars = []
            for i in range(data.shape[0]):
                bar_time = datetime.strptime(data.loc[i, "time_key"], "%Y-%m-%d %H:%M:%S")
                # if bar_time>cur_datetime:
                #     break
                bar = Bar(
                    datetime = bar_time,
                    security = security,
                    open = data.loc[i, "open"],
                    high = data.loc[i, "high"],
                    low = data.loc[i, "low"],
                    close = data.loc[i, "close"],
                    volume = data.loc[i, "volume"]
                )
                bars.append(bar)
            if len(bars)==1:
                return bars[0]
            else:
                return bars
        else:
            print('error:', data)

    def place_order(self, order:Order):
        """提交订单"""
        code, data = self.trd_ctx.place_order(
            price=order.price,
            qty=order.quantity,
            code=order.security.code,
            trd_side=convert_direction_qt2futu(order.direction),
            trd_env=self.futu_trd_env
        )
        if code:
            print(f"提交订单失败:{data}")
            return ""
        orderid = data["order_id"].values[0]   # 如果成功提交订单,一定会返回一个orderid
        order.status = QTOrderStatus.SUBMITTED # 修改状态为已提交
        self.orders.put(orderid, order)        # 稍后通过callback更新order状态
        return orderid

    def cancel_order(self, orderid):
        """取消订单"""
        code, data = self.trd_ctx.modify_order(
            ModifyOrderOp.CANCEL,
            orderid,
            0,
            0,
            trd_env=self.futu_trd_env
        )
        if code:
            print(f"撤单失败:{data}")

    def get_balance(self):
        """获取资金"""
        data = self.trd_ctx.accinfo_query(trd_env=self.futu_trd_env)
        balance = AccountBalance()
        for col in data.columns:
            if col in ('risk_level', 'risk_status'):
                continue
            setattr(balance, col, data[col].values[0])
        return balance

    def get_position(self):
        """获取持仓"""
        return
Ejemplo n.º 8
0
class BacktestGateway:

    # 定义交易时间 (港股)
    TRADING_HOURS_AM = [Time(9, 30, 0), Time(12, 0, 0)]
    TRADING_HOURS_PM = [Time(13, 0, 0), Time(16, 0, 0)]

    # 定义最小时间单位 (秒)
    TIME_STEP = 60

    # 参数设定
    SHORT_INTEREST_RATE = 0.0098  # 融券利息

    def __init__(
        self,
        securities: List[Stock],
        start: datetime,
        end: datetime,
        dtype: List[str] = ["open", "high", "low", "close", "volume"]
    ) -> Dict[Stock, Iterator]:
        """
        历史数据分派器

        :param securities:
        :param start:
        :param end:
        :param dtype:
        :return:
        """
        data_iterators = {}
        trading_days = {}
        for security in securities:
            full_data = _get_full_data(security=security,
                                       start=start,
                                       end=end,
                                       dtype=dtype)
            data_it = _get_data_iterator(security=security,
                                         full_data=full_data)
            data_iterators[security] = data_it
            trading_days[security] = sorted(
                set(t.split(" ")[0] for t in full_data["time_key"].values))
        self.data_iterators = data_iterators
        self.trading_days = trading_days
        trading_days_list = set()
        for k, v in self.trading_days.items():
            trading_days_list.update(v)
        self.trading_days_list = sorted(trading_days_list)
        self.prev_cache = {s: None for s in securities}
        self.next_cache = {s: None for s in securities}
        self.start = start
        self.end = end
        self.datetime = start

    def is_trading_time(self, cur_datetime: datetime) -> bool:
        """
        判断当前时间是否属于交易时间段

        :param cur_datetime:
        :return:
        """
        is_trading_day = cur_datetime.strftime(
            "%Y-%m-%d") in self.trading_days_list
        if not is_trading_day:
            return False
        cur_time = Time(hour=cur_datetime.hour,
                        minute=cur_datetime.minute,
                        second=cur_datetime.second)
        return (
            self.TRADING_HOURS_AM[0] <= cur_time <= self.TRADING_HOURS_AM[1]
        ) or (self.TRADING_HOURS_PM[0] <= cur_time <= self.TRADING_HOURS_PM[1])

    def next_trading_datetime(self, cur_datetime: datetime,
                              security: Stock) -> datetime:
        """
        根据已有数据寻找下一个属于交易时间的时间点,如果找不到,返回None

        :param cur_datetime:
        :param security:
        :return:
        """
        # 移动一个时间单位,看是否属于交易时间
        next_datetime = cur_datetime + relativedelta(seconds=self.TIME_STEP)
        next_time = Time(hour=next_datetime.hour,
                         minute=next_datetime.minute,
                         second=next_datetime.second)
        next_trading_daytime = None

        # 如果下一个时间点,不属于交易日;或者已经超出pm交易时间,找到并返回下一个交易日的开盘时间
        if (next_datetime.strftime("%Y-%m-%d")
                not in self.trading_days[security]) or (
                    next_time > self.TRADING_HOURS_PM[1]):
            for trading_day in self.trading_days[security]:
                year, month, day = trading_day.split("-")
                trade_datetime = datetime(int(year), int(month), int(day),
                                          self.TRADING_HOURS_AM[0].hour,
                                          self.TRADING_HOURS_AM[0].minute,
                                          self.TRADING_HOURS_AM[0].second)
                if trade_datetime >= next_datetime:
                    next_trading_daytime = trade_datetime
                    break
        # 如果下一个时间点属于交易日,并且没有超出pm交易时间,则找到并返回上午或者下午的开盘时间
        elif (not self.is_trading_time(next_datetime)):
            if next_time < self.TRADING_HOURS_AM[0]:
                next_trading_daytime = datetime(
                    next_datetime.year, next_datetime.month, next_datetime.day,
                    self.TRADING_HOURS_AM[0].hour,
                    self.TRADING_HOURS_AM[0].minute,
                    self.TRADING_HOURS_AM[0].second)
            elif next_time < self.TRADING_HOURS_PM[0]:
                next_trading_daytime = datetime(
                    next_datetime.year, next_datetime.month, next_datetime.day,
                    self.TRADING_HOURS_PM[0].hour,
                    self.TRADING_HOURS_PM[0].minute,
                    self.TRADING_HOURS_PM[0].second)
        # 如果下一个时间点属于交易日,并且属于交易时间段内,则直接返回下一个时间点
        else:
            next_trading_daytime = next_time
        return next_trading_daytime

    def get_recent_data(self, cur_datetime: datetime, security: Stock):
        """
        获取最接近当前时间的数据点

        :param security:
        :param cur_time:
        :return:
        """
        assert cur_datetime >= self.datetime, f"历史不能回头,当前时间{cur_datetime}在dispatcher的系统时间{self.datetime}之前了"
        data_it = self.data_iterators[security]
        data_prev = self.prev_cache[security]
        data_next = self.next_cache[security]

        if cur_datetime > self.end:
            pass

        elif (data_prev is None) and (data_next is None):
            bar = next(data_it)
            if bar.datetime > cur_datetime:
                self.next_cache[security] = bar
            else:
                while bar.datetime <= cur_datetime:
                    self.prev_cache[security] = bar
                    bar = next(data_it)
                self.next_cache[security] = bar

        else:
            if self.next_cache[security].datetime <= cur_datetime:
                self.prev_cache[security] = self.next_cache[security]
                try:
                    bar = next(data_it)
                    while bar.datetime <= cur_datetime:
                        self.prev_cache[security] = bar
                        bar = next(data_it)
                    self.next_cache[security] = bar
                except StopIteration:
                    pass

        self.datetime = cur_datetime
        return self.prev_cache[security]

    def process_order(self, order: Order):
        """最简单的处理,假设全部成交"""
        order.filled_time = self.datetime
        return order