def refresh_tree(self) -> None: """""" self.clear_tree() data = self.engine.get_bar_data_available() for d in data: key = (d["symbol"], d["exchange"], d["interval"]) item = self.tree_items.get(key, None) if not item: item = QtWidgets.QTreeWidgetItem() self.tree_items[key] = item item.setText(1, ".".join([d["symbol"], d["exchange"]])) item.setText(2, d["symbol"]) item.setText(3, d["exchange"]) if d["interval"] == Interval.MINUTE.value: self.minute_child.addChild(item) elif d["interval"] == Interval.HOUR.value: self.hour_child.addChild(item) else: self.daily_child.addChild(item) output_button = QtWidgets.QPushButton("导出") output_func = partial(self.output_data, d["symbol"], Exchange(d["exchange"]), Interval(d["interval"]), d["start"], d["end"]) output_button.clicked.connect(output_func) show_button = QtWidgets.QPushButton("查看") show_func = partial(self.show_data, d["symbol"], Exchange(d["exchange"]), Interval(d["interval"]), d["start"], d["end"]) show_button.clicked.connect(show_func) delete_button = QtWidgets.QPushButton("删除") delete_func = partial( self.delete_data, d["symbol"], Exchange(d["exchange"]), Interval(d["interval"]), ) delete_button.clicked.connect(delete_func) self.tree.setItemWidget(item, 7, show_button) self.tree.setItemWidget(item, 8, output_button) self.tree.setItemWidget(item, 9, delete_button) item.setText(4, str(d["count"])) item.setText(5, d["start"].strftime("%Y-%m-%d %H:%M:%S")) item.setText(6, d["end"].strftime("%Y-%m-%d %H:%M:%S")) self.minute_child.setExpanded(True) self.hour_child.setExpanded(True) self.daily_child.setExpanded(True)
def get_bar_data_available(self) -> List[Dict]: """""" data = database_manager.get_bar_data_statistics() for d in data: oldest_bar = database_manager.get_oldest_bar_data( d["symbol"], Exchange(d["exchange"]), Interval(d["interval"])) d["start"] = oldest_bar.datetime newest_bar = database_manager.get_newest_bar_data( d["symbol"], Exchange(d["exchange"]), Interval(d["interval"])) d["end"] = newest_bar.datetime return data
def set_parameters( self, vt_symbol: str, interval: Interval, rate: float, slippage: float, size: float, pricetick: float, enginetype: EngineType, capital: int = 0, mode: BacktestingMode = BacktestingMode.BAR, inverse: bool = False, risk_free: float = 0, annual_days: int = 240 ): """""" self.mode = mode self.vt_symbol = vt_symbol self.interval = Interval(interval) self.rate = rate self.slippage = slippage self.size = size self.pricetick = pricetick self.enginetype = enginetype self.symbol, exchange_str = self.vt_symbol.split(".") self.exchange = Exchange(exchange_str) self.capital = capital self.mode = mode self.inverse = inverse self.risk_free = risk_free self.annual_days = annual_days
def load_bar_data(self, filter: dict) -> List[BarData]: """加载历史数据""" if self.bar_collection.count_documents(filter) == 0: self.output("无符合条件的历史数据") return else: bars = [] c: Cursor = self.bar_collection.find(filter) for d in c: d["exchange"] = Exchange(d["exchange"]) d["interval"] = Interval(d["interval"]) d["gateway_name"] = "DB" d.pop("_id") bar = BarData(**d) bars.append(bar) bars = sorted(bars, key=lambda bar: bar.datetime) self.output( f"已从{self.history_db.name}库{self.bar_collection.name}集合加载{bars[0].datetime}至{bars[-1].datetime}K线数据{len(bars)}条" ) return bars
def run_downloading( self, vt_symbol: str, interval: str, start: datetime, end: datetime ): """ Query bar data from RQData. """ self.write_log(f"{vt_symbol}-{interval}开始下载历史数据") symbol, exchange = extract_vt_symbol(vt_symbol) data = rqdata_client.query_bar( symbol, exchange, Interval(interval), start, end ) if not data: self.write_log(f"数据下载失败,无法获取{vt_symbol}的历史数据") database_manager.save_bar_data(data) # Clear thread object handler. self.thread = None self.write_log(f"{vt_symbol}-{interval}历史数据下载完成")
def run_backtesting_load_data_df(self): # Use the first [days] of history data for initializing strategy self.load_bar_end_timestamp = self.history_data_df['datetime'].min() + Day(self.days) load_bar_df = self.history_data_df[self.history_data_df['datetime'] <= self.load_bar_end_timestamp] for ix in range(len(load_bar_df)): x = load_bar_df.iloc[ix, :] bar = BarData( symbol=x.loc['symbol'], exchange=Exchange(x.loc['exchange']), datetime=self.datetime_set_timezone(x.loc['datetime'].to_pydatetime()), interval=Interval(x.loc['interval']), volume=x.loc['volume'], open_price=x.loc['open_price'], high_price=x.loc['high_price'], open_interest=x.loc['open_interest'], low_price=x.loc['low_price'], close_price=x.loc['close_price'], gateway_name="DB", ) self.datetime = bar.datetime # fangyang self.callback 是 strategyTemplate里面的 on_bar # self.callback 是在下面的 load_bar(self)函数中赋值的,去策略模板中掉的load_bar/tick # 这里将数据推送进我们的策略 try: self.callback(bar) except Exception: self.output("触发异常,回测终止") self.output(traceback.format_exc()) return
def download_bar_data( self, symbol: str, exchange: Exchange, interval: str, start: datetime ) -> int: """ """ req = HistoryRequest( symbol=symbol, exchange=exchange, interval=Interval(interval), start=start, end=datetime.now() ) vt_symbol = f"{symbol}.{exchange.value}" contract = self.main_engine.get_contract(vt_symbol) # If history data provided in gateway, then query if contract and contract.history_data: data = self.main_engine.query_history( req, contract.gateway_name ) if data: database_manager.save_bar_data(data) return(len(data)) return 0
def download_bar_data(self, symbol: str, exchange: Exchange, interval: str, start: datetime) -> int: """ Query bar data from RQData. """ req = HistoryRequest(symbol=symbol, exchange=exchange, interval=Interval(interval), start=start, end=datetime.now()) vt_symbol = f"{symbol}.{exchange.value}" contract = self.main_engine.get_contract(vt_symbol) # If history data provided in gateway, then query if contract and contract.history_data: data = self.main_engine.query_history(req, contract.gateway_name) # Otherwise use RQData to query data else: if not rqdata_client.inited: rqdata_client.init() data = rqdata_client.query_history(req) if data: database_manager.save_bar_data(data) return (len(data)) return 0
def init_bar_overview(self) -> None: """ Init overview table if not exists. """ f = shelve.open(self.overview_filepath) query: str = "select count(close_price) from bar_data group by *" result = self.client.query(query) for k, v in result.items(): tags = k[1] data = list(v)[0] vt_symbol = tags["vt_symbol"] symbol, exchange = extract_vt_symbol(vt_symbol) interval = Interval(tags["interval"]) overview = BarOverview( symbol=symbol, exchange=exchange, interval=interval, count=data["count"] ) overview.start = self.get_bar_datetime(vt_symbol, interval, 1) overview.end = self.get_bar_datetime(vt_symbol, interval, -1) key = f"{vt_symbol}_{interval.value}" f[key] = overview f.close()
def set_parameters(self, vt_symbol: str, interval: Interval, start: datetime, rate: float, slippage: float, size: float, pricetick: float, capital: int = 0, end: datetime = None, mode: BacktestingMode = BacktestingMode.BAR, inverse: bool = False): """""" self.mode = mode self.vt_symbol = vt_symbol self.interval = Interval(interval) self.rate = rate self.slippage = slippage self.size = size self.pricetick = pricetick self.start = start self.symbol, exchange_str = self.vt_symbol.split(".") self.exchange = Exchange(exchange_str) self.capital = capital self.end = end self.mode = mode self.inverse = inverse # 更新日志目录 self.logs_path = os.path.abspath( os.path.join(os.getcwd(), 'log', self.test_name))
def run_single_backtesting(args): engine = BacktestingEngine() req = HistoryRequest(symbol=args.symbol, exchange=Exchange(args.exchange), interval=Interval(args.interval), start=datetime.strptime(args.startdate, '%Y-%m-%d'), end=datetime.strptime(args.enddate, '%Y-%m-%d')) setting = txt_to_dic(args.backtesting_setting_file) strategy_class = STRATEGIES[args.strategy_class] engine.set_parameters(vt_symbol=req.vt_symbol, interval=req.interval, start=req.start, end=req.end, rate=args.rate, slippage=args.slippage, size=args.size, pricetick=args.pricetick, capital=args.capital) engine.add_strategy(strategy_class, setting) engine.load_data() engine.run_backtesting() df = engine.calculate_result() engine.calculate_statistics(df) engine.show_chart(df)
def run_downloading(self, vt_symbol: str, interval: str, start: datetime, end: datetime): """ Query bar data from RQData. """ self.write_log(f"{vt_symbol}-{interval}开始下载历史数据") symbol, exchange = extract_vt_symbol(vt_symbol) req = HistoryRequest(symbol=symbol, exchange=exchange, interval=Interval(interval), start=start, end=end) contract = self.main_engine.get_contract(vt_symbol) # If history data provided in gateway, then query if contract and contract.history_data: data = self.main_engine.query_history(req, contract.gateway_name) # Otherwise use RQData to query data else: data = rqdata_client.query_history(req) if data: database_manager.save_bar_data(data) self.write_log(f"{vt_symbol}-{interval}历史数据下载完成") else: self.write_log(f"数据下载失败,无法获取{vt_symbol}的历史数据") # Clear thread object handler. self.thread = None
def load_bar_data( self, symbol: str, exchange: Exchange, interval: Interval, start: datetime, end: datetime ) -> List[BarData]: """""" s: ModelSelect = ( DbBarData.select().where( (DbBarData.symbol == symbol) & (DbBarData.exchange == exchange.value) & (DbBarData.interval == interval.value) & (DbBarData.datetime >= start) & (DbBarData.datetime <= end) ).order_by(DbBarData.datetime) ) vt_symbol = f"{symbol}.{exchange.value}" bars: List[BarData] = [] for db_bar in s: db_bar.datetime = DB_TZ.localize(db_bar.datetime) db_bar.exchange = Exchange(db_bar.exchange) db_bar.interval = Interval(db_bar.interval) db_bar.gateway_name = "DB" db_bar.vt_symbol = vt_symbol bars.append(db_bar) return bars
def update_data(self) -> None: """""" data = self.engine.get_bar_data_available() total = len(data) count = 0 dialog = QtWidgets.QProgressDialog( "历史数据更新中", "取消", 0, 100 ) dialog.setWindowTitle("更新进度") dialog.setWindowModality(QtCore.Qt.WindowModal) dialog.setValue(0) for d in data: if dialog.wasCanceled(): break self.engine.download_bar_data( d["symbol"], Exchange(d["exchange"]), Interval(d["interval"]), d["end"] ) count += 1 progress = int(round(count / total * 100, 0)) dialog.setValue(progress) dialog.close()
def load_bar_data( self, symbol: str, exchange: Exchange, interval: Interval, start: datetime, end: datetime ) -> List[BarData]: """读取K线数据""" filter = { "symbol": symbol, "exchange": exchange.value, "interval": interval.value, "datetime": { "$gte": start, "$lte": end } } c: Cursor = self.bar_collection.find(filter) bars = [] for d in c: d["exchange"] = Exchange(d["exchange"]) d["interval"] = Interval(d["interval"]) d["gateway_name"] = "DB" d.pop("_id") bar = BarData(**d) bars.append(bar) return bars
def rq_download( self, vt_symbol: str, interval: str, start: datetime, end: datetime, ): rqdata_client.init() symbol, exchange = extract_vt_symbol(vt_symbol) req = HistoryRequest(symbol=symbol, exchange=exchange, interval=Interval(interval), start=start, end=end) # print(req) data = rqdata_client.query_history(req) if data: database_manager.save_bar_data(data) print(f"{vt_symbol}-{interval} 历史数据下载完成") else: print(f"数据下载失败,无法得到 {vt_symbol} 的数据")
def set_parameters(self, vt_symbol: str, interval: Interval, start: datetime, rate: float, slippage: float, size: float, pricetick: float, capital: int = 0, end: datetime = None, mode: BacktestingMode = BacktestingMode.BAR, inverse: bool = False, risk_free: float = 0): """""" self.mode = mode self.vt_symbol = vt_symbol self.interval = Interval(interval) self.rate = rate self.slippage = slippage self.size = size self.pricetick = pricetick self.start = start self.symbol, exchange_str = self.vt_symbol.split(".") self.exchange = Exchange(exchange_str) self.capital = capital self.end = end self.mode = mode self.inverse = inverse self.risk_free = risk_free
def load_bar_data(self, symbol: str, exchange: Exchange, interval: Interval, start: datetime, end: datetime) -> List[BarData]: """""" s: ModelSelect = ( DbBarData.select().where((DbBarData.symbol == symbol) & (DbBarData.exchange == exchange.value) & (DbBarData.interval == interval.value) & (DbBarData.datetime >= start) & (DbBarData.datetime <= end)).order_by( DbBarData.datetime)) bars: List[BarData] = [] for db_bar in s: bar = BarData(symbol=db_bar.symbol, exchange=Exchange(db_bar.exchange), datetime=db_bar.datetime.astimezone(DB_TZ), interval=Interval(db_bar.interval), volume=db_bar.volume, turnover=db_bar.turnover, open_interest=db_bar.open_interest, open_price=db_bar.open_price, high_price=db_bar.high_price, low_price=db_bar.low_price, close_price=db_bar.close_price, gateway_name="DB") bars.append(bar) return bars
def load_bar_data( self, symbol: str, exchange: Exchange, interval: Interval, start: datetime, end: datetime ) -> List[BarData]: """""" s: QuerySet = DbBarData.objects( symbol=symbol, exchange=exchange.value, interval=interval.value, datetime__gte=convert_tz(start), datetime__lte=convert_tz(end), ) vt_symbol = f"{symbol}.{exchange.value}" bars: List[BarData] = [] for db_bar in s: db_bar.datetime = DB_TZ.localize(db_bar.datetime) db_bar.exchange = Exchange(db_bar.exchange) db_bar.interval = Interval(db_bar.interval) db_bar.gateway_name = "DB" db_bar.vt_symbol = vt_symbol bars.append(db_bar) return bars
def download(self): """""" symbol = self.symbol_edit.text() exchange = Exchange(self.exchange_combo.currentData()) interval = Interval(self.interval_combo.currentData()) start_date = self.start_date_edit.date() start = datetime(start_date.year(), start_date.month(), start_date.day()) count = self.engine.download_bar_data(symbol, exchange, interval, start) QtWidgets.QMessageBox.information(self, "下载结束", f"下载总数据量:{count}条")
def load_data(vt_symbol: str, interval: str, start: datetime, end: datetime) -> pd.DataFrame: symbol, exchange = extract_vt_symbol(vt_symbol) data = database_manager.load_bar_data( symbol, exchange, Interval(interval), start=start, end=end, ) return vt_bar_to_df(data)
def process_current_contract_event(self, event) -> None: self.process_clean(event) self.clear_all() self._bars = [] if type(event.data) != str: self.current_contract = event.data if self.current_interval == "d": self._bars = self.get_jq_bars(self.current_contract.symbol, "d", 300) elif self.current_interval == "1h": self._bars = self.get_jq_bars(self.current_contract.symbol, "1h", 300) elif self.current_interval == "1m": self._bars = self.get_jq_bars(self.current_contract.symbol, "1m", 300) elif self.current_interval == "5m": self._bars = self.get_jq_bars(self.current_contract.symbol, "5m", 300) elif self.current_interval == "15m": self._bars = self.get_jq_bars(self.current_contract.symbol, "15m", 300) if self._bars and self.current_contract: if self._bars[0].interval.value == "d": if self.current_tick: now = datetime.now() current_bar = BarData( gateway_name='JQ', symbol=self.current_tick.symbol, exchange=self.current_tick.exchange, interval=Interval('d'), datetime=datetime(now.year, now.month, now.day, 0, 0), volume=self.current_tick.volume, open_price=self.current_tick.open_price, high_price=self.current_tick.high_price, low_price=self.current_tick.low_price, close_price=self.current_tick.last_price, ) self._bars.append(current_bar) self.update_history(self._bars[:-1]) elif self._bars[0].interval.value == "1h": self.update_history(self._bars[:-1]) elif self._bars[0].interval.value == "1m": self.update_history(self._bars[:-1]) # self.update_trades(self._bars,"BK",1,"close_price") elif self._bars[0].interval.value == "5m": self.update_history(self._bars[:-1]) # self.update_trades(trades) elif self._bars[0].interval.value == "15m": self.update_history(self._bars[:-1]) # self.update_trades(trades) self.main_engine.current_bars = self._bars[:-1]
def get_bar_overview(self) -> List[BarOverview]: """查询数据库中的K线汇总信息""" c: Cursor = self.overview_collection.find() overviews = [] for d in c: d["exchange"] = Exchange(d["exchange"]) d["interval"] = Interval(d["interval"]) d.pop("_id") overview = BarOverview(**d) overviews.append(overview) return overviews
def downloading_history_data(args): ''' download stocks or futures bar data. ''' req = HistoryRequest(symbol=args.symbol, exchange=Exchange(args.exchange), interval=Interval(args.interval), start=datetime.strptime(args.startdate, '%Y-%m-%d'), end=datetime.strptime(args.enddate, '%Y-%m-%d')) event_engine = EventEngine() main_engine = MainEngine(event_engine) main_engine.add_gateway(CtpGateway) cta_backtester_engine = main_engine.add_app(CtaBacktesterApp) cta_backtester_engine.start_downloading(req.vt_symbol, req.interval, req.start, req.end)
def get_bar_overview(self) -> List[BarOverview]: """查询数据库中的K线整体概况""" s = self.db.query(DbBarOverview).all() overviews = [] for overview in s: data = BarOverview( symbol=overview.symbol, exchange=Exchange(overview.exchange), interval=Interval(overview.interval), count=overview.count, start=overview.start, end=overview.end ) overviews.append(data) return overviews
def download_history_data(symbol, exchange): print(symbol, exchange) begin_date = datetime.strptime("2019-09-19", "%Y-%m-%d") for i in range(1): start_date = begin_date.replace(year=begin_date.year + i) # end_date = begin_date.replace(year=begin_date.year + i + 1) end_date = datetime.now() print(start_date, end_date) req = HistoryRequest(symbol=symbol, exchange=Exchange(exchange), interval=Interval("1m"), start=start_date, end=end_date) data = mddata_client.query_history(req)
def change_contract(self, vt_symbol): self.unregister_event() self.chart.clear_all() self.vt_symbol = vt_symbol contract: ContractData = self.contracts.get(self.vt_symbol) if contract: interval = Interval(self.interval_combo.currentText()) his_data = self.visual_engine.get_historical_data( contract, '', 600, interval) trade_data = self.visual_engine.get_trades(contract) order_data = self.visual_engine.get_orders(contract) self.register_event() self.chart.update_all(his_data, trade_data, order_data)
def deal_func(x): bar = BarData( symbol=x.loc['symbol'], exchange=Exchange(x.loc['exchange']), datetime=datetime.strptime(x.loc['datetime'], '%Y-%m-%d %H:%M:%S').replace(tzinfo=DB_TZ), interval=Interval(x.loc['interval']), volume=x.loc['volume'], open_price=x.loc['open_price'], high_price=x.loc['high_price'], open_interest=x.loc['open_interest'], low_price=x.loc['low_price'], close_price=x.loc['close_price'], gateway_name="DB", ) return bar
def run_downloading(self, vt_symbol: str, interval: str, start: datetime, end: datetime): """ Query bar data from RQData. """ self.write_log(f"{vt_symbol}-{interval}开始下载历史数据") try: symbol, exchange = extract_vt_symbol(vt_symbol) except ValueError: self.write_log(f"{vt_symbol}解析失败,请检查交易所后缀") self.thread = None return req = HistoryRequest(symbol=symbol, exchange=exchange, interval=Interval(interval), start=start, end=end) contract = self.main_engine.get_contract(vt_symbol) try: # If history data provided in gateway, then query if contract and contract.history_data: data = self.main_engine.query_history(req, contract.gateway_name) # Otherwise use RQData to query data else: if SETTINGS["rqdata.username"]: data = rqdata_client.query_history(req) elif SETTINGS["tqdata.username"]: data = tqdata_client.query_history(req) else: data = [] if data: database_manager.save_bar_data(data) self.write_log(f"{vt_symbol}-{interval}历史数据下载完成") else: self.write_log(f"数据下载失败,无法获取{vt_symbol}的历史数据") except Exception: msg = f"数据下载失败,触发异常:\n{traceback.format_exc()}" self.write_log(msg) # Clear thread object handler. self.thread = None
def get_jq_bars(self, symbol: str, interval: str, count: int): """""" contract = self.main_engine.get_current_contract() if not contract: return [] req = HistoryRequest( symbol=contract.symbol, count=count, exchange=contract.exchange, interval=Interval(interval), start="", end="", ) bars_list, bars_df = jqdata_client.query_history(req) return bars_list