def load_fin_data_cwbbzy(): """导入上市公司财务报表摘要""" cfg = ConfigParser() cfg.read('config.ini') cwbbzy_url = cfg.get('fin_data', 'cwbbzy_url') cwbbzy_path = os.path.join(cfg.get('factor_db', 'db_path'), cfg.get('fin_data', 'cwbbzy_path')) # 读取个股代码 data_api = DataApi(addr='tcp://data.tushare.org:8910') data_api.login( '13811931480', 'eyJhbGciOiJIUzI1NiJ9.eyJjcmVhdGVfdGltZSI6IjE1MTI4Nzk0NTI2MjkiLCJpc3MiOiJhdXRoMCIsImlkIjoiMTM4MTE5MzE0ODAifQ.I0SXsA1bK--fbGu0B5Is2xdKOjALAeWBJRX6GdVmUL8' ) df_stock_basics, msg = data_api.query( view='jz.instrumentInfo', fields='status,list_date,name,market', filter='inst_type=1&status=1&market=SH,SZ&symbol=', data_format='pandas') if msg != '0,': print('读取市场个股代码失败。') return df_stock_basics.symbol = df_stock_basics.symbol.map( lambda x: x.split('.')[0]) # 遍历个股, 下载财务报表摘要数据 for _, stock_info in df_stock_basics.iterrows(): url = cwbbzy_url % stock_info.symbol resp = requests.get(url) if resp.status_code != requests.codes.ok: print('%s的财务报表摘要数据下载失败!' % stock_info.symbol) continue print('下载%s的财务报表摘要数据.' % stock_info.symbol) fin_data = resp.text tmp = fin_data.split(',')[-1] fin_data = fin_data.replace(tmp, '') fin_data = fin_data.split('\r\n') fin_datas = [] for data in fin_data: s = data.split(',') fin_datas.append(s[:-1]) n = min([len(data) for data in fin_datas]) dict_fin_data = {data[0]: data[1:n] for data in fin_datas} fin_header = [data[0] for data in fin_datas] df_fin_data = DataFrame(dict_fin_data, columns=fin_header) df_fin_data = df_fin_data.sort_values(by=fin_header[0]) df_fin_data.to_csv(os.path.join( cwbbzy_path, '%s.csv' % utils.code_to_symbol(stock_info.symbol)), index=False)
def load_ipo_info(): """从网易财经下载个股的IPO数据""" cfg = ConfigParser() cfg.read('config.ini') ipo_info_url = cfg.get('ipo_info', 'ipo_info_url') db_path = Path(cfg.get('factor_db', 'db_path'), cfg.get('ipo_info', 'db_path')) # 读取所有已上市个股代码 data_api = DataApi(addr='tcp://data.tushare.org:8910') data_api.login('13811931480', 'eyJhbGciOiJIUzI1NiJ9.eyJjcmVhdGVfdGltZSI6IjE1MTI4Nzk0NTI2MjkiLCJpc3MiOiJhdXRoMCIsImlkIjoiMTM4MTE5MzE0ODAifQ.I0SXsA1bK--fbGu0B5Is2xdKOjALAeWBJRX6GdVmUL8') df_stock_basics, msg = data_api.query(view='jz.instrumentInfo', fields='status,list_date,name,market', filter='inst_type=1&status=&market=SH,SZ&symbol=', data_format='pandas') if msg != '0,': print('读取市场个股代码失败。') return df_stock_basics.symbol = df_stock_basics.symbol.map(lambda x: x.split('.')[0]) # 遍历个股, 下载ipo信息数据 df_ipo_info = DataFrame() for _, stock_info in df_stock_basics.iterrows(): print('下载%s的IPO数据.' % stock_info.symbol) ipo_info_header = [] ipo_info_data = [] secu_code = utils.code_to_symbol(stock_info.symbol) url = ipo_info_url % stock_info.symbol html = requests.get(url).content soup = BeautifulSoup(html, 'html.parser') tags = soup.find_all(name='h2') for tag in tags: if tag.get_text().strip() == 'IPO资料': ipo_table = tag.find_next(name='table') for tr in ipo_table.find_all(name='tr'): tds = tr.find_all(name='td') name = tds[0].get_text().replace(' ', '').replace('\n', '').replace('\r', '') value = tds[1].get_text().replace(' ', '').replace(',','').replace('\n', '').replace('\r', '') ipo_info_header.append(name) ipo_info_data.append(value) ipo_info = Series(ipo_info_data, index=ipo_info_header) ipo_info['代码'] = secu_code ipo_info.to_csv(db_path.joinpath('%s.csv' % secu_code)) df_ipo_info = df_ipo_info.append(ipo_info, ignore_index=True) break df_ipo_info.to_csv(db_path.joinpath('ipo_info.csv'), index=False)
def get_index_basic_information(): api = DataApi(addr=data_config.get('remote.data.address')) api.login(data_config.get('remote.data.username'), data_config.get('remote.data.password')) df, msg = api.query( view="lb.indexInfo", fields="symbol,name,listdate,expire_date", filter="symbol=" + index, data_format='pandas') # 判断今天的日期是否超出了指数的终止发布日期, 如果超出,以终止发布日期为准 # 开始日期等于指数发布日期 today = dt.today().strftime('%Y%m%d') if df['expire_date'][0]: end_date = min(today, df['expire_date'][0]) else: end_date = today symbol, name, start_date = df['symbol'][0], df['name'][0], df['listdate'][0] # print(symbol, name, start_date, end_date) return (symbol, name, start_date, end_date)
def test_data_api(): dic = fileio.read_json(fileio.join_relative_path('etc/data_config.json')) address = dic.get("remote.address", None) username = dic.get("remote.username", None) password = dic.get("remote.password", None) if address is None or username is None or password is None: raise ValueError("no data service config available!") api = DataApi(address, use_jrpc=False) login_msg = api.login(username=username, password=password) print login_msg daily, msg = api.daily( symbol="600030.SH,000002.SZ", start_date=20170103, end_date=20170708, fields="open,high,low,close,volume,last,trade_date,settle") daily2, msg2 = api.daily( symbol="600030.SH", start_date=20170103, end_date=20170708, fields="open,high,low,close,volume,last,trade_date,settle") # err_code, err_msg = msg.split(',') assert msg == '0,' assert msg2 == '0,' assert daily.shape == (248, 8) assert daily2.shape == (124, 8) df, msg = api.bar(symbol="600030.SH", trade_date=20170904, freq='1m', start_time=90000, end_time=150000) print df.columns assert df.shape == (240, 15) print "test passed"
# coding: utf-8 # In[1]: from jaqs.data.dataapi import DataApi # In[2]: api = DataApi(addr="tcp://192.168.1.117:8910") # 根据本地实际情况修改 api.login("25","25") # In[3]: symbol = 'rb1805.SHF' fields = 'open,high,low,last,volume' # 获取实时行情 df, msg = api.quote(symbol=symbol, fields=fields) print(df) print(msg) # In[14]:
class RemoteDataService(DataService): """ RemoteDataService is a concrete class using data from remote server's database. """ __metaclass__ = Singleton # TODO no validity check for input parameters def __init__(self): DataService.__init__(self) dic = fileio.read_json( fileio.join_relative_path('etc/data_config.json')) address = dic.get("remote.address", None) username = dic.get("remote.username", None) password = dic.get("remote.password", None) if address is None or username is None or password is None: raise ValueError("no address, username or password available!") self.api = DataApi(address, use_jrpc=False) self.api.set_timeout(60) r, msg = self.api.login(username=username, password=password) if not r: print msg else: print "DataAPI login success.".format(address) self.REPORT_DATE_FIELD_NAME = 'report_date' def daily(self, symbol, start_date, end_date, fields="", adjust_mode=None): df, err_msg = self.api.daily(symbol=symbol, start_date=start_date, end_date=end_date, fields=fields, adjust_mode=adjust_mode, data_format="") # trade_status performance warning # TODO there will be duplicate entries when on stocks' IPO day df = df.drop_duplicates() return df, err_msg def bar(self, symbol, start_time=200000, end_time=160000, trade_date=None, freq='1m', fields=""): df, msg = self.api.bar(symbol=symbol, fields=fields, start_time=start_time, end_time=end_time, trade_date=trade_date, freq='1m', data_format="") return df, msg def query(self, view, filter="", fields="", **kwargs): """ Get various reference data. Parameters ---------- view : str data source. fields : str Separated by ',' filter : str filter expressions. kwargs Returns ------- df : pd.DataFrame msg : str error code and error message, joined by ',' Examples -------- res3, msg3 = ds.query("lb.secDailyIndicator", fields="price_level,high_52w_adj,low_52w_adj", filter="start_date=20170907&end_date=20170907", orderby="trade_date", data_format='pandas') view does not change. fileds can be any field predefined in reference data api. """ df, msg = self.api.query(view, fields=fields, filter=filter, data_format="", **kwargs) return df, msg def get_suspensions(self): return None # TODO use Calendar instead def get_trade_date(self, start_date, end_date, symbol=None, is_datetime=False): if symbol is None: symbol = '000300.SH' df, msg = self.daily(symbol, start_date, end_date, fields="close") res = df.loc[:, 'trade_date'].values if is_datetime: res = dtutil.convert_int_to_datetime(res) return res @staticmethod def _dic2url(d): """ Convert a dict to str like 'k1=v1&k2=v2' Parameters ---------- d : dict Returns ------- str """ l = ['='.join([key, str(value)]) for key, value in d.items()] return '&'.join(l) def query_lb_fin_stat(self, type_, symbol, start_date, end_date, fields=""): """ Helper function to call data_api.query with 'lb.income' more conveniently. Parameters ---------- type_ : {'income', 'balance_sheet', 'cash_flow'} symbol : str separated by ',' start_date : int Annoucement date in results will be no earlier than start_date end_date : int Annoucement date in results will be no later than start_date fields : str, optional separated by ',', default "" Returns ------- df : pd.DataFrame index date, columns fields msg : str """ view_map = { 'income': 'lb.income', 'cash_flow': 'lb.cashFlow', 'balance_sheet': 'lb.balanceSheet', 'fin_indicator': 'lb.finIndicator' } view_name = view_map.get(type_, None) if view_name is None: raise NotImplementedError("type_ = {:s}".format(type_)) dic_argument = { 'symbol': symbol, 'start_date': start_date, 'end_date': end_date, 'update_flag': '0' } if view_name != 'lb.finIndicator': dic_argument.update({ 'report_type': '408001000' }) # we do not use single quarter single there are zeros """ 408001000: joint 408002000: joint (single quarter) """ filter_argument = self._dic2url( dic_argument) # 0 means first time, not update res, msg = self.query(view_name, fields=fields, filter=filter_argument, order_by=self.REPORT_DATE_FIELD_NAME) # change data type try: cols = list( set.intersection({'ann_date', 'report_date'}, set(res.columns))) dic_dtype = {col: int for col in cols} res = res.astype(dtype=dic_dtype) except: pass return res, msg def query_lb_dailyindicator(self, symbol, start_date, end_date, fields=""): """ Helper function to call data_api.query with 'lb.secDailyIndicator' more conveniently. Parameters ---------- symbol : str separated by ',' start_date : int end_date : int fields : str, optional separated by ',', default "" Returns ------- df : pd.DataFrame index date, columns fields msg : str """ filter_argument = self._dic2url({ 'symbol': symbol, 'start_date': start_date, 'end_date': end_date }) return self.query("lb.secDailyIndicator", fields=fields, filter=filter_argument, orderby="trade_date") def _get_index_comp(self, index, start_date, end_date): """ Return all securities that have been in index during start_date and end_date. Parameters ---------- index : str separated by ',' start_date : int end_date : int Returns ------- list """ filter_argument = self._dic2url({ 'index_code': index, 'start_date': start_date, 'end_date': end_date }) df_io, msg = self.query("lb.indexCons", fields="", filter=filter_argument, orderby="symbol") return df_io, msg def get_index_comp(self, index, start_date, end_date): """ Return list of symbols that have been in index during start_date and end_date. Parameters ---------- index : str separated by ',' start_date : int end_date : int Returns ------- list """ df_io, msg = self._get_index_comp(index, start_date, end_date) if msg != '0,': print msg return list(np.unique(df_io.loc[:, 'symbol'])) def get_index_comp_df(self, index, start_date, end_date): """ Get index components on each day during start_date and end_date. Parameters ---------- index : str separated by ',' start_date : int end_date : int Returns ------- res : pd.DataFrame index dates, columns all securities that have ever been components, values are 0 (not in) or 1 (in) """ df_io, msg = self._get_index_comp(index, start_date, end_date) if msg != '0,': print msg def str2int(s): if isinstance(s, (str, unicode)): return int(s) if s else 99999999 elif isinstance(s, (int, np.integer, float, np.float)): return s else: raise NotImplementedError("type s = {}".format(type(s))) df_io.loc[:, 'in_date'] = df_io.loc[:, 'in_date'].apply(str2int) df_io.loc[:, 'out_date'] = df_io.loc[:, 'out_date'].apply(str2int) # df_io.set_index('symbol', inplace=True) dates = self.get_trade_date(start_date=start_date, end_date=end_date, symbol=index) dic = dict() gp = df_io.groupby(by='symbol') for sec, df in gp: mask = np.zeros_like(dates, dtype=int) for idx, row in df.iterrows(): bool_index = np.logical_and(dates > row['in_date'], dates < row['out_date']) mask[bool_index] = 1 dic[sec] = mask res = pd.DataFrame(index=dates, data=dic) return res @staticmethod def _group_df_to_dict(df, by): gp = df.groupby(by=by) res = {key: value for key, value in gp} return res def get_industry_daily(self, symbol, start_date, end_date, type_='SW'): """ Get index components on each day during start_date and end_date. Parameters ---------- symbol : str separated by ',' start_date : int end_date : int type_ : {'SW', 'ZZ'} Returns ------- res : pd.DataFrame index dates, columns symbols values are industry code """ df_raw = self.get_industry_raw(symbol, type_=type_) dic_sec = self._group_df_to_dict(df_raw, by='symbol') dic_sec = { sec: df.sort_values(by='in_date', axis=0).reset_index() for sec, df in dic_sec.viewitems() } df_ann = pd.concat([ df.loc[:, 'in_date'].rename(sec) for sec, df in dic_sec.viewitems() ], axis=1) df_value = pd.concat([ df.loc[:, 'industry1_code'].rename(sec) for sec, df in dic_sec.viewitems() ], axis=1) dates_arr = self.get_trade_date(start_date, end_date) df_industry = align.align(df_value, df_ann, dates_arr) # TODO before industry classification is available, we assume they belong to their first group. df_industry = df_industry.fillna(method='bfill') df_industry = df_industry.astype(str) return df_industry def get_industry_raw(self, symbol, type_='ZZ'): """ Get daily industry of securities from ShenWanHongYuan or ZhongZhengZhiShu. Parameters ---------- symbol : str separated by ',' type_ : {'SW', 'ZZ'} Returns ------- df : pd.DataFrame """ if type_ == 'SW': src = u'申万研究所'.encode('utf-8') elif type_ == 'ZZ': src = u'中证指数有限公司'.encode('utf-8') else: raise ValueError("type_ must be one of SW of ZZ") filter_argument = self._dic2url({ 'symbol': symbol, 'industry_src': src }) fields_list = ['symbol', 'industry1_code', 'industry1_name'] df_raw, msg = self.query("lb.secIndustry", fields=','.join(fields_list), filter=filter_argument, orderby="symbol") if msg != '0,': print msg df_raw = df_raw.astype(dtype={ 'in_date': int, # 'out_date': int }) return df_raw.drop_duplicates() def get_adj_factor_daily(self, symbol, start_date, end_date, div=False): """ Get index components on each day during start_date and end_date. Parameters ---------- symbol : str separated by ',' start_date : int end_date : int div : bool False for normal adjust factor, True for diff. Returns ------- res : pd.DataFrame index dates, columns symbols values are industry code """ df_raw = self.get_adj_factor_raw(symbol) dic_sec = self._group_df_to_dict(df_raw, by='symbol') dic_sec = { sec: df.loc[:, ['trade_date', 'adjust_factor']].set_index( 'trade_date').iloc[:, 0] for sec, df in dic_sec.viewitems() } res = pd.concat(dic_sec, axis=1) # align to every trade date s, e = df_raw.loc[:, 'trade_date'].min(), df_raw.loc[:, 'trade_date'].max() dates_arr = self.get_trade_date(s, e) res = res.reindex(dates_arr) res = res.fillna(method='ffill').fillna(method='bfill') if div: res = res.div(res.shift(1, axis=0)).fillna(1.0) res = res.loc[start_date:end_date, :] return res def get_adj_factor_raw(self, symbol, start_date=None, end_date=None): """ Query adjust factor for symbols. Parameters ---------- symbol : str separated by ',' start_date : int end_date : int Returns ------- df : pd.DataFrame """ if start_date is None: start_date = "" if end_date is None: end_date = "" filter_argument = self._dic2url({ 'symbol': symbol, 'start_date': start_date, 'end_date': end_date }) fields_list = ['symbol', 'trade_date', 'adjust_factor'] df_raw, msg = self.query("lb.secAdjFactor", fields=','.join(fields_list), filter=filter_argument, orderby="symbol") if msg != '0,': print msg df_raw = df_raw.astype(dtype={ 'symbol': str, 'trade_date': int, 'adjust_factor': float }) return df_raw.drop_duplicates() def query_inst_info(self, symbol, inst_type="", fields=""): if inst_type == "": inst_type = "1,2,3,4,5,101,102,103,104" filter_argument = self._dic2url({ 'symbol': symbol, 'inst_type': inst_type }) df_raw, msg = self.query("jz.instrumentInfo", fields=fields, filter=filter_argument, orderby="symbol") if msg != '0,': print msg dtype_map = {'symbol': str, 'list_date': int, 'delist_date': int} cols = set(df_raw.columns) dtype_map = {k: v for k, v in dtype_map.items() if k in cols} df_raw = df_raw.astype(dtype=dtype_map) return df_raw, msg
jaqs_fxdayu.patch_all() from jaqs.data import DataView from jaqs.data import RemoteDataService from jaqs_fxdayu.data.dataservice import LocalDataService import os import numpy as np import alpha32_, alpha42_, alpha56_, alpha62_, alpha64_, alpha194, alpha195, alpha197, Beta3 import pandas as pd import matplotlib.pyplot as plt from jaqs_fxdayu.research import SignalDigger from jaqs_fxdayu.research.signaldigger import analysis from jaqs_fxdayu.research.signaldigger import multi_factor api = DataApi(addr='tcp://data.tushare.org:8910') api.login( "18523827661", 'eyJhbGciOiJIUzI1NiJ9.eyJjcmVhdGVfdGltZSI6IjE1MjIxMTc0NDY1MzAiLCJpc3MiOiJhdXRoMCIsImlkIjoiMTg1MjM4Mjc2NjEifQ.AO9Rp8jG_IWc6crPrBOC-ujMP0-g1S1c5kUlTs5qwrk' ) start = 20100101 end = 20180401 SH_id = dp.index_cons(api, "000300.SH", start, end) SZ_id = dp.index_cons(api, "000905.SH", start, end) stock_symbol = list(set(SH_id.symbol) | set(SZ_id.symbol)) factor_list = ['volume', 'float_mv', 'pe', 'ps'] check_factor = ','.join(factor_list) dataview_folder = '/Users/adam/Desktop/intern/test5/fxdayu_adam/data' dataview_folder2 = 'muti_factor/' dv = DataView() #ds = LocalDataService(fp=dataview_folder) data_config = {
"eyJhbGciOiJIUzI1NiJ9.eyJjcmVhdGVfdGltZSI6IjE1MzU1OTg2MTI0NzYiLCJpc3MiOiJhdXRoMCIsImlkIjoiMTU1NjYwMjg1NjgifQ.ToTAnVWpKtweGj4yoXhVW0pzHds7a9qQzXC8qLBui2g" #token是api令牌 } trade_config = { "remote.trade.address": "tcp://gw.quantos.org:8901", "remote.trade.username": "******", "remote.trade.password": "******" } dataview_dir_path = '***' #dataview存储路径 backtest_result_dir_path = '***' #回测结果存储路径 api = DataApi(addr='tcp://data.quantos.org:8910') api.login("phone", "token") dv = DataView() dv.load_dataview(folder_path=dataview_dir_path) dv.add_field('turnover_ratio', ds) dv.save_dataview(folder_path=dataview_dir_path) dv.update_snapshot() #在已经保存dataview之后,需要添加某项指标(dataview.py里面可以找这些指标,不是因子!),如turnover_ratio,使用这段代码,添加完之后下一次运行就可以注释掉 def my_selector(context, user_options=None): #筛选僵尸股和ST股 df = context.dataview.data_inst['name'] result = pd.DataFrame(df, columns=['ST'], index=df.index, dtype='bool') selector_volume = context.snapshot[
import pandas as pd from jaqs_fxdayu.util import dp from jaqs.data.dataapi import DataApi import numpy as np import talib as ta from jaqs_fxdayu.research.signaldigger.process import neutralize api = DataApi(addr='tcp://data.tushare.org:8910') api.login( "13662241013", 'eyJhbGciOiJIUzI1NiJ9.eyJjcmVhdGVfdGltZSI6IjE1MTc2NDQzMzg5MTIiLCJpc3MiOiJhdXRoMCIsImlkIjoiMTM2NjIyNDEwMTMifQ.sVIzI5VLqq8fbZCW6yZZW0ClaCkcZpFqpiK944AHEow' ) start = 20130101 end = 20180101 SH_id = dp.index_cons(api, "000300.SH", start, end) SZ_id = dp.index_cons(api, "000905.SH", start, end) stock_symbol = list(set(SH_id.symbol) | set(SZ_id.symbol)) #读取数据,其中需要调节我们需要的因子的数量和内容 factor_list = [ 'oper_rev', 'oper_rev_ttm', 'oper_rev_lyr', "total_oper_rev", 'less_gerl_admin_exp', 'AdminiExpenseRate', 'volume', "gainvariance120", 'index', "cash_recp_sg_and_rs" ] check_factor = ','.join(factor_list) import jaqs_fxdayu jaqs_fxdayu.patch_all() from jaqs.data import DataView from jaqs.data import RemoteDataService from jaqs_fxdayu.data.dataservice import LocalDataService
# coding: utf-8 # In[2]: import pandas as pd import numpy as np # In[3]: from jaqs.data.dataapi import DataApi api = DataApi(addr='tcp://data.tushare.org:8910') api.login( '18810695562', 'eyJhbGciOiJIUzI1NiJ9.eyJjcmVhdGVfdGltZSI6IjE1MTMzMTU1MTE3MDEiLCJpc3MiOiJhdXRoMCIsImlkIjoiMTg4MTA2OTU1NjIifQ.2qvBxyHzcIFfC4_5lP_MFkNwDGLYD2gwqxMLWMpOLw0' ) # In[4]: df, msg = api.quote("000001.SH, cu1802.SHF", fields="open,high,low,last,volume") # In[5]: df # In[5]: df1, msg1 = api.query(view="jz.secTradeCal", fields="date,market", filter="start_date=20170901&end_date=20171101",
class MyQuantosDataApi(): def __init__(self): self.symbol = "" self.fields = "" self.start_date = 0 self.end_date = 0 self._client = None self._db = None self._col = None self._db_name = "" self._col_name = "" self._db_list = [] self._col_list = [] self._data_init_config = {} self._dataapi = None self._remote_data_service = None self._instrument_info = pd.DataFrame() self._symbol_set_all_A = set() self._dataview_data = pd.DataFrame() self._dbase_props = {} self._db_init_config = ['reference_daily_fields'] # self._input_init_config = [] self._data_status = 0 def _login(self, dataapi=False, remote=False, mongodb=False): self._data_init_config = { "remote.data.address": "tcp://data.tushare.org:8910", "remote.data.username": "******", "remote.data.password": "******" } if dataapi: self._dataapi = DataApi( addr=self._data_init_config.get('remote.data.address')) self._dataapi.login( self._data_init_config.get('remote.data.username'), self._data_init_config.get('remote.data.password')) if remote: self._remote_data_service = RemoteDataService() self._remote_data_service.init_from_config(self._data_init_config) if mongodb: # 连接mongoDB self._client = pymongo.MongoClient(host='localhost', port=27017) def _logout(self): "登出mongoDB数据库" self._client.close() def _set_data_status(self): ''' 判断是首次下载,或增量更新数据 首次下载,则self._data_status = 1 增量更新,则self._data_status = 2 同时,返回去掉默认数据库后的db列表和collection列表 ''' self._login(mongodb=True) self._db_list = self._client.list_database_names() # 去掉默认的数据库,admin和local self._db_list = self._db_list.remove('admin') self._db_list = self._db_list.remove('local') if self._db_list == []: self._data_status = 1 self._client.close() else: self._col_list = self._db.collection_names() self._data_status = 2 self._client.close() def _prepare_data(self): self._set_data_status() self._set_fields() self._login(dataapi=True, remote=True) self._get_instrumentInfo() self._get_symbol_set_all_A() if self._data_status == 1: self._db_name = self._db_init_config[0] self._col_name = self.symbol self._get_start_end_date_first_time() if self._data_status == 2: self._get_start_end_date_update() # def _get_input_init_config(self): # "确定最终多线程map函数的输入列表,返回[(db_name, col_name)]" # for db_name in self._db_init_config: # for col_name in self._symbol_set_all_A: # self._input_init_config.append(db_name, col_name) def _set_fields(self): "确定查询字段,同时也确定mongoDB的db_name" dv = DataView() fields_init_config = { 'reference_daily_fields': dv.reference_daily_fields # 此处可能增加新的字段,只要是qunatos的dataview支持的字段 } self.fields = fields_init_config.get('reference_daily_fields') def _get_instrumentInfo(self): ''' 获取沪深A股基本资料 inst_type=1 证券类别:股票 status=1 上市状态:上市 返回: pd.DataFrame ''' df, msg = self._dataapi.query( view="jz.instrumentInfo", fields="status,list_date,delist_date,name,market,symbol", filter="inst_type=1&status=1", data_format='pandas') # 股票市场为沪市和深市,原数据包括港股 df = df[(df['market'] == 'SH') | (df['market'] == 'SZ')] self._instrument_info = df def _get_symbol_set_all_A(self): "获取当前日期,所有上市状态的A股代码集合" self._symbol_set_all_A = set(self._instrument_info['symbol']) def _get_start_end_date_first_time(self): "确定首次下载数据时的起止日期" df = self._instrument_info list_date = df[df['symbol'] == self.symbol]['list_date'].iloc[0] delist_date = df[df['symbol'] == self.symbol]['delist_date'].iloc[0] # 判定上市、退市日期和today的关系 today = datetime.today().strftime('%Y%m%d') self.start_date = int(list_date) self.end_date = int(min(today, delist_date)) def _get_start_end_date_update(self): "确定增量更新数据时的起止日期" # 从数据库中获取已存数据的交易日期, 类型pd.DataFrame" trade_date = self._col.find({}, {'trade_date': 1, '_id': 0}) trade_date = [i for i in trade_date] trade_date = pd.DataFrame(trade_date) # 获取已存数据的最新交易日期,类型str" latest_date = trade_date['trade_date'].max() latest_date = str(latest_date) # 增量更新时的起止日期" today = datetime.today().strftime('%Y%m%d') latest_date = datetime.strptime(latest_date, '%Y%m%d') latest_date = latest_date + timedelta(days=1) latest_date = latest_date.strftime('%Y%m%d') if today >= latest_date: self.start_date = int(latest_date) self.end_date = int(today) else: print("Trade date in Dbase is latest.") def _download_data(self): "使用quantos的dataview,下载单个股票的给定字段的数据,返回pd.DataFrame" dv = DataView() # fields = ','.join(list(dv.reference_daily_fields)) props = { 'symbol': self.symbol, 'fields': self.fields, 'start_date': self.start_date, 'end_date': self.end_date, 'freq': 1 } dv.init_from_config(props=props, data_api=self._remote_data_service) dv.prepare_data() self._dataview_data = dv.data_d def _dataframe_to_dbase_props(self): "将dataframe转换为mongoDB中的props,以备写入数据库" df = self._dataview_date df.columns = df.columns.droplevel() df.reset_index(inplace=True) # dataframe的范围限制在更新时期之内,因为dataview取数据时会将日期的范围前后各放宽几天 df = df[(df['trade_date'] >= self.start_date) & (df['trade_date'] <= self.end_date)] # 判断drop掉trade_date列,并dropna后,dataframe是否为空,空则说明更新日期已经是最新的 is_df = df.drop('trade_date', axis=1).dropna(how='all') if is_df.empty: print("\n\nTrade date is latest.\n\n") return None else: props = df.to_dict(orient='records') self._dbase_props = props def _write_data_to_dbase(self): pass
class Calendar(object): """ A calendar for manage trade date. Attributes ---------- data_api : """ def __init__(self, data_api=None): if data_api is not None: self.data_api = data_api else: props = jutil.read_json( jutil.join_relative_path('etc/data_config.json')) address = props.get("remote.address", "") username = props.get("remote.username", "") password = props.get("remote.password", "") if address is None or username is None or password is None: raise ValueError("no address, username or password available!") time_out = props.get("timeout", 60) self.data_api = DataApi(address, use_jrpc=False) self.data_api.set_timeout(timeout=time_out) r, msg = self.data_api.login(username=username, password=password) if not r: print("DataAPI login failed: msg = '{}".format(msg)) else: print "DataAPI login success : {}@{}".format(username, address) @staticmethod def _dic2url(d): """ Convert a dict to str like 'k1=v1&k2=v2' Parameters ---------- d : dict Returns ------- str """ l = ['='.join([key, str(value)]) for key, value in d.viewitems()] return '&'.join(l) def get_trade_date_range(self, start_date, end_date): """ Get array of trade dates within given range. Return zero size array if no trade dates within range. Parameters ---------- start_date : int YYmmdd end_date : int Returns ------- trade_dates_arr : np.ndarray dtype = int """ filter_argument = self._dic2url({ 'start_date': start_date, 'end_date': end_date }) df_raw, msg = self.data_api.query("jz.secTradeCal", fields="trade_date", filter=filter_argument, orderby="") if df_raw.empty: return np.array([], dtype=int) trade_dates_arr = df_raw['trade_date'].values.astype(int) return trade_dates_arr def get_last_trade_date(self, date): """ Parameters ---------- date : int Returns ------- res : int """ dt = jutil.convert_int_to_datetime(date) delta = pd.Timedelta(weeks=2) dt_old = dt - delta date_old = jutil.convert_datetime_to_int(dt_old) dates = self.get_trade_date_range(date_old, date) mask = dates < date res = dates[mask][-1] return res def is_trade_date(self, date): """ Check whether date is a trade date. Parameters ---------- date : int Returns ------- bool """ dates = self.get_trade_date_range(date, date) return len(dates) > 0 def get_next_trade_date(self, date): """ Parameters ---------- date : int Returns ------- res : int """ dt = jutil.convert_int_to_datetime(date) delta = pd.Timedelta(weeks=2) dt_new = dt + delta date_new = jutil.convert_datetime_to_int(dt_new) dates = self.get_trade_date_range(date, date_new) mask = dates > date res = dates[mask][0] return res
class RemoteDataService(DataService): """ RemoteDataService is a concrete class using data from remote server's database. """ __metaclass__ = Singleton # TODO no validity check for input parameters def __init__(self): DataService.__init__(self) self.data_api = None self.REPORT_DATE_FIELD_NAME = 'report_date' self.calendar = None def __del__(self): self.data_api.close() def init_from_config(self, props=None): if props is None: props = dict() if self.data_api is not None: if len(props) == 0: return else: self.data_api.close() def get_from_list_of_dict(l, key, default=None): res = None for dic in l: res = dic.get(key, None) if res is not None: break if res is None: res = default return res props_default = jutil.read_json( jutil.join_relative_path('etc/data_config.json')) dic_list = [props, props_default] address = get_from_list_of_dict(dic_list, "remote.address", "") username = get_from_list_of_dict(dic_list, "remote.username", "") password = get_from_list_of_dict(dic_list, "remote.password", "") if address is None or username is None or password is None: raise ValueError("no address, username or password available!") time_out = get_from_list_of_dict(dic_list, "timeout", 60) self.data_api = DataApi(address, use_jrpc=False) self.data_api.set_timeout(timeout=time_out) print("\nDataApi login: {}@{}".format(username, address)) r, msg = self.data_api.login(username=username, password=password) if not r: print(" login failed: msg = '{}'\n".format(msg)) else: print " login success \n" self.calendar = Calendar(self.data_api) # ----------------------------------------------------------------------------------- # Basic APIs def daily(self, symbol, start_date, end_date, fields="", adjust_mode=None): df, err_msg = self.data_api.daily(symbol=symbol, start_date=start_date, end_date=end_date, fields=fields, adjust_mode=adjust_mode, data_format="") # trade_status performance warning # TODO there will be duplicate entries when on stocks' IPO day df = df.drop_duplicates() return df, err_msg def bar(self, symbol, start_time=200000, end_time=160000, trade_date=None, freq='1M', fields=""): df, msg = self.data_api.bar(symbol=symbol, fields=fields, start_time=start_time, end_time=end_time, trade_date=trade_date, freq='1M', data_format="") return df, msg def query(self, view, filter="", fields="", **kwargs): """ Get various reference data. Parameters ---------- view : str data source. fields : str Separated by ',' filter : str filter expressions. kwargs Returns ------- df : pd.DataFrame msg : str error code and error message, joined by ',' Examples -------- res3, msg3 = ds.query("lb.secDailyIndicator", fields="price_level,high_52w_adj,low_52w_adj",\ filter="start_date=20170907&end_date=20170907",\ orderby="trade_date",\ data_format='pandas') view does not change. fileds can be any field predefined in reference data api. """ df, msg = self.data_api.query(view, fields=fields, filter=filter, data_format="", **kwargs) return df, msg # ----------------------------------------------------------------------------------- # Convenient Functions def get_trade_date_range(self, start_date, end_date): return self.calendar.get_trade_date_range(start_date, end_date) @staticmethod def _dic2url(d): """ Convert a dict to str like 'k1=v1&k2=v2' Parameters ---------- d : dict Returns ------- str """ l = ['='.join([key, str(value)]) for key, value in d.viewitems()] return '&'.join(l) def query_lb_fin_stat(self, type_, symbol, start_date, end_date, fields="", drop_dup_cols=None): """ Helper function to call data_api.query with 'lb.income' more conveniently. Parameters ---------- type_ : {'income', 'balance_sheet', 'cash_flow'} symbol : str separated by ',' start_date : int Annoucement date in results will be no earlier than start_date end_date : int Annoucement date in results will be no later than start_date fields : str, optional separated by ',', default "" drop_dup_cols : list or tuple Whether drop duplicate entries according to drop_dup_cols. Returns ------- df : pd.DataFrame index date, columns fields msg : str """ view_map = { 'income': 'lb.income', 'cash_flow': 'lb.cashFlow', 'balance_sheet': 'lb.balanceSheet', 'fin_indicator': 'lb.finIndicator' } view_name = view_map.get(type_, None) if view_name is None: raise NotImplementedError("type_ = {:s}".format(type_)) dic_argument = { 'symbol': symbol, 'start_date': start_date, 'end_date': end_date, # 'update_flag': '0' } if view_name != 'lb.finIndicator': dic_argument.update({ 'report_type': '408001000' }) # we do not use single quarter single there are zeros """ 408001000: joint 408002000: joint (single quarter) """ filter_argument = self._dic2url( dic_argument) # 0 means first time, not update res, msg = self.query(view_name, fields=fields, filter=filter_argument, order_by=self.REPORT_DATE_FIELD_NAME) # change data type try: cols = list( set.intersection({'ann_date', 'report_date'}, set(res.columns))) dic_dtype = {col: int for col in cols} res = res.astype(dtype=dic_dtype) except: pass if drop_dup_cols is not None: res = res.sort_values(by=drop_dup_cols, axis=0) res = res.drop_duplicates(subset=drop_dup_cols, keep='first') return res, msg def query_lb_dailyindicator(self, symbol, start_date, end_date, fields=""): """ Helper function to call data_api.query with 'lb.secDailyIndicator' more conveniently. Parameters ---------- symbol : str separated by ',' start_date : int end_date : int fields : str, optional separated by ',', default "" Returns ------- df : pd.DataFrame index date, columns fields msg : str """ filter_argument = self._dic2url({ 'symbol': symbol, 'start_date': start_date, 'end_date': end_date }) return self.query("lb.secDailyIndicator", fields=fields, filter=filter_argument, orderby="trade_date") def get_index_weights(self, index, trade_date): """ Return all securities that have been in index during start_date and end_date. Parameters ---------- index : str separated by ',' trade_date : int Returns ------- pd.DataFrame """ if index == '000300.SH': index = '399300.SZ' filter_argument = self._dic2url({ 'index_code': index, 'trade_date': trade_date }) df_io, msg = self.query("lb.indexWeight", fields="", filter=filter_argument) if msg != '0,': print msg df_io = df_io.set_index('symbol') df_io = df_io.astype({'weight': float, 'trade_date': int}) df_io.loc[:, 'weight'] = df_io['weight'] / 100. return df_io def get_index_weights_daily(self, index, start_date, end_date): """ Return all securities that have been in index during start_date and end_date. Parameters ---------- index : str start_date : int end_date : int Returns ------- res : pd.DataFrame Index is trade_date, columns are symbols. """ # TODO: temparary api trade_dates = self.get_trade_date_range(start_date, end_date) start_date, end_date = trade_dates[0], trade_dates[-1] td = start_date dic = dict() symbols_set = set() while True: if td > end_date: break df = self.get_index_weights(index, td) update_date = df['trade_date'].iat[0] if update_date >= start_date and update_date <= end_date: symbols_set.update(set(df.index)) dic[td] = df['weight'] td = jutil.get_next_period_day(td, 'month', 1) merge = pd.concat(dic, axis=1).T merge = merge.fillna(0.0) # for those which are not components res = pd.DataFrame(index=trade_dates, columns=sorted(list(symbols_set)), data=np.nan) res.update(merge) res = res.fillna(method='ffill') res = res.loc[start_date:end_date] return res def _get_index_comp(self, index, start_date, end_date): """ Return all securities that have been in index during start_date and end_date. Parameters ---------- index : str separated by ',' start_date : int end_date : int Returns ------- list """ filter_argument = self._dic2url({ 'index_code': index, 'start_date': start_date, 'end_date': end_date }) df_io, msg = self.query("lb.indexCons", fields="", filter=filter_argument, orderby="symbol") return df_io, msg def get_index_comp(self, index, start_date, end_date): """ Return list of symbols that have been in index during start_date and end_date. Parameters ---------- index : str separated by ',' start_date : int end_date : int Returns ------- list """ df_io, msg = self._get_index_comp(index, start_date, end_date) if msg != '0,': print msg return list(np.unique(df_io.loc[:, 'symbol'])) def get_index_comp_df(self, index, start_date, end_date): """ Get index components on each day during start_date and end_date. Parameters ---------- index : str separated by ',' start_date : int end_date : int Returns ------- res : pd.DataFrame index dates, columns all securities that have ever been components, values are 0 (not in) or 1 (in) """ df_io, msg = self._get_index_comp(index, start_date, end_date) if msg != '0,': print msg def str2int(s): if isinstance(s, (str, unicode)): return int(s) if s else 99999999 elif isinstance(s, (int, np.integer, float, np.float)): return s else: raise NotImplementedError("type s = {}".format(type(s))) df_io.loc[:, 'in_date'] = df_io.loc[:, 'in_date'].apply(str2int) df_io.loc[:, 'out_date'] = df_io.loc[:, 'out_date'].apply(str2int) # df_io.set_index('symbol', inplace=True) dates = self.get_trade_date_range(start_date=start_date, end_date=end_date) dic = dict() gp = df_io.groupby(by='symbol') for sec, df in gp: mask = np.zeros_like(dates, dtype=int) for idx, row in df.iterrows(): bool_index = np.logical_and(dates > row['in_date'], dates < row['out_date']) mask[bool_index] = 1 dic[sec] = mask res = pd.DataFrame(index=dates, data=dic) return res def get_industry_daily(self, symbol, start_date, end_date, type_='SW', level=1): """ Get index components on each day during start_date and end_date. Parameters ---------- symbol : str separated by ',' start_date : int end_date : int type_ : {'SW', 'ZZ'} Returns ------- res : pd.DataFrame index dates, columns symbols values are industry code """ df_raw = self.get_industry_raw(symbol, type_=type_, level=level) dic_sec = jutil.group_df_to_dict(df_raw, by='symbol') dic_sec = { sec: df.sort_values(by='in_date', axis=0).reset_index() for sec, df in dic_sec.viewitems() } df_ann_tmp = pd.concat( {sec: df.loc[:, 'in_date'] for sec, df in dic_sec.viewitems()}, axis=1) df_value_tmp = pd.concat( { sec: df.loc[:, 'industry{:d}_code'.format(level)] for sec, df in dic_sec.viewitems() }, axis=1) idx = np.unique( np.concatenate([df.index.values for df in dic_sec.values()])) symbol_arr = np.sort(symbol.split(',')) df_ann = pd.DataFrame(index=idx, columns=symbol_arr, data=np.nan) df_ann.loc[df_ann_tmp.index, df_ann_tmp.columns] = df_ann_tmp df_value = pd.DataFrame(index=idx, columns=symbol_arr, data=np.nan) df_value.loc[df_value_tmp.index, df_value_tmp.columns] = df_value_tmp dates_arr = self.get_trade_date_range(start_date, end_date) df_industry = align.align(df_value, df_ann, dates_arr) # TODO before industry classification is available, we assume they belong to their first group. df_industry = df_industry.fillna(method='bfill') df_industry = df_industry.astype(str) return df_industry def get_industry_raw(self, symbol, type_='ZZ', level=1): """ Get daily industry of securities from ShenWanZhiShu or ZhongZhengZhiShu. Parameters ---------- symbol : str separated by ',' type_ : {'SW', 'ZZ'} level : {1, 2, 3, 4} Use which level of industry index classification. Returns ------- df : pd.DataFrame """ if type_ == 'SW': src = u'申万研究所'.encode('utf-8') if level not in [1, 2, 3, 4]: raise ValueError("For [SW], level must be one of {1, 2, 3, 4}") elif type_ == 'ZZ': src = u'中证指数有限公司'.encode('utf-8') if level not in [1, 2, 3, 4]: raise ValueError("For [ZZ], level must be one of {1, 2}") else: raise ValueError("type_ must be one of SW of ZZ") filter_argument = self._dic2url({ 'symbol': symbol, 'industry_src': src }) fields_list = [ 'symbol', 'industry{:d}_code'.format(level), 'industry{:d}_name'.format(level) ] df_raw, msg = self.query("lb.secIndustry", fields=','.join(fields_list), filter=filter_argument, orderby="symbol") if msg != '0,': print msg df_raw = df_raw.astype(dtype={ 'in_date': int, # 'out_date': int }) return df_raw.drop_duplicates() def get_adj_factor_daily(self, symbol, start_date, end_date, div=False): """ Get index components on each day during start_date and end_date. Parameters ---------- symbol : str separated by ',' start_date : int end_date : int div : bool False for normal adjust factor, True for diff. Returns ------- res : pd.DataFrame index dates, columns symbols values are industry code """ df_raw = self.get_adj_factor_raw(symbol, start_date=start_date, end_date=end_date) dic_sec = jutil.group_df_to_dict(df_raw, by='symbol') dic_sec = { sec: df.set_index('trade_date').loc[:, 'adjust_factor'] for sec, df in dic_sec.viewitems() } # TODO: duplicate codes with dataview.py: line 512 res = pd.concat(dic_sec, axis=1) # TODO: fillna ? idx = np.unique( np.concatenate([df.index.values for df in dic_sec.values()])) symbol_arr = np.sort(symbol.split(',')) res_final = pd.DataFrame(index=idx, columns=symbol_arr, data=np.nan) res_final.loc[res.index, res.columns] = res # align to every trade date s, e = df_raw.loc[:, 'trade_date'].min(), df_raw.loc[:, 'trade_date'].max() dates_arr = self.get_trade_date_range(s, e) if not len(dates_arr) == len(res_final.index): res_final = res_final.reindex(dates_arr) res_final = res_final.fillna(method='ffill').fillna(method='bfill') if div: res_final = res_final.div(res_final.shift(1, axis=0)).fillna(1.0) # res = res.loc[start_date: end_date, :] return res_final def get_adj_factor_raw(self, symbol, start_date=None, end_date=None): """ Query adjust factor for symbols. Parameters ---------- symbol : str separated by ',' start_date : int end_date : int Returns ------- df : pd.DataFrame """ if start_date is None: start_date = "" if end_date is None: end_date = "" filter_argument = self._dic2url({ 'symbol': symbol, 'start_date': start_date, 'end_date': end_date }) fields_list = ['symbol', 'trade_date', 'adjust_factor'] df_raw, msg = self.query("lb.secAdjFactor", fields=','.join(fields_list), filter=filter_argument, orderby="symbol") if msg != '0,': print msg df_raw = df_raw.astype(dtype={ 'symbol': str, 'trade_date': int, 'adjust_factor': float }) return df_raw.drop_duplicates() def query_inst_info(self, symbol, inst_type="", fields=""): if inst_type == "": inst_type = "1,2,3,4,5,101,102,103,104" filter_argument = self._dic2url({ 'symbol': symbol, 'inst_type': inst_type }) df_raw, msg = self.query("jz.instrumentInfo", fields=fields, filter=filter_argument, orderby="symbol") if msg != '0,': print msg dtype_map = { 'symbol': str, 'list_date': int, 'delist_date': int, 'inst_type': int } cols = set(df_raw.columns) dtype_map = {k: v for k, v in dtype_map.viewitems() if k in cols} df_raw = df_raw.astype(dtype=dtype_map) res = df_raw.set_index('symbol') return res # ----------------------------------------------------------------------------------- # subscribe for real time trading def subscribe(self, symbols): """ Parameters ---------- symbols : str Separated by , """ self.data_api.subscribe(symbols, func=self.mkt_data_callback) def mkt_data_callback(self, key, quote): e = Event(EVENT_TYPE.MARKET_DATA) # print quote e.dic = {'quote': quote} self.ctx.instance.put(e)
class MyTushareApi(): def __init__(self): self.universe = [] self.symbol = "" self.fields = "" self.start_date = 0 self.end_date = 0 self._data_status = 0 self._data_init_config = {} self._client = None self._db = None self._col = None self._db_list = [] self._col_list = [] self._folder_path = './log/' self.reference = { "lb.secDailyIndicator": "pe, pe_ttm, pb, ps, ps_ttm, net_assets", # 可扩展view和fileds } def _login(self, dataapi=False, mongodb=False): self._data_init_config = { "remote.data.address": "tcp://data.tushare.org:8910", "remote.data.username": "******", "remote.data.password": "******" } if dataapi: self._dataapi = DataApi( addr=self._data_init_config.get('remote.data.address')) self._dataapi.login(self._data_init_config.get( 'remote.data.username'), self._data_init_config.get('remote.data.password')) if mongodb: # 连接mongoDB self._client = pymongo.MongoClient(host='localhost', port=27017) def _logout(self): self._client.close() def _set_data_status(self): ''' 判断是首次下载,或增量更新数据 首次下载,则self._data_status = 1 增量更新,则self._data_status = 2 同时,返回去掉默认数据库后的db列表和collection列表 ''' self._db_list = self._client.list_database_names() # 去掉默认的数据库,admin和local self._db_list.remove('admin') self._db_list.remove('local') if self._db_list == []: self._data_status = 1 else: self._col_list = self._db.collection_names() self._data_status = 2 def _get_instrumentInfo(self): ''' 获取沪深A股基本资料 inst_type=1 证券类别:股票 status=1 上市状态:上市 返回: pd.DataFrame ''' df, msg = self._dataapi.query( view="jz.instrumentInfo", fields="status,list_date,delist_date,name,market,symbol", filter="inst_type=1&status=1", data_format='pandas') # 股票市场为沪市和深市,原数据包括港股 df = df[(df['market'] == 'SH') | (df['market'] == 'SZ')] return df def _set_universe(self): "获取当前日期,所有上市状态的A股代码集合" self.universe = set(self._get_instrumentInfo()['symbol']) def _set_start_end_date_first_time(self): "确定首次下载数据时的起止日期" df = self._get_instrumentInfo() list_date = df[df['symbol'] == self.symbol]['list_date'].iloc[0] delist_date = df[df['symbol'] == self.symbol]['delist_date'].iloc[0] # 判定上市、退市日期和today的关系 today = datetime.today().strftime('%Y%m%d') self.start_date = list_date self.end_date = min(today, delist_date) def _set_start_end_date_update(self): "确定增量更新数据时的起止日期" # 从数据库中获取已存数据的交易日期, 类型pd.DataFrame" trade_date = self._col.find({}, {'trade_date': 1, '_id': 0}) trade_date = [i for i in trade_date] trade_date = pd.DataFrame(trade_date) # 获取已存数据的最新交易日期,类型str" latest_date = trade_date['trade_date'].max() latest_date = str(latest_date) # 增量更新时的起止日期" today = datetime.today().strftime('%Y%m%d') latest_date = datetime.strptime(latest_date, '%Y%m%d') latest_date = latest_date + timedelta(days=1) latest_date = latest_date.strftime('%Y%m%d') if today >= latest_date: self.start_date = latest_date self.end_date = today else: print("Trade date in Dbase is latest.") def _download_data(self): df = self._dataapi.query( view=self.db, fields=self.fields, filter=self.symbol + '&' + self.start_date + '&' + self.end_date ) return df def _save_data(self): for db_name, self.fields in self.reference.items(): df = self._download_data() props = df.to_dict(orient='records') self._db = self._client[db_name] self._col = self._db[self.symbol] try: self._col.insert_many(props) print("\n\Save data to DBase...\nSymbol: {}\n\n".format(self.symbol)) except: # 将错误股票代码列表写入文件 with open(self._folder_path + 'error_list.csv', 'a') as f: f.write(symbol + ',') finally: self._logout() def main(self, symbol): self.symbol = symbol self._login(dataapi=True, mongodb=True) self._set_data_status() self._set_universe() if self._data_status == 1: self._set_start_end_date_first_time() self._save_data() if self._data_status == 2: self._set_start_end_date_update() self._save_data()