def download_minute_bar(name, start_date, end_date, ac): """下载某一合约的分钟线数据""" print(f"开始下载合约数据{name}") symbol_info = ac[ac.name == name] symbol = name vt_symbol = symbol_info.index[0] exchange = ex_jq2vn.get(vt_symbol[-4:]) print(f"{symbol, exchange}") symbol_info = ac[ac.name == name] vt_symbol = symbol_info.index[0] start = time() df = jq.get_price( vt_symbol, start_date=start_date, end_date=end_date, frequency="1m", fields=FIELDS, ) bars = [] for ix, row in df.iterrows(): bar = generateVtBar(row, symbol, exchange) bars.append(bar) database_manager.save_bar_data(bars) end = time() cost = (end - start) * 1000 print("合约%s的分钟K线数据下载完成%s - %s,耗时%s毫秒" % (symbol, df.index[0], df.index[-1], cost)) print(jq.get_query_count())
def save_data_to_db(symbol, alias, count=5000): """数据入库""" auth('13277099856', '1221gzcC') exchange = const.Exchange.get_exchange_by_alias(alias) data = get_bars(symbol + '.' + alias, count, unit='1d', fields=['date', 'open', 'high', 'low', 'close', 'volume'], include_now=False, end_dt=None, fq_ref_date=None, df=True) bars = [] for row in data.iterrows(): data = row[1] bar = BarData( gateway_name='test', symbol=symbol, exchange=exchange, datetime=data.date, interval=const.Interval.DAILY, volume=data['volume'], ) # open_interest: float = 0 bar.open_price = data['open'] bar.high_price = data['high'] bar.low_price = data['low'] bar.close_price = data['close'] bars.append(bar) database_manager.save_bar_data(bars)
def download_bar_data(self, symbol: str, exchange: Exchange, interval: str, start: datetime) -> int: """ Query bar data from RQData. """ req = HistoryRequest(symbol=symbol, exchange=exchange, interval=Interval(interval), start=start, end=datetime.now()) vt_symbol = f"{symbol}.{exchange.value}" contract = self.main_engine.get_contract(vt_symbol) # If history data provided in gateway, then query if contract and contract.history_data: data = self.main_engine.query_history(req, contract.gateway_name) # Otherwise use RQData to query data else: if not rqdata_client.inited: rqdata_client.init() data = rqdata_client.query_history(req) if data: database_manager.save_bar_data(data) return (len(data)) return 0
def download_bar_data( self, symbol: str, exchange: Exchange, interval: str, start: datetime ) -> int: """ """ req = HistoryRequest( symbol=symbol, exchange=exchange, interval=Interval(interval), start=start, end=datetime.now() ) vt_symbol = f"{symbol}.{exchange.value}" contract = self.main_engine.get_contract(vt_symbol) # If history data provided in gateway, then query if contract and contract.history_data: data = self.main_engine.query_history( req, contract.gateway_name ) if data: database_manager.save_bar_data(data) return(len(data)) return 0
def move_df_to_sql(imported_data: pd.DataFrame): bars = [] start = None count = 0 for row in imported_data.itertuples(): bar = BarData( symbol=row.symbol, exchange=row.exchange, datetime=row.dtime, interval=row.interval, volume=row.volume, open_price=row.open, high_price=row.high, low_price=row.low, close_price=row.close, open_interest=row.o_interest, gateway_name="DB", ) bars.append(bar) # do some statistics count += 1 if not start: start = bar.datetime end = bar.datetime # insert into database database_manager.save_bar_data(bars) print(f'Insert Bar: {count} from {start} - {end}')
def update_all(self, symbol, exchange, trade_datas, start, end=None): self._symbol = symbol self._exchange = exchange self._start = start interval = self._interval tp = self.interval2timdelta(self._interval) backward_n = max(60 * tp, dt.timedelta(hours=25)) end = start + backward_n if end is None else end history_data = database_manager.load_bar_data(symbol, exchange, interval, start=start, end=end) self.trade_datas = trade_datas if len(history_data) > 0 and len(history_data) / ( (end - start).total_seconds() / 60) > 0.7: self.chart.update_all(history_data, trade_datas, []) else: req = HistoryRequest(symbol, exchange, start, end, interval) gateway = self.main_engine.get_gateway('IB') if gateway and gateway.api.status: self.history_data = history_data = gateway.query_history(req) self.chart.update_all(history_data, trade_datas, []) database_manager.save_bar_data(history_data) if len(getattr(self, 'history_data', [])) > 0: self._end = self.history_data[-1].datetime
def download_minute_bar(vt_symbol): """下载某一合约的分钟线数据""" print(f"开始下载合约数据{vt_symbol}") symbol, exchange = vt_symbol.split(".") start = time() df = rq.get_price( symbol, frequency="1m", fields=FIELDS, start_date='20100416', end_date='20190416' ) bars = [] for ix, row in df.iterrows(): bar = generate_bar_from_row(row, symbol, exchange) bars.append(bar) database_manager.save_bar_data(bars) end = time() cost = (end - start) * 1000 print( "合约%s的分钟K线数据下载完成%s - %s,耗时%s毫秒" % (symbol, df.index[0], df.index[-1], cost) )
def run_downloading(self, vt_symbol: str, interval: str, start: datetime, end: datetime): """ Query bar data from RQData. """ self.write_log(f"{vt_symbol}-{interval}开始下载历史数据") symbol, exchange = extract_vt_symbol(vt_symbol) req = HistoryRequest(symbol=symbol, exchange=exchange, interval=Interval(interval), start=start, end=end) contract = self.main_engine.get_contract(vt_symbol) # If history data provided in gateway, then query if contract and contract.history_data: data = self.main_engine.query_history(req, contract.gateway_name) # Otherwise use RQData to query data else: data = rqdata_client.query_history(req) if data: database_manager.save_bar_data(data) self.write_log(f"{vt_symbol}-{interval}历史数据下载完成") else: self.write_log(f"数据下载失败,无法获取{vt_symbol}的历史数据") # Clear thread object handler. self.thread = None
def run_downloading( self, vt_symbol: str, interval: str, start: datetime, end: datetime ): """ Query bar data from RQData. """ self.write_log(f"{vt_symbol}-{interval}开始下载历史数据") symbol, exchange = extract_vt_symbol(vt_symbol) data = rqdata_client.query_bar( symbol, exchange, Interval(interval), start, end ) if not data: self.write_log(f"数据下载失败,无法获取{vt_symbol}的历史数据") database_manager.save_bar_data(data) # Clear thread object handler. self.thread = None self.write_log(f"{vt_symbol}-{interval}历史数据下载完成")
def save_to_database(data: List[dict], vt_symbol: str, rq_interval: str): interval = INTERVAL_RQ2VT.get(rq_interval) if not rq_interval: return None symbol, exchange = extract_vt_symbol(vt_symbol) exchange = Exchange(exchange) dt_format = "%Y-%m-%d %H:%M:%S" res_list: List[BarData] = [] if data is not None: for row in data: bar = BarData(symbol=symbol, exchange=exchange, interval=interval, datetime=datetime.strptime( row['datetime'], dt_format), open_price=row["open"], high_price=row["high"], low_price=row["low"], close_price=row["close"], volume=row["volume"], gateway_name="RQ_WEB") res_list.append(bar) database_manager.save_bar_data(res_list)
def download_all(self): """ 使用tushare下载A股股票全市场日线数据 :return: """ log.info("开始下载A股股票全市场日线数据") if self.symbols is not None: with tqdm(total=len(self.symbols)) as pbar: for tscode, list_date in zip(self.symbols['ts_code'], self.symbols['list_date']): symbol, exchange = to_split_ts_codes(tscode) pbar.set_description_str("下载A股日线数据股票代码:" + tscode) start_date = datetime.strptime(list_date, TS_DATE_FORMATE) req = HistoryRequest(symbol=symbol, exchange=exchange, start=start_date, end=datetime.now(), interval=Interval.DAILY) bardata = self.tushare_client.query_history(req=req) if bardata: try: database_manager.save_bar_data(bardata) except Exception as ex: log.error(tscode + "数据存入数据库异常") log.error(ex) traceback.print_exc() pbar.update(1) log.info(pbar.desc) log.info("A股股票全市场日线数据下载完毕")
def rq_download( self, vt_symbol: str, interval: str, start: datetime, end: datetime, ): rqdata_client.init() symbol, exchange = extract_vt_symbol(vt_symbol) req = HistoryRequest(symbol=symbol, exchange=exchange, interval=Interval(interval), start=start, end=end) # print(req) data = rqdata_client.query_history(req) if data: database_manager.save_bar_data(data) print(f"{vt_symbol}-{interval} 历史数据下载完成") else: print(f"数据下载失败,无法得到 {vt_symbol} 的数据")
def load_by_handle( self, data, symbol: str, exchange: Exchange, interval: Interval, datetime_head: str, open_head: str, high_head: str, low_head: str, close_head: str, volume_head: str, open_interest_head: str, datetime_format: str, progress_bar_dict:dict, opt_str: str ): start_time = time.time() if isinstance(data[datetime_head][0], str): data[datetime_head] = data[datetime_head].apply( lambda x: datetime.strptime(x, datetime_format) if datetime_format else datetime.fromisoformat(x)) elif isinstance(data[datetime_head][0], pd.Timestamp): self.main_engine.write_log("datetime 格式为 pd.Timestamp, 不用处理.") else: self.main_engine.write_log("未知datetime类型, 请检查") self.main_engine.write_log(f'df apply 处理日期时间 cost {time.time() - start_time:.2f}s') if opt_str == "to_db": start_time = time.time() bars = data.apply( self.to_bar_data, args=( symbol, exchange, interval, datetime_head, open_head, high_head, low_head, close_head, volume_head, open_interest_head ), axis=1).tolist() self.main_engine.write_log(f'df apply 处理bars时间 cost {time.time() - start_time:.2f}s') # insert into database database_manager.save_bar_data(bars, progress_bar_dict) elif opt_str == "to_csv": csv_file_dir = get_folder_path("csv_files") data.to_csv(f'{csv_file_dir}/{exchange.value}_{symbol}.csv', index=False) start = data[datetime_head].iloc[0] end = data[datetime_head].iloc[-1] count = len(data) return start, end, count
def import_data_from_csv( self, file_path: str, symbol: str, exchange: Exchange, interval: Interval, datetime_head: str, open_head: str, high_head: str, low_head: str, close_head: str, volume_head: str, open_interest_head: str, datetime_format: str ) -> Tuple: """""" with open(file_path, "rt") as f: buf = [line.replace("\0", "") for line in f] reader = csv.DictReader(buf, delimiter=",") bars = [] start = None count = 0 for item in reader: if datetime_format: dt = datetime.strptime(item[datetime_head], datetime_format) else: dt = datetime.fromisoformat(item[datetime_head]) open_interest = item.get(open_interest_head, 0) bar = BarData( symbol=symbol, exchange=exchange, datetime=dt, interval=interval, volume=float(item[volume_head]), open_price=float(item[open_head]), high_price=float(item[high_head]), low_price=float(item[low_head]), close_price=float(item[close_head]), open_interest=float(open_interest), gateway_name="DB", ) bars.append(bar) # do some statistics count += 1 if not start: start = bar.datetime # insert into database database_manager.save_bar_data(bars) end = bar.datetime return start, end, count
def check_update_all(self): """ 这个方法太慢了,不建议调用。 这个方法用于本地数据库已经建立,但可能有部分数据缺失时使用 使用tushare检查更新所有的A股股票全市场日线数据 检查哪一个交易日的数据是缺失的,补全它 检查上市后是否每个交易日都有数据,若存在某一交易日无数据,尝试从tushare查询该日数据,若仍无,则说明当天停盘 :return: """ log.info("开始检查更新所有的A股股票全市场日线数据") if self.symbols is not None: with tqdm(total=len(self.symbols)) as pbar: for tscode, list_date in zip(self.symbols['ts_code'], self.symbols['list_date']): pbar.set_description_str("正在检查A股日线数据,股票代码:" + tscode) symbol, exchange = to_split_ts_codes(tscode) local_bar = database_manager.load_bar_data( symbol=symbol, exchange=exchange, interval=Interval.DAILY, start=datetime.strptime(list_date, TS_DATE_FORMATE), end=datetime.now()) local_bar_dates = [ bar.datetime.strftime(TS_DATE_FORMATE) for bar in local_bar ] index = (self.trade_cal[exchange.value][( self.trade_cal[exchange.value].cal_date == list_date)]) trade_cal = self.trade_cal[ exchange.value].iloc[index.index[0]:] for trade_date in trade_cal['cal_date']: if trade_date not in local_bar_dates: req = HistoryRequest(symbol=symbol, exchange=exchange, start=datetime.strptime( trade_date, TS_DATE_FORMATE), end=datetime.strptime( trade_date, TS_DATE_FORMATE), interval=Interval.DAILY) bardata = self.tushare_client.query_history( req=req) if bardata: log.info(tscode + "本地数据库缺失:" + trade_date) try: database_manager.save_bar_data(bardata) except Exception as ex: log.error(tscode + "数据存入数据库异常") log.error(ex) traceback.print_exc() pbar.update(1) log.info(pbar.desc) log.info("A股股票全市场日线数据检查更新完毕")
def load_by_handle( self, f: TextIO, symbol: str, exchange: Exchange, interval: Interval, datetime_head: str, open_head: str, high_head: str, low_head: str, close_head: str, volume_head: str, datetime_format: str, ): """ load by text mode file handle 通过文件 句柄加载 bar 数据 返回 开始 结束 和 总共有多少个 bar """ reader = csv.DictReader(f) bars = [] start = None count = 0 # 循环生成 bar 数据, 放进bar列表里 for item in reader: if datetime_format: dt = datetime.strptime(item[datetime_head], datetime_format) else: dt = datetime.fromisoformat(item[datetime_head]) bar = BarData( symbol=symbol, exchange=exchange, datetime=dt, interval=interval, volume=item[volume_head], open_price=item[open_head], high_price=item[high_head], low_price=item[low_head], close_price=item[close_head], gateway_name="DB", ) bars.append(bar) # do some statistics count += 1 if not start: start = bar.datetime end = bar.datetime # insert into database # 存放到数据库 database_manager.save_bar_data(bars) return start, end, count
def load_by_handle( self, f: TextIO, symbol: str, exchange: Exchange, interval: Interval, datetime_head: str, open_head: str, high_head: str, low_head: str, close_head: str, volume_head: str, open_interest_head: str, datetime_format: str, progress_bar_dict ): """ load by text mode file handle """ buf = [line.replace("\0", "") for line in f] reader = csv.DictReader(buf, delimiter=",") bars = [] start = None count = 0 for item in reader: if datetime_format: dt = datetime.strptime(item[datetime_head], datetime_format) else: dt = datetime.fromisoformat(item[datetime_head]) bar = BarData( symbol=symbol, exchange=exchange, datetime=dt, interval=interval, volume=item[volume_head], open_interest=item[open_interest_head], open_price=item[open_head], high_price=item[high_head], low_price=item[low_head], close_price=item[close_head], gateway_name="DB", ) bars.append(bar) # do some statistics count += 1 if not start: start = bar.datetime end = bar.datetime # insert into database database_manager.save_bar_data(bars, progress_bar_dict) return start, end, count
def data_record(start, end, vt_symbol): from vnpy.trader.database import database_manager from vnpy.gateway.ib.ib_gateway import IbGateway from vnpy.trader.utility import load_json from vnpy.trader.object import HistoryRequest from vnpy.trader.constant import Interval, Exchange from dateutil import parser from vnpy.event.engine import EventEngine from vnpy.trader.event import EVENT_LOG vt_symbol = vt_symbol symbol, exchange = vt_symbol.split('.') if not start and not end: offset = 0 if dt.datetime.now().time() > dt.time(17, 0) else 1 start = (dt.datetime.today() - dt.timedelta(days=offset + 1)).replace( hour=17, minute=0, second=0, microsecond=0) end = (dt.datetime.today() - dt.timedelta(days=offset)).replace( hour=17, minute=0, second=0, microsecond=0) else: start = parser.parse(start) end = parser.parse(end) if end else end ib_settings = load_json('connect_ib.json') ib_settings["客户号"] += 4 recorder_engine = EventEngine() def log(event): data = event.data print(data.level, data.msg) recorder_engine.register(EVENT_LOG, log) ib = IbGateway(recorder_engine) try: recorder_engine.start() ib.connect(ib_settings) if ib.api.client.isConnected(): req = HistoryRequest(symbol, Exchange(exchange), start, end, Interval.MINUTE) ib.write_log(f'发起请求#{vt_symbol}, {start}至{end}') his_data = ib.query_history(req) ib.write_log( f'获得数据#{vt_symbol}, {his_data[0].datetime}至{his_data[-1].datetime}, 共{len(his_data)}条' ) database_manager.save_bar_data(his_data) ib.write_log(f'成功入库') else: ib.write_log('连接失败!请检查客户号是否被占用或IP是否正确') except Exception as e: raise e finally: ib.close() recorder_engine.stop()
def write(): req = HistoryRequest( symbol=symbol, exchange=exchange, interval=Interval.MINUTE, start=start_date, end=end_date ) data = rqdata_client.query_history(req) database_manager.save_bar_data(data)
def move_df_to_db(imported_data:pd.DataFrame,future_download:bool): print("move_df_to_db 函数") bars = [] count = 0 time_consuming_start = time() tmpsymbol = None start_time = None for row in imported_data.itertuples(): bar = BarData( datetime=datetime.fromtimestamp(row.datetime), #为标准datetime格式,非datetime64[ns],timeStamp symbol=row.symbol, exchange=row.exchange, interval=row.interval, open_price=row.open_price, high_price=row.high_price, low_price=row.low_price, close_price=row.close_price, # open_interest=row.open_interest, volume=row.volume, gateway_name="DB", ) if not tmpsymbol : tmpsymbol = bar.symbol if future_download: # 夜盘时间21:00 - 2:30 日期减1天 if bar.datetime.time() >= dtime(21,0) or bar.datetime.time() <= dtime(2,30): bar.datetime -= timedelta(days=1) # 其他时间分钟减1 ??? bar.datetime-= timedelta(minutes=1) if not start_time: start_time = bar.datetime bars.append(bar) # do some statistics count += 1 end_time = bar.datetime # insert into database for bar_data in chunked(bars, 10000): # 分批保存数据 database_manager.save_bar_data(bar_data) time_consuming_end =time() print(f'载入通达信标的:{tmpsymbol} 分钟数据,开始时间:{start_time},结束时间:{end_time},数据量:{count},耗时:{round(time_consuming_end-time_consuming_start,3)}秒')
def run(self): """""" while self.active: try: task = self.queue.get(timeout=1) task_type, data = task if task_type == "tick": database_manager.save_tick_data([data]) elif task_type == "bar": database_manager.save_bar_data([data]) except Empty: continue
def run_downloading(self, vt_symbol: str, interval: str, start: datetime, end: datetime): """ Query bar data from RQData. """ self.write_log(f"{vt_symbol}-{interval}开始下载历史数据") try: symbol, exchange = extract_vt_symbol(vt_symbol) except ValueError: self.write_log(f"{vt_symbol}解析失败,请检查交易所后缀") self.thread = None return req = HistoryRequest(symbol=symbol, exchange=exchange, interval=Interval(interval), start=start, end=end) contract = self.main_engine.get_contract(vt_symbol) try: # If history data provided in gateway, then query if contract and contract.history_data: data = self.main_engine.query_history(req, contract.gateway_name) # Otherwise use RQData to query data else: if SETTINGS["rqdata.username"]: data = rqdata_client.query_history(req) elif SETTINGS["tqdata.username"]: data = tqdata_client.query_history(req) else: data = [] if data: database_manager.save_bar_data(data) self.write_log(f"{vt_symbol}-{interval}历史数据下载完成") else: self.write_log(f"数据下载失败,无法获取{vt_symbol}的历史数据") except Exception: msg = f"数据下载失败,触发异常:\n{traceback.format_exc()}" self.write_log(msg) # Clear thread object handler. self.thread = None
def update_newest(self): """ 使用tushare更新本地数据库中的最新数据,默认本地数据库中原最新的数据之前的数据都是完备的 :return: """ log.info("开始更新最新的A股股票全市场日线数据") if self.symbols is not None: with tqdm(total=len(self.symbols)) as pbar: for tscode, list_date in zip(self.symbols['ts_code'], self.symbols['list_date']): symbol, exchange = to_split_ts_codes(tscode) newest_local_bar = self.get_newest_bar_data( symbol=symbol, exchange=exchange, interval=Interval.DAILY) if newest_local_bar is not None: pbar.set_description_str("正在处理股票代码:" + tscode + "本地最新数据:" + newest_local_bar.datetime. strftime(TS_DATE_FORMATE)) start_date = newest_local_bar.datetime + timedelta( days=1) else: pbar.set_description_str("正在处理股票代码:" + tscode + "无本地数据") start_date = datetime.strptime(list_date, TS_DATE_FORMATE) req = HistoryRequest(symbol=symbol, exchange=exchange, start=start_date, end=datetime.now(), interval=Interval.DAILY) bardata = self.tushare_client.query_history(req=req) if bardata: try: database_manager.save_bar_data(bardata) except Exception as ex: log.error(tscode + "数据存入数据库异常") log.error(ex) traceback.print_exc() pbar.update(1) log.info(pbar.desc) log.info("A股股票全市场日线数据更新完毕")
def update_backward_bars(self, n): chart = self.chart last_bar = chart._manager.get_bar(chart.last_ix) if last_bar: symbol = last_bar.symbol exchange = last_bar.exchange if self._end: start = max(last_bar.datetime, self._end) else: start = last_bar.datetime if start >= dt.datetime.now(get_localzone()): return tp = self.interval2timdelta(self._interval) backward_n = max(tp * n, dt.timedelta(minutes=60)) end = start + backward_n if not self.checkTradeTime(end.time()): history_data = database_manager.load_bar_data(symbol, exchange, self._interval, start=start, end=end) if len(history_data) == 0 or len(history_data) / ( (end - start).total_seconds() / 60) < 0.7: req = HistoryRequest(symbol, exchange, start, end, self._interval) gateway = self.main_engine.get_gateway('IB') if gateway and gateway.api.status: history_data = gateway.query_history(req) database_manager.save_bar_data(history_data) for bar in history_data: self.chart.update_bar(bar) last_bar_after_update = chart._manager.get_bar(chart.last_ix) self.chart.clear_trades() self.chart.update_trades([ t for t in self.trade_datas if t.datetime <= last_bar_after_update.datetime ]) self.chart.update_pos() self.chart.update_pnl() self._end = end
def download_data_from_tdx(download_futures, from_date=None, back_days=None): """ :param download_futures: ["rb2009.SHFE"] :param from_date: 2020-7-8 :param back_days: """ if tdxdata_client.init(): print("数据服务器登录成功") else: print("数据服务器登录失败") return for future in download_futures: _future = future.split(".") symbol = _future[0] exchange = Exchange.__dict__[_future[1]] interval = Interval.MINUTE if from_date: start = datetime.datetime.strptime(from_date, "%Y-%m-%d") else: bar = database_manager.get_newest_bar_data(symbol, exchange, interval) if bar: start = bar.datetime else: start = datetime.datetime(2012, 1, 1) if back_days: start = start.replace(tzinfo=None) _start = datetime.datetime.now() - datetime.timedelta(days=3) start = _start if start > _start else start # 下载数据 req = HistoryRequest(symbol, exchange, start, datetime.datetime.now(), interval=interval) data = tdxdata_client.query_history(req) # 写入数据库 if data: database_manager.save_bar_data(data) print(f"{symbol}更新完成:{data[0].datetime} -- {data[-1].datetime}") print("数据全部更新完毕")
def run_downloading( self, vt_symbol: str, interval: str, start: datetime, end: datetime ): """ Query bar data from RQData. """ self.write_log(f"{vt_symbol}-{interval} start downloading historical data ") symbol, exchange = extract_vt_symbol(vt_symbol) req = HistoryRequest( symbol=symbol, exchange=exchange, interval=Interval(interval), start=start, end=end ) contract = self.main_engine.get_contract(vt_symbol) try: # If history data provided in gateway, then query if contract and contract.history_data: data = self.main_engine.query_history( req, contract.gateway_name ) # Otherwise use RQData to query data else: data = rqdata_client.query_history(req) if data: database_manager.save_bar_data(data) self.write_log(f"{vt_symbol}-{interval} historical data download is complete ") else: self.write_log(f" data download failed , unable to get {vt_symbol} historical data ") except Exception: msg = f" data download failed , trigger abnormal :\n{traceback.format_exc()}" self.write_log(msg) # Clear thread object handler. self.thread = None
def main(symbol, exchange, start, end): # 1) prepare observation X data = database_manager.load_bar_data(symbol, exchange, Interval.MINUTE, start, end); if len(data) == 0: # download data if not presented if not tsdata_client.inited: print('登录tushare'); succeed = tsdata_client.init(); if False == succeed: print('tushare登录失败'); return; req = HistoryRequest(symbol = symbol, exchange = exchange, interval = Interval.MINUTE, start = start, end = end); data = tsdata_client.query_history(req); database_manager.save_bar_data(data); data = database_manager.load_bar_data(symbol, exchange, Interval.MINUTE, start, end); X = [[log(data[i].close_price) - log(data[i-1].close_price), log(data[i].close_price) - log(data[i-5].close_price), log(data[i].high_price) - log(data[i].low_price)] for i in range(5, len(data))]; # X.shape = (len(data) - 5, 3) # 2) learn the HMM model hmm = GaussianHMM(n_components = 6, covariance_type = 'diag', n_iter = 5000).fit(X); print('hidden markov model %s converged' % ('is' if hmm.monitor_.converged else 'is not')); with open('hmm.pkl', 'wb') as f: pickle.dump(hmm, f); # 3) visualize latent_states_sequence = hmm.predict(X); plt.figure(figsize = (15,8)); dates = [data[i].datetime.strftime('%Y-%m-%d') for i in range(5, len(data))]; close_prices = [data[i].close_price for i in range(5, len(data))]; for i in range(hmm.n_components): idx = (latent_states_sequence == i); # index of day labeled with i plt.plot(np.array(dates)[idx], np.array(close_prices)[idx], '.', label = 'latent idx %d' % i, lw = 1); plt.legend(); plt.grid(1); plt.axis([0, len(close_prices), min(close_prices), max(close_prices)]); plt.savefig('colored_k_bar.png'); plt.show(); for i in range(hmm.n_components): idx = (latent_states_sequence == i); # index of day labeled with i idx = np.append(False, idx[:-1]); # index of the next day of the day labeled with i, because if you trade on day with label i the reward comes one day after plt.plot(np.exp(np.array(X)[idx, 0].cumsum()), label = 'latent_state %d' % i); plt.legend(); plt.grid(1); plt.savefig('return_curve.png'); plt.show();
def csv_load(file_path): """ 读取csv文件内容,并写入到数据库中 """ for _file in os.listdir(file_path): full_path = os.path.join(file_path, _file) with open(full_path) as f: print(f"开始载入文件:{_file}") lines = f.readlines() start = None count = 0 buf = [] for row in lines: if count == 0: count += 1 continue items = row.strip().split("\t") _datetime = datetime.fromtimestamp(int(items[1])) _datetime = _datetime.replace(tzinfo=get_localzone()) bar = BarData( gateway_name="DB", symbol=items[0], exchange=Exchange.HUOBI, interval=Interval.MINUTE, datetime=_datetime, open_price=float(items[4]), high_price=float(items[5]), low_price=float(items[6]), close_price=float(items[7]), volume=float(items[8]), ) buf.append(bar) count += 1 if not start: start = bar.datetime end = bar.datetime database_manager.save_bar_data(buf) print("插入数据", start, "-", end, "总数量:", count)
def txt_load(loadpath, file): """ 读取csv文件内容,并写入到数据库中 """ print("载入文件:", file) fullfile = os.path.join(loadpath, file) exchange, symbol = file.replace(".", "#").split("#")[:2] if exchange == "SH": exchange = Exchange.SSE elif exchange == "SZ": exchange = Exchange.SZSE raise Exception("未匹配的交易所类型") else: raise Exception("未匹配的交易所类型") with open(fullfile, "r") as f: filelines = f.readlines() if len(filelines) < 4: return None res = filelines[:2] if res[0].split()[2] == "日线": interval = Interval.DAILY else: raise Exception("未匹配的间隔类型") tlist = [[col.strip() for col in item.split("\t")] for item in filelines[2:-1]] df = pd.DataFrame(tlist) df.columns = res[1].split() data = [] if df is not None: for ix, row in df.iterrows(): date = datetime.strptime(row["日期"], '%Y/%m/%d') bar = BarData(symbol=symbol, exchange=exchange, interval=interval, datetime=date, open_price=row["开盘"], high_price=row["最高"], low_price=row["最低"], close_price=row["收盘"], volume=row["成交量"], gateway_name="XTP") data.append(bar) if data: database_manager.save_bar_data(data) print("单标的 插入数据 数量:", df.shape[0])
def download_data_from_tiingo(self, symbol: str, interval: Interval, exchange: Exchange, start: datetime): format = '%Y-%m-%d' if interval is Interval.DAILY: json = tiingo_client.get_ticker_price( symbol, startDate=start.strftime(format), endDate=datetime.today().strftime(format), frequency='daily') data = self.translateJsonToBarDataList(symbol, exchange, json) if data: database_manager.save_bar_data(data) return (len(data)) return 0