class StrategyEngine(object): """策略引擎""" CACHE_MAXLEN = 10000 # ---------------------------------------------------------------------- def __init__(self, backtesting=False): """Constructor""" self.__event_engine = EventEngine() # 事件处理引擎 self.__account_manager = AccountManager() # 账户管理 self.__backtesting = backtesting # 是否为回测 self.__orders_done = {} # 保存所有已处理报单数据的字典 self.__orders_todo = {} # 保存所有未处理报单(即挂单)数据的字典 self.__deals = {} # 保存所有成交数据的字典 self.__positions = {} # Key:id, value:position with responding id self.__strategys = {} # 保存策略对象的字典,key为策略名称,value为策略对象 self.__data = {} # 统一的数据视图 self.__symbols = {} # key:(symbol,timeframe),value:maxlen self.start_time = None self.end_time = None self.__current_positions = {} # key:symbol,value:current position self.__initial_positions = {} # key:symbol,value:initial position # TODO单独放入utils中 symbols = property(partial(get_attr, attr='symbols'), None, None) start_time = property(partial(get_attr, attr='start_time'), partial(set_attr, attr='start_time'), None) end_time = property(partial(get_attr, attr='end_time'), partial(set_attr, attr='end_time'), None) # ---------------------------------------------------------------------- def get_current_contracts(self): # TODO 现在持仓手数 return () # ---------------------------------------------------------------------- def get_current_positions(self): # TODO 读取每个品种的有效Position return self.__current_positions def get_deals(self): return self.__deals def get_positions(self): return self.__positions # ---------------------------------------------------------------------- def get_data(self): return self.__data # ---------------------------------------------------------------------- def get_profit_records(self): """获取平仓收益记录""" return self.__account_manager.get_profit_records() # ---------------------------------------------------------------------- def get_traceback(self): pass # ---------------------------------------------------------------------- def get_position_records(self): """获取仓位收益记录""" def get_point(deal, entry, volume): if entry == DEAL_ENTRY_IN: if deal.type == DEAL_TYPE_BUY: return ({ 'type': 'point', 'x': deal.time + deal.time_msc / (10**6), 'y': deal.price, 'color': 'buy', 'text': 'Buy %s' % volume }) elif deal.type == DEAL_TYPE_SELL: return ({ 'type': 'point', 'x': deal.time + deal.time_msc / (10**6), 'y': deal.price, 'color': 'short', 'text': 'Short %s' % volume }) elif entry == DEAL_ENTRY_OUT: if deal.type == DEAL_TYPE_BUY: return ({ 'type': 'point', 'x': deal.time + deal.time_msc / (10**6), 'y': deal.price, 'color': 'cover', 'text': 'Cover %s' % volume }) elif deal.type == DEAL_TYPE_SELL: return ({ 'type': 'point', 'x': deal.time + deal.time_msc / (10**6), 'y': deal.price, 'color': 'sell', 'text': 'Sell %s' % volume }) def get_lines(position_start, position_end): deal_start = self.__deals[position_start.deal] deal_end = self.__deals[position_end.deal] start_time = deal_start.time + deal_start.time_msc / (10**6) end_time = deal_end.time + deal_end.time_msc / (10**6) result = { 'type': 'line', 'x_start': start_time, 'x_end': end_time, 'y_start': deal_start.price, 'y_end': deal_end.price } if (deal_end.type == DEAL_TYPE_BUY) ^ (deal_start.price >= deal_end.price): result['color'] = 'win' else: result['color'] = 'lose' def next_position(position): return self.__positions.get(position.next_id, None) def prev_position(position): return self.__positions.get(position.prev_id, None) result = [] stack = [] for symbol in {symbol for (symbol, _) in self.__symbols}: position = next_position(self.__init_positions[symbol]) while (position != None): deal = self.__deals[position.deal] if deal.entry == DEAL_ENTRY_IN: # open or overweight position result.append(get_point(deal, DEAL_ENTRY_IN, deal.volume)) stack.append((position, deal.volume)) else: if deal.entry == DEAL_ENTRY_INOUT: # reverse position volume_left = deal.volume - position.volume result.append( get_point(deal, DEAL_ENTRY_IN, position.volume)) else: # underweight position volume_left = deal.volume result.append(get_point(deal, DEAL_ENTRY_OUT, volume_left)) while volume_left > 0: position_start, volume = stack.pop() result.append(get_lines(position_start, position)) volume_left -= volume if volume_left < 0: stack.append(position_start, -volume_left) elif deal.entry == DEAL_ENTRY_INOUT and position.volume > 0: stack.append((position, position.volume)) position = next_position(position) return result # ---------------------------------------------------------------------- def set_capital_base(self, base): self.__account_manager.set_capital_base(base) # ---------------------------------------------------------------------- def add_symbols(self, symbols, time_frame, max_length=0): for symbol in symbols: if (symbol, time_frame) not in self.__symbols: self.__symbols[(symbol, time_frame)] = max_length self.__symbols[(symbol, time_frame)] = max( max_length, self.__symbols[(symbol, time_frame)]) self.register_event(EVENT_BAR_SYMBOL[symbol][time_frame], self.update_bar_data) # ---------------------------------------------------------------------- def initialize(self): # TODO 数据结构还需修改 self.__deals.clear() self.__positions.clear() self.__data.clear() # TODO 这里的auto_inc是模块级别的,需要修改成对象级别的。 Deal.set_auto_inc(0) Position.set_auto_inc(0) self.__current_positions.clear() for (symbol, time_frame), maxlen in self.__symbols.items(): if symbol not in self.__data: self.__data[symbol] = {} if time_frame not in self.__data[symbol]: self.__data[symbol][time_frame] = {} if maxlen == 0: maxlen = self.CACHE_MAXLEN for field in ['open', 'high', 'low', 'close', 'time', 'volume']: self.__data[symbol][time_frame][field] = deque(maxlen=maxlen) if symbol not in self.__current_positions: position = Position(symbol) self.__current_positions[symbol] = position self.__initial_positions[symbol] = position self.__positions[position.get_id()] = position # ---------------------------------------------------------------------- def add_file(self, file): self.__event_engine.add_file(file) # ---------------------------------------------------------------------- def add_strategy(self, strategy): """添加已创建的策略实例""" self.__strategys[strategy.get_id()] = strategy strategy.engine = self # ---------------------------------------------------------------------- def update_market_data(self, event): """行情更新""" # TODO行情数据 pass # ---------------------------------------------------------------------- def update_bar_data(self, event): bar = event.content['data'] symbol = bar.symbol time_frame = bar.time_frame for field in ['open', 'high', 'low', 'close', 'time', 'volume']: self.__data[symbol][time_frame][field].appendleft( getattr(bar, field)) # ---------------------------------------------------------------------- def __process_order(self, tick): """处理停止单""" pass # ---------------------------------------------------------------------- def update_order(self, event): """报单更新""" # TODO 成交更新 # ---------------------------------------------------------------------- def update_trade(self, event): """成交更新""" # TODO 成交更新 pass # ---------------------------------------------------------------------- def __update_position(self, deal): def sign(num): if abs(num) <= 10**-7: return 0 elif num > 0: return 1 else: return -1 if deal.volume == 0: return position_prev = self.__current_positions[deal.symbol] position_now = Position(deal.symbol, deal.strategy, deal.handle) position_now.prev_id = position_prev.get_id() position_prev.next_id = position_now.get_id() position = position_prev.type # XXX常量定义改变这里的映射函数也可能改变 if deal.type * position >= 0: deal.entry = DEAL_ENTRY_IN if position == 0: # open position position_now.price_open = deal.price position_now.time_open = deal.time position_now.time_open_msc = deal.time_msc else: # overweight position position_now.time_open = position_prev.time_open position_now.time_open_msc = position_prev.time_open_msc position_now.volume = deal.volume + position_prev.volume position_now.type = deal.type position_now.price_current = ( position_prev.price_current * position_prev.volume + deal.price * deal.volume) / position_now.volume else: contracts = position_prev.volume - deal.volume position_now.volume = abs(contracts) position_now.type = position * sign(contracts) if position_now.type == 0: # close position deal.entry = DEAL_ENTRY_OUT deal.profit = (deal.price - position_prev.price_current ) * position * position_prev.volume position_now.price_current = 0 position_now.volume = 0 # 防止浮点数精度可能引起的问题 position_now.time_open = position_prev.time_open position_now.time_open_msc = position_prev.time_open_msc elif position_now != position: # reverse position deal.entry = DEAL_ENTRY_INOUT deal.profit = (deal.price - position_prev.price_current ) * position * position_prev.volume position_now.price_current = deal.price position_now.time_open = deal.time position_now.time_open_msc = deal.time_msc position_now.price_open = position_now.price_current else: # underweight position # XXX 平部分仓位是直接计算入平仓收益还是将收益暂时算在浮动中 deal.entry = DEAL_ENTRY_OUT deal.profit = (deal.price - position_prev.price_current ) * position * deal.volume position_now.price_current = position_prev.price_current position_now.time_open = position_prev.time_open position_now.time_open_msc = position_prev.time_open_msc position_now.time_update = deal.time position_now.time_update_msc = deal.time_msc deal.position = position_now.get_id() position_now.deal = deal.get_id() self.__current_positions[deal.symbol] = position_now self.__positions[position_now.get_id()] = position_now self.__deals[deal.get_id()] = deal if deal.profit != 0: self.__account_manager.update_deal(deal) # ---------------------------------------------------------------------- @staticmethod def check_order(order): if not isinstance(order, Order): return False # TODO更多关于订单合法性的检查 return True # ---------------------------------------------------------------------- def __send_order_to_broker(self, order): if self.__backtesting: time_frame = SymbolsListener.get_by_id( order.handle).get_time_frame() time_ = self.__data[order.symbol][time_frame]["time"][ 0] + time_frame_to_seconds(time_frame) order.time_done = int(time_) order.time_done_msc = int((time_ - int(time_)) * (10**6)) order.volume_current = order.volume_initial deal = Deal(order.symbol, order.strategy, order.handle) deal.volume = order.volume_current deal.time = order.time_done deal.time_msc = order.time_done_msc deal.type = 1 - ( (order.type & 1) << 1) # 参见ENUM_ORDER_TYPE和ENUM_DEAL_TYPE的定义 deal.price = self.__data[order.symbol][time_frame]["close"][0] # TODO加入手续费等 order.deal = deal.get_id() deal.order = order.get_id() return [deal], {} # TODO 市价单成交 else: pass # TODO 实盘交易 # ---------------------------------------------------------------------- def send_order(self, order): """ 发单(仅允许限价单) symbol:合约代码 direction:方向,DIRECTION_BUY/DIRECTION_SELL offset:开平,OFFSET_OPEN/OFFSET_CLOSE price:下单价格 volume:下单手数 strategy:策略对象 """ # TODO 更多属性的处理 if self.check_order(order): if order.type <= 1: # market order # send_order_to_broker = async_handle(self.__event_engine, self.__update_position)(self.__send_order_to_broker) # send_order_to_broker(order) result = self.__send_order_to_broker(order) self.__update_position(*result[0]) else: self.__orders_todo[order.get_id()] = order return True else: return False # ---------------------------------------------------------------------- def cancel_order(self, order_id): """ 撤单 """ if order_id == 0: self.__orders_todo = {} else: if order_id in self.__orders_todo: del (self.__orders_todo[order_id]) # ---------------------------------------------------------------------- def put_event(self, event): # TODO 加入验证 # TODO 多了一层函数调用,尝试用绑定的形式 self.__event_engine.put(event) # ---------------------------------------------------------------------- def register_event(self, event_type, handle): """注册事件监听""" # TODO 加入验证 self.__event_engine.register(event_type, handle) def unregister_event(self, event_type, handle): """取消事件监听""" self.__event_engine.unregister(event_type, handle) # ---------------------------------------------------------------------- def writeLog(self, log): """写日志""" event = Event(type_=EVENT_LOG) event.content['log'] = log self.__event_engine.put(event) # ---------------------------------------------------------------------- def start(self): """启动所有策略""" self.__event_engine.start() for strategy in self.__strategys.values(): strategy.start() # ---------------------------------------------------------------------- def stop(self): """停止所有策略""" self.__event_engine.stop() for strategy in self.__strategys.values(): strategy.stop() def wait(self): """等待所有事件处理完毕""" self.__event_engine.wait() self.stop() # TODO 对限价单的支持 # ---------------------------------------------------------------------- def sell(self, symbol, volume=1, price=None, stop=False, limit=False, strategy=None, listener=None): if volume == 0: return position = self.__current_positions.get(symbol, None) if not position or position.type <= 0: return # XXX可能的返回值 order = Order(symbol, ORDER_TYPE_SELL, strategy, listener) order.volume_initial = volume if self.__backtesting: time_ = self.__data[symbol][SymbolsListener.get_by_id( listener).get_time_frame()]['time'][0] else: time_ = time.time() order.time_setup = int(time_) order.time_setup_msc = int((time_ - int(time_)) * (10**6)) return self.send_order(order) # ---------------------------------------------------------------------- def buy(self, symbol, volume=1, price=None, stop=False, limit=False, strategy=None, listener=None): if self.__backtesting: time_ = self.__data[symbol][SymbolsListener.get_by_id( listener).get_time_frame()]['time'][0] else: time_ = time.time() position = self.__current_positions.get(symbol, None) if position and position.type < 0: order = Order(symbol, ORDER_TYPE_BUY, strategy, listener) order.volume_initial = position.volume order.time_setup = int(time_) order.time_setup_msc = int((time_ - int(time_)) * (10**6)) # TODO 这里应该要支持事务性的下单操作 self.send_order(order) if volume == 0: return order = Order(symbol, ORDER_TYPE_BUY, strategy, listener) order.volume_initial = volume order.time_setup = int(time_) order.time_setup_msc = int((time_ - int(time_)) * (10**6)) return self.send_order(order) # ---------------------------------------------------------------------- def cover(self, symbol, volume=1, price=None, stop=False, limit=False, strategy=None, listener=None): if volume == 0: return position = self.__current_positions.get(symbol, None) order = Order(symbol, ORDER_TYPE_BUY, strategy, listener) if not position or position.type >= 0: return # XXX可能的返回值 order.volume_initial = volume if self.__backtesting: time_ = self.__data[symbol][SymbolsListener.get_by_id( listener).get_time_frame()]['time'][0] else: time_ = time.time() order.time_setup = int(time_) order.time_setup_msc = int((time_ - int(time_)) * (10**6)) return self.send_order(order) # ---------------------------------------------------------------------- def short(self, symbol, volume=1, price=None, stop=False, limit=False, strategy=None, listener=None): if self.__backtesting: time_ = self.__data[symbol][SymbolsListener.get_by_id( listener).get_time_frame()]['time'][0] else: time_ = time.time() position = self.__current_positions.get(symbol, None) if position and position.type > 0: order = Order(symbol, ORDER_TYPE_SELL, strategy, listener) order.volume_initial = position.volume order.time_setup = int(time_) order.time_setup_msc = int((time_ - int(time_)) * (10**6)) # TODO 这里应该要支持事务性的下单操作 self.send_order(order) if volume == 0: return order = Order(symbol, ORDER_TYPE_SELL, strategy, listener) order.volume_initial = volume order.time_setup = int(time_) order.time_setup_msc = int((time_ - int(time_)) * (10**6)) return self.send_order(order)
class StrategyEngine(object): """策略引擎""" CACHE_MAXLEN = 10000 # ---------------------------------------------------------------------- def __init__(self, is_backtest=False, **config): """Constructor""" self.__config = config self.__event_engine = EventEngine() # 事件处理引擎 if is_backtest: self.__account_manager = AccountManager(self, **config) # 账户管理 else: self.__account_manager = FDTAccountManager(self, **config) self.mongo_user = MongoUser(self.__config['user']) self.__trade_manager = TradeManager(self, is_backtest, **config) # 交易管理器 self.__data_cache = DataCache(self, is_backtest) # 数据中继站 self.__strategys = {} # 策略管理器 self.__profit_records = [] # 保存账户净值的列表 def set_account(self, account): assert isinstance(account, AccountManager) self.__account_manager = account def get_data(self): return self.__data_cache.data def get_symbol_pool(self): return self.__data_cache.symbol_pool def get_current_positions(self): return self.__trade_manager.current_positions def get_current_time(self): return self.__data_cache.current_time def get_positions(self): return self.__trade_manager.positions def get_deals(self): return self.__trade_manager.deals def get_strategys(self): return self.__strategys def get_profit_records(self): """获取平仓收益记录""" return self.__profit_records def get_symbol_timeframe(self): return self.__data_cache.get_cache_info().keys() def get_capital_cash(self): return self.__account_manager.capital_cash def get_capital_net(self): return self.__account_manager.capital_net def get_capital_available(self): return self.__account_manager.capital_available # XXX之所以不用装饰器的方式是考虑到不知经过一层property会不会影响效率,所以保留用get_XXX直接访问 # property: current_time = property(get_current_time) symbol_pool = property(get_symbol_pool) data = property(get_data) current_positions = property(get_current_positions) positions = property(get_positions) deal = property(get_deals) strategys = property(get_strategys) profit_records = property(get_profit_records) symbol_timeframe = property(get_symbol_timeframe) capital_cash = property(get_capital_cash) capital_net = property(get_capital_net) capital_available = property(get_capital_available) def get_counter_price(self, code, time_frame): """ 计算当前对应货币对的报价货币(counter currency)兑美元的价格 :param code: 品种代码 :param time_frame: 时间尺度 :return: 当前对应货币对的报价货币(counter currency)兑美元的价格 """ symbol = self.symbol_pool[code] if symbol.code.endswith('USD'): # 间接报价 base_price = 1 elif symbol.code.startswith('USD'): # 直接报价 base = symbol.code[-3:] base_price = 1 / self.data[time_frame]['close']['USD' + base][0] else: # 交叉盘 base = symbol.code[-3:] if base + 'USD' in symbol.ALL.index: base_price = self.data[time_frame]['close'][base + 'USD'][0] elif 'USD' + base in symbol.ALL.index: base_price = 1 / self.data[time_frame]['close']['USD' + base][0] else: raise ValueError('找不到基准报价:%s' % base) return base_price def get_base_price(self, code, time_frame): """ 计算当前对应货币对的基准货币(base currency)兑美元的价格 :param code: 品种代码 :param time_frame: 时间尺度 :return: 当前对应货币对的报价货币(base currency)兑美元的价格 """ symbol = self.symbol_pool[code] if symbol.code.startswith('USD'): # 间接报价 base_price = 1 elif symbol.code.endswith('USD'): # 直接报价 base = symbol.code[:3] base_price = self.data[time_frame]['close'][base + 'USD'][0] else: # 交叉盘 base = symbol.code[:3] if base + 'USD' in symbol.ALL.index: base_price = self.data[time_frame]['close'][base + 'USD'][0] elif 'USD' + base in symbol.ALL.index: base_price = 1 / self.data[time_frame]['close']['USD' + base][0] else: raise ValueError('找不到基准报价:%s' % base) return base_price def get_capital(self): return self.__account_manager.get_api() def profit_record(self, *args, **kwargs): return self.__account_manager.profit_record(*args, **kwargs) def update_cash(self, deal): return self.__account_manager.update_cash(deal) def send_order_to_broker(self, order): return self.__account_manager.send_order_to_broker(order) def order_status(self): return self.__account_manager.order_status() def position_status(self, *args, **kwargs): return self.__account_manager.position_status(*args, **kwargs) def open_position(self, *args, **kwargs): return self.__trade_manager.open_position(*args, **kwargs) def close_position(self, *args, **kwargs): return self.__trade_manager.close_position(*args, **kwargs) def set_capital_base(self, base): self.__account_manager.capital_base = base # ---------------------------------------------------------------------- def add_cache_info(self, *args, **kwargs): self.__data_cache.add_cache_info(*args, **kwargs) # TODO 从全局的品种池中查询 # ---------------------------------------------------------------------- def add_file(self, file): self.__event_engine.add_file(file) # ---------------------------------------------------------------------- def add_strategy(self, strategy): """添加已创建的策略实例""" self.__strategys[strategy.get_id()] = strategy strategy.engine = self # ---------------------------------------------------------------------- def put_event(self, event): # TODO 加入验证 # TODO 多了一层函数调用,尝试用绑定的形式 self.__event_engine.put(event) # ---------------------------------------------------------------------- def register_event(self, event_type, handle): """注册事件监听""" # TODO 加入验证 self.__event_engine.register(event_type, handle) def unregister_event(self, event_type, handle): """取消事件监听""" self.__event_engine.unregister(event_type, handle) # ---------------------------------------------------------------------- def write_log(self, log): """写日志""" self.__event_engine.put(Event(type=EVENT_LOG, log=log)) # ---------------------------------------------------------------------- def start(self): """启动所有策略""" for strategy in self.__strategys.values(): strategy.start() self.__profit_records.clear() self.__data_cache.start() self.__trade_manager.init() self.__event_engine.start() self.__account_manager.initialize() # ---------------------------------------------------------------------- def stop(self): """停止所有策略""" self.__event_engine.stop() self.__data_cache.stop() for strategy in self.__strategys.values(): strategy.stop() self._recycle() # 释放资源 # ---------------------------------------------------------------------- def _recycle(self): self.__data_cache.stop() self.__trade_manager.recycle() def wait(self, call_back=None, finished=True, *args, **kwargs): """等待所有事件处理完毕 :param call_back: 运行完成时的回调函数 :param finish: 向下兼容,finish为True时,事件队列处理完成时结束整个回测引擎;为False时只是调用回调函数,继续挂起回测引擎。 """ self.__event_engine.wait() if call_back: result = call_back(*args, **kwargs) else: result = None if finished: self.stop() return result
class StrategyEngine(object): """策略引擎""" CACHE_MAXLEN = 10000 # ---------------------------------------------------------------------- def __init__(self, is_backtest=False): """Constructor""" self.__event_engine = EventEngine() # 事件处理引擎 self.__account_manager = AccountManager() # 账户管理 self.__trade_manager = TradeManager(self, is_backtest) # 交易管理器 self.__data_cache = DataCache(self) # 数据中继站 self.__strategys = {} # 策略管理器 def get_data(self): return self.__data_cache.data def get_symbol_pool(self): return self.__data_cache.symbol_pool def get_current_positions(self): return self.__trade_manager.current_positions def get_current_time(self): return self.__data_cache.current_time def get_positions(self): return self.__trade_manager.positions def get_deals(self): return self.__trade_manager.deals def get_strategys(self): return self.__strategys def get_profit_records(self): """获取平仓收益记录""" return self.__account_manager.get_profit_records() def get_symbol_timeframe(self): return self.__data_cache.get_cache_info().keys() # XXX之所以不用装饰器的方式是考虑到不知经过一层property会不会影响效率,所以保留用get_XXX直接访问 # property: current_time = property(get_current_time) symbol_pool = property(get_symbol_pool) data = property(get_data) current_positions = property(get_current_positions) positions = property(get_positions) deal = property(get_deals) strategys = property(get_strategys) profit_records = property(get_profit_records) symbol_timeframe = property(get_symbol_timeframe) def update_deal(self, deal): self.__account_manager.update_deal(deal) def open_position(self, *args, **kwargs): self.__trade_manager.open_position(*args, **kwargs) def close_position(self, *args, **kwargs): self.__trade_manager.close_position(*args, **kwargs) def set_capital_base(self, base): self.__account_manager.set_capital_base(base) # ---------------------------------------------------------------------- def add_cache_info(self, *args, **kwargs): self.__data_cache.add_cache_info(*args, **kwargs) # TODO 从全局的品种池中查询 # ---------------------------------------------------------------------- def add_file(self, file): self.__event_engine.add_file(file) # ---------------------------------------------------------------------- def add_strategy(self, strategy): """添加已创建的策略实例""" self.__strategys[strategy.get_id()] = strategy strategy.engine = self # ---------------------------------------------------------------------- def put_event(self, event): # TODO 加入验证 # TODO 多了一层函数调用,尝试用绑定的形式 self.__event_engine.put(event) # ---------------------------------------------------------------------- def register_event(self, event_type, handle): """注册事件监听""" # TODO 加入验证 self.__event_engine.register(event_type, handle) def unregister_event(self, event_type, handle): """取消事件监听""" self.__event_engine.unregister(event_type, handle) # ---------------------------------------------------------------------- def write_log(self, log): """写日志""" self.__event_engine.put(Event(type=EVENT_LOG, log=log)) # ---------------------------------------------------------------------- def start(self): """启动所有策略""" for strategy in self.__strategys.values(): strategy.start() self.__data_cache.start() self.__trade_manager.init() self.__event_engine.start() # ---------------------------------------------------------------------- def stop(self): """停止所有策略""" self.__event_engine.stop() self.__data_cache.stop() for strategy in self.__strategys.values(): strategy.stop() self._recycle() # 释放资源 # ---------------------------------------------------------------------- def _recycle(self): self.__data_cache.stop() self.__trade_manager.recycle() self.__account_manager.initialize() def wait(self, call_back=None, finished=True, *args, **kwargs): """等待所有事件处理完毕 :param call_back: 运行完成时的回调函数 :param finish: 向下兼容,finish为True时,事件队列处理完成时结束整个回测引擎;为False时只是调用回调函数,继续挂起回测引擎。 """ self.__event_engine.wait() result = call_back(*args, **kwargs) if finished: self._set_finished() self.stop() return result def _set_finished(self): # 标记即不会再有新数据到来 self.__event_engine.set_finished()
class StrategyEngine(LoggerInterface, Runnable, ConfigInterface, APIInterface): """策略引擎""" def __init__(self, parent=None): """Constructor""" LoggerInterface.__init__(self, parent=parent) Runnable.__init__(self) ConfigInterface.__init__(self, parent=parent) APIInterface.__init__(self) self.__event_engine = EventEngine(parent=self) # 事件处理引擎 self.__quotation_manager = QuotationManager(self, parent=self) # 行情数据管理器 if self.config.running_mode == RunningMode.backtest: self.__account_manager = BfAccountManager(parent=self) # 账户管理 else: self.__account_manager = FDTAccountManager(parent=self) self.mongo_user = MongoUser(self.config.user) self.__trading_manager = TradingManager(self, self.__quotation_manager, self.__account_manager, parent=self) # 交易管理器 if self.config.running_mode == RunningMode.backtest: self.__account_manager.set_trading_manager(self.__trading_manager) self.__strategys = {} # 策略管理器 self.__profit_records = [] # 保存账户净值的列表 self.logger_name = "StrategyEngine" def set_account(self, account): assert isinstance(account, AccountManager) self.__account_manager = account @property def current_time(self): return self.__quotation_manager.current_time @property def positions(self): return self.__trading_manager.positions @property def deals(self): return self.__trading_manager.deals @property def strategys(self): return self.__strategys @property def profit_records(self): """获取平仓收益记录""" return self.__profit_records @property def max_margin(self): return self.__trading_manager.max_margin def profit_record(self, *args, **kwargs): return self.__account_manager.profit_record(*args, **kwargs) def realize_order(self): self.__trading_manager.realize_order() def add_cache_info(self, *args, **kwargs): self.__quotation_manager.add_cache_info(*args, **kwargs) # TODO 从全局的品种池中查询 def add_file(self, file): self.__event_engine.add_file(file) def add_strategy(self, strategy): """添加已创建的策略实例""" self.__strategys[strategy.get_id()] = strategy strategy.engine = self def put_event(self, event): # TODO 加入验证 # TODO 多了一层函数调用,尝试用绑定的形式 self.__event_engine.put(event) def register_event(self, event_type, handle): """注册事件监听""" # TODO 加入验证 self.__event_engine.register(event_type, handle) def unregister_event(self, event_type, handle): """取消事件监听""" self.__event_engine.unregister(event_type, handle) def write_log(self, log): """写日志""" self.__event_engine.put(Event(type=EVENT_LOG, log=log)) def _start(self): """启动所有策略""" self.__profit_records.clear() self.__quotation_manager.start() self.__trading_manager.init() self.__event_engine.start() self.__account_manager.initialize() for strategy in self.__strategys.values(): strategy.start() def _stop(self): """停止所有策略""" for strategy in self.__strategys.values(): strategy.stop() self.__event_engine.stop() self.__quotation_manager.stop() self._recycle() # 释放资源 def _recycle(self): self.__quotation_manager.stop() self.__trading_manager.recycle() # TODO finished的参数设计有点问题 def wait(self, call_back=None, finished=True, *args, **kwargs): """等待所有事件处理完毕 :param call_back: 运行完成时的回调函数 :param finished: 向下兼容,finish为True时,事件队列处理完成时结束整个回测引擎;为False时只是调用回调函数,继续挂起回测引擎。 """ self.__event_engine.wait() if call_back: result = call_back(*args, **kwargs) else: result = None if finished: self.stop() return result def get_APIs(self, strategy=None, signal=None, symbols=None, time_frame=None) -> Globals: APIs = Globals({}, {}) APIs.update(self.__account_manager.get_APIs()) APIs.update(self.__quotation_manager.get_APIs(symbols=symbols, time_frame=time_frame)) APIs.update(self.__trading_manager.get_APIs(strategy=strategy, signal=signal, symbol=symbols[0])) return APIs