def get_index_stocks(self, index_symbol, date): assert isinstance(index_symbol, six.string_types) index_symbol = index_symbol.upper() session = get_session() idx = session.query(IndexEntity).get(index_symbol) if idx is not None: stocks = json.loads(idx.index_json) else: stocks = [] if not stocks: if index_symbol not in SecurityStore.instance().get_all_indexs( ).keys(): raise ParamsError("指数'%s'不存在" % index_symbol) else: return [] if isinstance(date, (datetime.date, datetime.datetime)): date_s = date.strftime("%Y-%m-%d") elif isinstance(date, six.string_types): date_s = date else: raise ParamsError( "date参数必须是(datetime.date, datetime.datetime, str)中的一种") ret = [] for code, periods in stocks: for i in range(0, len(periods) - 1, 2): # 不包括最后一天 if periods[i] <= date_s and date_s < periods[i + 1]: ret.append(code) break return ret
def get_concept_stocks(self, concept_code, date): assert isinstance(concept_code, six.string_types) concept_code = concept_code.upper() stocks = self._dic['stocks'].get(concept_code, []) if not stocks: if concept_code not in self._dic['name']: raise ParamsError("概念板块 '%s' 不存在" % concept_code) return [] if isinstance(date, (datetime.date, datetime.datetime)): date_s = date.strftime("%Y-%m-%d") elif isinstance(date, six.string_types): date_s = date else: raise ParamsError( "date参数必须是(datetime.date, datetime.datetime, str)中的一种") ret = [] for code, periods in stocks: if not SecurityStore.instance().exists(code): continue for i in range(0, len(periods) - 1, 2): # 不包括最后一天 if periods[i] <= date_s and date_s < periods[i + 1]: ret.append(code) break return ret
def get_concept_stocks(self, concept_code, date): assert isinstance(concept_code, six.string_types) concept_code = concept_code.upper() if isinstance(date, (datetime.date, datetime.datetime)): date_s = date.strftime("%Y-%m-%d") elif isinstance(date, six.string_types): date_s = date else: raise ParamsError( "date参数必须是(datetime.date, datetime.datetime, str)中的一种") session = get_session() count = session.query(ConceptEntity.stock).filter( ConceptEntity.code == concept_code).count() if count == 0: raise ParamsError("概念板块 '%s' 不存在" % concept_code) stocks = session.query(ConceptEntity.stock).filter( ConceptEntity.code == concept_code).filter( ConceptEntity.name != '').filter( ConceptEntity.stock_startdate <= date_s).filter( ConceptEntity.stock_enddate > date_s).distinct().all() ret = [] for s, in stocks: ret.append(s) return ret
def normalize_code(code): ''' 上海证券交易所证券代码分配规则 https://biz.sse.com.cn/cs/zhs/xxfw/flgz/rules/sserules/sseruler20090810a.pdf 深圳证券交易所证券代码分配规则 http://www.szse.cn/main/rule/bsywgz/39744233.shtml ''' if isinstance(code, int): suffix = 'XSHG' if code >= 500000 else 'XSHE' return '%06d.%s' % (code, suffix) elif isinstance(code, six.string_types): code = code.upper() if code[-5:] in ('.XSHG', '.XSHE', '.CCFX'): return code suffix = None match = re.search(r'[0-9]{6}', code) if match is None: raise ParamsError(u"wrong code={}".format(code)) number = match.group(0) if 'SH' in code: suffix = 'XSHG' elif 'SZ' in code: suffix = 'XSHE' if suffix is None: suffix = 'XSHG' if int(number) >= 500000 else 'XSHE' return '%s.%s' % (number, suffix) else: raise ParamsError(u"normalize_code(code=%s) 的参数必须是字符串或者整数" % code)
def convert_security(sec): if isinstance(sec, six.string_types): t = SecurityStore.instance().get_security(sec) if not t: raise ParamsError("找不到标的{}".format(sec)) return t elif isinstance(sec, Security): return sec else: raise ParamsError('security 必须是一个Security对象')
def get_locked_shares(stock_list, start_date=None, end_date=None, forward_count=None): ''' 获取指定日期范围内的个股限售股解禁数据 :param stock_list:单个股票或股票代码的列表 :param start_date: 开始日期 :param end_date: 结束日期 :param forward_count: 交易日数量,与 end_date 不能同时使用。与 start_date 配合使用时, 表示获取 start_date 到 start_date+count-1个交易日期间的数据 :return: dataframe |date|stock_code|num|rate1|rate2| |----------|-----------|--------|----|----| |2017-07-01|000001.XSHG|20000000|0.03|0.02| |2017-07-01|000001.XSHG|20000000|0.03|0.02| #### 注意单日个股多条解禁数据的问题 #### ''' import pandas as pd from six import StringIO from ..utils.utils import convert_date, is_lists from ..db_utils import query, request_mysql_server if forward_count is not None and end_date is not None: raise ParamsError("get_locked_shares 不能同时指定 end_date 和 forward_count 两个参数") if forward_count is None and end_date is None: raise ParamsError("get_locked_shares 必须指定 end_date 或 forward_count 之一") start_date = convert_date(start_date) if not is_lists(stock_list): stock_list = [stock_list] if stock_list: stock_list = [s.split('.')[0] for s in stock_list] if forward_count is not None: end_date = start_date + datetime.timedelta(days=forward_count) end_date = convert_date(end_date) q = query(StkLockShares.day, StkLockShares.code, StkLockShares.num, StkLockShares.rate1, StkLockShares.rate2).filter(StkLockShares.code.in_(stock_list), StkLockShares.day <= end_date, StkLockShares.day >= start_date ).order_by(StkLockShares.day).order_by(StkLockShares.code.desc()) sql = compile_query(q) cfg = get_config() if os.getenv('JQENV') == 'client': # 客户端 csv = request_mysql_server(sql) else: if not cfg.FUNDAMENTALS_SERVERS: raise RuntimeError( "you must config FUNDAMENTALS_SERVERS for jqdata") sql_runner = get_sql_runner( server_name='fundamentals', keep_connection=cfg.KEEP_DB_CONNECTION, retry_policy=cfg.DB_RETRY_POLICY, is_random=False) csv = sql_runner.run(sql, return_df=False) dtype_dict = {} dtype_dict['code'] = str df = pd.read_csv(StringIO(csv), dtype=dtype_dict) return df pass
def get_billboard_list(stock_list=None, start_date=None, end_date=None, count=None): ''' 返回执指定日期区间内的龙虎榜个股列表 :param stock_list:单个股票或股票代码列表, 可以为 None, 返回股票的列表。 :param start_date: 开始日期 :param end_date: 结束日期 :param count: 交易日数量,与 end_date 不能同时使用。与 start_date 配合使用时, 表示获取 start_date 到 start_date+count-1个交易日期间的数据 :return:Dataframe | date | stock_code | abnormal_code | abnormal_name | sales_depart_name | abnormal_type | buy_value | buy_rate | sell_value | sell_rate | net_value | amount | |----------|------------|---------------|--------------------------|-------------------|---------------|-----------|----------|------------|-----------|-----------|--------| |2017-07-01| 000038.XSHE| 1 |日价格涨幅偏离值达7%以上的证券| None | ALL | 35298494 |0.37108699| 32098850 | 0.33744968| 3199644 |95121886| ''' import pandas as pd from ..utils.utils import convert_date, is_lists from ..db_utils import query, request_mysql_server if count is not None and start_date is not None: raise ParamsError("get_billboard_list 不能同时指定 start_date 和 count 两个参数") if count is None and start_date is None: raise ParamsError("get_billboard_list 必须指定 start_date 或 count 之一") end_date = convert_date(end_date) if end_date else datetime.date.today() start_date = convert_date(start_date) if start_date else \ (get_trade_days(end_date=end_date, count=count)[0] if count else TRADE_MIN_DATE) if not is_lists(stock_list): if stock_list is not None: stock_list = [stock_list] if stock_list: stock_list = [s.split('.')[0] for s in stock_list] q = query(StkAbnormal).filter(StkAbnormal.day <= end_date, StkAbnormal.day >= start_date) if stock_list is not None: q = q.filter(StkAbnormal.code.in_(stock_list)) q = q.order_by(StkAbnormal.day.desc()).order_by(StkAbnormal.code.desc()) sql = compile_query(q) cfg = get_config() if os.getenv('JQENV') == 'client': # 客户端 csv = request_mysql_server(sql) dtype_dict = {} dtype_dict['code'] = str df = pd.read_csv(six.StringIO(csv), dtype=dtype_dict) else: if not cfg.FUNDAMENTALS_SERVERS: raise RuntimeError( "you must config FUNDAMENTALS_SERVERS for jqdata") sql_runner = get_sql_runner( server_name='fundamentals', keep_connection=cfg.KEEP_DB_CONNECTION, retry_policy=cfg.DB_RETRY_POLICY, is_random=False) df = sql_runner.run(sql, return_df=True) return df pass
def get_extras(info, security_list, start_date=None, end_date=None, df=True, count=None): assert info in ('is_st', 'acc_net_value', 'unit_net_value', 'futures_sett_price', 'futures_positions') securities = list_or_str(security_list) securities = convert_security(securities) if start_date and count: raise ParamsError("start_date 参数与 count 参数只能二选一") if not (count is None or count > 0): raise ParamsError("count 参数需要大于 0 或者为 None") if count is not None: count = int(count) end_date = convert_date(end_date) if end_date else convert_date( '2015-12-31') from jqdata.stores import FundStore, StStore, FuturesStore, CalendarStore if start_date: start_date = convert_date(start_date) elif count: ix = CalendarStore.instance().get_trade_days_between( datetime.date(2005, 1, 4), end_date) start_date = ix[-count] else: start_date = convert_date('2015-01-01') df = bool(df) dates = CalendarStore.instance().get_trade_days_between( start_date, end_date) values = {} if info == 'is_st': for s in securities: values[s.code] = StStore.instance().query(s, dates) elif info in ('acc_net_value', 'unit_net_value'): for s in securities: values[s.code] = FundStore.instance().query(s, dates, info) elif info in ('futures_sett_price', 'futures_positions'): for s in securities: values[s.code] = FuturesStore.instance().query(s, dates, info) if df: columns = [s.code for s in securities] ret = dict(index=vec2combine(dates), columns=columns, data=values) ret = pd.DataFrame(**ret) return ret else: return values
def convert_dt(dt): """ >>> convert_dt(datetime.date(2015, 1, 1)) datetime.datetime(2015, 1, 1, 0, 0) >>> convert_dt(datetime.datetime(2015, 1, 1)) datetime.datetime(2015, 1, 1, 0, 0) >>> convert_dt('2015-1-1') datetime.datetime(2015, 1, 1, 0, 0) >>> convert_dt('2015-01-01 09:30:00') datetime.datetime(2015, 1, 1, 9, 30) >>> convert_dt(datetime.datetime(2015, 1, 1, 9, 30)) datetime.datetime(2015, 1, 1, 9, 30) """ if is_str(dt): if ':' in dt: return datetime.datetime.strptime(dt, '%Y-%m-%d %H:%M:%S') else: return datetime.datetime.strptime(dt, '%Y-%m-%d') elif isinstance(dt, datetime.datetime): return dt elif isinstance(dt, datetime.date): return date2dt(dt) raise ParamsError( "date 必须是datetime.date, datetime.datetime或者如下格式的字符串:'2015-01-05'")
def query(self, security, dates, field): n = len(dates) if n == 0: return np.array([]) ct = self.open_table(security) start_ts = to_timestamp(dates[0]) end_ts = to_timestamp(dates[-1]) start_idx = ct.index.searchsorted(start_ts) end_idx = ct.index.searchsorted(end_ts, 'right') - 1 if end_idx < start_idx: return np.array([np.nan] * n) if field == 'acc_net_value': name = 'acc' elif field == 'unit_net_value': name = 'unit' else: raise ParamsError( "field should in (acc_net_value, unit_net_value)") ret = np.round(ct.table.cols[name][start_idx:end_idx + 1], security.price_decimals) if len(ret) < n: dates = list(dates) raw_list = [np.nan] * n ret_idx = 0 for idx in range(start_idx, end_idx + 1): date = to_date(ct.table.cols['date'][idx]) day_idx = dates.index(date) raw_list[day_idx] = ret[ret_idx] ret_idx += 1 ret = np.array(raw_list) return ret
def get_industries(self, name): assert isinstance(name, six.string_types) if name in self._dic['lists']: return self._dic['lists'][name] else: raise ParamsError("name 参数必须是zjw/jq_l1/jq_l2/sw_l1/sw_l2/sw_l3中的一种") pass
def get_ticks(security, end_dt, start_dt=None, count=None, fields=['time', 'current', 'high', 'low', 'volume', 'money']): from jqdata.stores.tick_store import get_tick_store from jqdata.utils.security import convert_security from jqdata.utils.datetime_utils import parse_datetime security = convert_security(security) end_dt = parse_datetime(end_dt) if start_dt is None and count is None: raise ParamsError("start_dt和count不能同时为None") elif start_dt is not None and count is not None: raise ParamsError("start_dt和count只能有一个不为None") if start_dt is not None: start_dt = parse_datetime(start_dt) if count is not None: count = int(count) assert count > 0, "get_ticks, count必须是一个正整数" store = get_tick_store() table = store.get_table(security) idx = table.find_great_or_equal(end_dt) if start_dt is not None: if start_dt > end_dt: raise ParamsError("start_dt 必须小于等于 end_dt") start = table.find_great_or_equal(start_dt) elif count is not None: start = max(0, idx - count) arr = table.array[start:idx] ret = {} for f in fields: if f in ('current', 'high', 'low', 'a1_p', 'b1_p'): ret[f] = arr[f] / 10000. elif f == 'time': ret[f] = arr[f] / 1000. else: ret[f] = arr[f] / 1.0 dtype = np.dtype([(str(f), ret[f].dtype) for f in fields]) cols = [ret[f] for f in fields] result = np.rec.fromarrays(cols, dtype=dtype).view(np.ndarray) return result
def check_date(date): if isinstance(date, (datetime.date, datetime.datetime)): date_s = date.strftime("%Y-%m-%d") elif isinstance(date, six.string_types): date_s = date else: raise ParamsError("date参数必须是(datetime.date, datetime.datetime, str)中的一种") return date_s pass
def convert_security_list(sec_list): if isinstance(sec_list, six.string_types): sec = convert_security(sec_list) return [sec] elif isinstance(sec_list, Security): return [sec_list] elif isinstance(sec_list, (list, tuple, set)): return [convert_security(o) for o in sec_list] else: raise ParamsError('security 必须是一个Security实例或者数组')
def get_trade_days(start_date=None, end_date=None, count=None): if start_date and count: raise ParamsError("start_date 参数与 count 参数只能二选一") if not (count is None or count > 0): raise ParamsError("count 参数需要大于 0 或者为 None") if not end_date: end_date = datetime.date.today() else: end_date = parse_date(end_date) store = get_calendar_store() if start_date: start_date = parse_date(start_date) return store.get_trade_days_between(start_date, end_date) elif count is not None: return store.get_trade_days_by_count(end_date, count) else: raise ParamsError("start_date 参数与 count 参数必须输入一个")
def convert_security(s): if isinstance(s, six.string_types): t = SecurityStore.instance().get_security(s) if not t: raise ParamsError("找不到标的{}".format(s)) return t elif isinstance(s, Security): return s elif isinstance(s, (list, tuple)): res = [] for i in range(len(s)): if isinstance(s[i], Security): res.append(s[i]) elif isinstance(s[i], six.string_types): t = SecurityStore.instance().get_security(s[i]) if not t: raise ParamsError("找不到标的{}".format(s[i])) res.append(t) else: raise ParamsError("找不到标的{}".format(s[i])) return res else: raise ParamsError('security 必须是一个Security实例或者数组')
def get_factor_by_date(self, code, date): factors = self._dic.get(code) if not factors: return 1.0 if isinstance(date, (datetime.date, datetime.datetime)): date_s = date.strftime("%Y-%m-%d") elif isinstance(date, six.string_types): date_s = date else: raise ParamsError( "date参数必须是(datetime.date, datetime.datetime, str)中的一种") for i in range(len(factors) - 1, -1, -1): if factors[i][0] <= date_s: return factors[i][1] return 1.0
def get_history_name(self, code, date): names = self._dic.get(code) if not names: return '' if isinstance(date, (datetime.date, datetime.datetime)): date_s = date.strftime("%Y-%m-%d") elif isinstance(date, six.string_types): date_s = date else: raise ParamsError( "date参数必须是(datetime.date, datetime.datetime, str)中的一种") for i in range(len(names) - 1, -1, -1): if names[i][0] <= date_s: return names[i][1] return ''
def get_industry_stocks(self, industry_code, date): '''获取行业代码在指定日期的股票列表''' assert isinstance(industry_code, six.string_types) industry_code = industry_code.upper() date_s = check_date(date) session = get_session() count = session.query(IndustryEntity.code).filter(IndustryEntity.code == industry_code).count() if count == 0: raise ParamsError("行业板块 '%s' 不存在" % industry_code) stocks = session.query(IndustryEntity.stock).filter(IndustryEntity.code == industry_code).filter( IndustryEntity.stock_startdate <= date_s).filter(IndustryEntity.stock_enddate > date_s).filter( IndustryEntity.stock != '').distinct().all() ret = [] for s, in stocks: ret.append(s) return ret
def get_margin_stocks(self, day): if isinstance(day, (datetime.date, datetime.datetime)): day_s = day.strftime("%Y-%m-%d") elif isinstance(day, six.string_types): day_s = day else: raise ParamsError( "date参数必须是(datetime.date, datetime.datetime, str)中的一种") session = get_session() ms = session.query(MarginStockEntity).filter( MarginStockEntity.margin_date <= day_s).order_by( MarginStockEntity.margin_date.desc()).first() if ms is not None: res = json.loads(ms.margin_json) else: res = None if res: return res return []
def get_fundamentals(query_object=None, date=None, statDate=None, sql=None): # noqa if query_object is None and sql is None: raise ParamsError("get_fundamentals 至少输入 query_object 或者 sql 参数") cfg = get_config() if date: date = convert_date(date) if query_object: sql = fundamentals_query_to_sql(query_object, date, statDate) check_string(sql) if os.getenv('JQENV') == 'client': # 客户端 from jqdata.db_utils import request_mysql_server csv = request_mysql_server(sql) else: if not cfg.FUNDAMENTALS_SERVERS: raise RuntimeError("you must config FUNDAMENTALS_SERVERS for jqdata") sql_runner = get_sql_runner( server_name='fundamentals', keep_connection=cfg.KEEP_DB_CONNECTION, retry_policy=cfg.DB_RETRY_POLICY, is_random=False) # return csv 在转成 DataFrame, 跟kaunke保持兼容, 防止直接return df 跟以前不一样 csv = sql_runner.run(sql, return_df=False) return pd.read_csv(StringIO(csv))
def get_industry_stocks(self, industry_code, date): '''获取行业代码在指定日期的股票列表''' assert isinstance(industry_code, six.string_types) stocks = self._dic['stocks'].get(industry_code.upper()) if not stocks: if industry_code not in self._dic['codes']['csrc'] and \ industry_code not in self._dic['codes']['wind'] and \ industry_code not in self._dic['codes']['sw']: raise ParamsError(u"行业'%s'不存在" % industry_code) return [] ret = [] date_s = check_date(date) for code, periods in stocks: if not SecurityStore.instance().exists(code): continue for i in range(0, len(periods) - 1, 2): # 不包括最后一天 if periods[i] <= date_s and date_s < periods[i + 1]: ret.append(code) break return ret
def get_fundamentals_continuously(query_object=None, end_date=None, count=1): ''' query_object:查询对象 end_date:查询财务数据的截止日期 count:查询财务数据前溯天数,默认为1 返回一个pd.Panel, 三维分别是 field, date, security. field: 下面的表的中属性 https://www.joinquant.com/data/dict/fundamentals ''' if query_object is None: raise ParamsError("get_fundamentals_continuously 需要输入 query_object 参数") if end_date is None: end_date = datetime.date.today() cfg = get_config() trade_day = get_trade_days(end_date=end_date, count=count) if query_object: sql = fundamentals_continuously_query_to_sql(query_object, trade_day) check_string(sql) # 调用查询接口生成CSV格式字符串 if os.getenv('JQENV') == 'client': # 客户端 from jqdata.db_utils import request_mysql_server csv = request_mysql_server(sql) else: if not cfg.FUNDAMENTALS_SERVERS: raise RuntimeError("you must config FUNDAMENTALS_SERVER for jqdata") sql_runner = get_sql_runner( server_name='fundamentals', keep_connection=cfg.KEEP_DB_CONNECTION, retry_policy=cfg.DB_RETRY_POLICY, is_random=False) csv = sql_runner.run(sql, return_df=False) # 转换成panel,设置时间和股票code为索引 df = pd.read_csv(StringIO(csv)) df = df.drop_duplicates() newdf = df.set_index(['day', 'code']) pan = newdf.to_panel() return pan pass
def convert_date(date): """ >>> convert_date('2015-1-1') datetime.date(2015, 1, 1) >>> convert_date('2015-01-01 00:00:00') datetime.date(2015, 1, 1) >>> convert_date(datetime.datetime(2015, 1, 1)) datetime.date(2015, 1, 1) >>> convert_date(datetime.date(2015, 1, 1)) datetime.date(2015, 1, 1) """ if is_str(date): if ':' in date: date = date[:10] return datetime.datetime.strptime(date, '%Y-%m-%d').date() elif isinstance(date, datetime.datetime): return date.date() elif isinstance(date, datetime.date): return date raise ParamsError( "date 必须是datetime.date, datetime.datetime或者如下格式的字符串:'2015-01-05'")
def query(self, security, dates, field): n = len(dates) if n == 0: return np.array([]) ct = self.open_table(security) start_ts = to_timestamp(dates[0]) end_ts = to_timestamp(dates[-1]) start_idx = ct.index.searchsorted(start_ts) end_idx = ct.index.searchsorted(end_ts, 'right') - 1 if end_idx < start_idx: return np.array([np.nan] * n) if field == 'futures_sett_price': name = 'settlement' elif field == 'futures_positions': name = 'open_interest' else: raise ParamsError("filed should in (futures_sett_price, open_interest)") ret = np.round(ct.table.cols[name][start_idx:end_idx+1], security.price_decimals) if len(ret) < n: st = to_date(ct.table.cols['date'][start_idx]) et = to_date(ct.table.cols['date'][end_idx]) for i in range(0, n): if dates[i] >= st: break if i > 0: ret = np.concatenate([np.array([np.nan]*i), ret]) for i in range(n-1, -1, -1): if dates[i] <= et: break if i < n - 1: ret = np.concatenate([ret, np.array([np.nan] * (n - i - 1))]) return ret
def get_mtss(security_list, start_date=None, end_date=None, fields=None, count=None): """ 获取融资融券信息 security_list: 股票代码或者 list start_date: 开始日期, **与 count 二选一, 不可同时使用**. str/datetime.date/datetime.datetime 对象, 默认为平台提供的数据的最早日期 end_date: 结束日期, str/datetime.date/datetime.datetime 对象, 默认为 datetime.date.today() fields: 字段名或者 list, 可选, 默认全部字段 count: 数量, **与 start_date 二选一,不可同时使用**. 表示返回 end_date 之前 count 个交易日的数据, 包含 end_date 返回pd.DataFrame, columns: 日期,股票代码, 融资余额,融资买入额,融资偿还额,融券余额,融资卖出额,融资偿还额,融资融券余额 date, sec_code, fin_value, fin_buy_value, fin_refund_value, sec_value, sec_sell_value, sec_refund_value, fin_sec_value """ if start_date and count: raise ParamsError("start_date 参数与 count 参数只能二选一") if not (count is None or count > 0): raise ParamsError("count 参数需要大于 0 或者为 None") if count: count = int(count) security_list = obj_to_tuple(security_list) check_string_list(security_list) end_date = convert_date(end_date) if end_date else datetime.date.today() start_date = convert_date(start_date) if start_date else \ (get_trade_days(end_date=end_date, count=count)[0] if count else TRADE_MIN_DATE) keys = ["sec_code", "date", "fin_value", "fin_buy_value", "fin_refund_value", "sec_value", "sec_sell_value", "sec_refund_value", "fin_sec_value"] nkeys = len(keys) if fields: fields = obj_to_tuple(fields) check_string_list(fields) check_fields(keys, fields) else: fields = ["date", "sec_code"] + keys[2:] request_path = "/stock/mtss/query" request_params = { "code": "", "startDate": start_date, "endDate": end_date, } lists = [] convert_funcs = [str, convert_dt, float, float, float, float, float, float, float] for security in security_list: request_params["code"] = security if os.getenv('JQENV') == 'client': # 客户端 data = request_client_data(request_path, request_params) else: data = request_data(DATA_SERVER + request_path, request_params) for d in data: values = [item.strip() for item in d.split(",", nkeys)[:nkeys]] values = [convert_funcs[i](v) for i, v in enumerate(values)] sec_dict = dict(zip(keys, values)) sec_dict["sec_code"] = normalize_code(sec_dict["sec_code"]) lists.append(filter_dict_values(sec_dict, fields)) import pandas as pd df = pd.DataFrame(columns=fields, data=lists) return df
def history(end_dt, count, unit='1d', field='avg', security_list=None, df=True, skip_paused=False, fq='pre', pre_factor_ref_date=None): ''' 只能在回测/模拟中调用,不能再研究中调用。 参数说明: `unit`是'Xd'时, end_dt的类型是datetime.date; `unit`是'Xm'时, end_dt的类型是datetime.datetime; `count` 必须大于0; `security_list`: 必须是 str 或者 tuple,不能是list 否则 lru_cache 会出错。 ''' count = int(count) assert count > 0, "history, count必须是一个正整数" check_unit_fields(unit, (field, )) if security_list is not None: security_list = list_or_str(security_list) if isinstance(security_list, tuple): security_list = list(security_list) security_list = convert_security(security_list) if field == 'price': warn_price_as_avg('使用 price 作为 history 的 field 参数', 'history') field = 'avg' group = int(unit[:-1]) total = count * group dict_by_column = {} _index = None df = bool(df) skip_paused = bool(skip_paused) fq = ensure_fq(fq) if pre_factor_ref_date is not None: pre_factor_ref_date = convert_date(pre_factor_ref_date) need_index = df and not skip_paused if is_list(security_list) and _has_stock_with_future(security_list): if unit.endswith('m') and need_index: raise ParamsError("history 取分钟数据时,为了对齐数据,不能同时取股票和期货。") if unit.endswith('d'): end_dt = convert_date(end_dt) for security in security_list: a, _index = get_price_daily_single( security, end_date=end_dt, count=total, fields=(field, ), skip_paused=skip_paused, fq=fq, include_now=False, pre_factor_ref_date=pre_factor_ref_date) a = a[field] a = group_array(a, group, field) dict_by_column[security.code] = a else: end_dt = convert_dt(end_dt) for security in security_list: a, _index = get_price_minute_single( security, end_dt=end_dt, count=total, fields=(field, ), skip_paused=skip_paused, fq=fq, include_now=False, pre_factor_ref_date=pre_factor_ref_date) # 取第一列 a = a[field] a = group_array(a, group, field) dict_by_column[security.code] = a if not df: return dict_by_column else: if need_index and _index is not None and len(_index) > 0: index = group_array(_index, group, 'index') index = vec2datetime(index) else: index = None return pd.DataFrame(index=index, columns=[s.code for s in security_list], data=dict_by_column)
def get_price(security, start_date=None, end_date=None, frequency='daily', fields=None, skip_paused=False, fq='pre', count=None, pre_factor_ref_date=None): security = convert_security(security) if count is not None and start_date is not None: raise ParamsError("get_price 不能同时指定 start_date 和 count 两个参数") if count is not None: count = int(count) end_dt = convert_dt(end_date) if end_date else datetime.datetime( 2015, 12, 31) end_dt = min(end_dt, date2dt(CalendarStore.instance().last_day)) start_dt = convert_dt(start_date) if start_date else datetime.datetime( 2015, 1, 1) start_dt = max(start_dt, date2dt(CalendarStore.instance().first_day)) if pre_factor_ref_date: pre_factor_ref_date = convert_date(pre_factor_ref_date) if frequency in frequency_compat: unit = frequency_compat.get(frequency) else: unit = frequency if fields is not None: fields = ensure_str_tuple(fields) if 'price' in fields: warn_price_as_avg('使用 price 作为 get_price 的 fields 参数', 'getprice') else: fields = tuple(DEFAULT_FIELDS) check_unit_fields(unit, fields) fq = ensure_fq(fq) skip_paused = bool(skip_paused) if is_list(security) and skip_paused: raise ParamsError("get_price 取多只股票数据时, 为了对齐日期, 不能跳过停牌") if is_list(security) and _has_stock_with_future(security): if unit.endswith('m'): raise ParamsError("get_price 取分钟数据时,为了对齐数据,不能同时取股票和期货。") group = int(unit[:-1]) res = {} for s in (security if is_list(security) else [security]): if unit.endswith('d'): a, index = get_price_daily_single( s, end_date=end_dt.date(), start_date=start_dt.date() if start_dt else None, count=count * group if count is not None else None, fields=fields, skip_paused=skip_paused, fq=fq, include_now=True, pre_factor_ref_date=pre_factor_ref_date) else: a, index = get_price_minute_single( s, end_dt=end_dt, start_dt=start_dt, count=count * group if count is not None else None, fields=fields, skip_paused=skip_paused, fq=fq, include_now=True, pre_factor_ref_date=pre_factor_ref_date) # group it dict_by_column = { f: group_array(a[f if f != 'price' else 'avg'], group, f) for f in fields } if index is not None and len(index) > 0: index = group_array(index, group, 'index') index = vec2datetime(index) res[s.code] = dict(index=index, columns=fields, data=dict_by_column) if is_list(security): fields = fields or DEFAULT_FIELDS if len(security) == 0: return pd.Panel(items=fields) pn_dict = {} index = res[security[0].code]['index'] for f in fields: df_dict = {s.code: res[s.code]['data'][f] for s in security} pn_dict[f] = pd.DataFrame(index=index, columns=[s.code for s in security], data=df_dict) return pd.Panel(pn_dict) else: return pd.DataFrame(**res[security.code])
def get_bars(end_dt, security, count, unit='1d', fields=('open', 'high', 'low', 'close'), include_now=False, fq=None, pre_factor_ref_date=None): ''' :param end_dt: 截止日期 :param security: 标的 :param count: bar个数 :param unit: 频率,'1d'表示1天,'xm'表示x分钟。 :param fields: :param include_now: :param fq: 'pre'表示前复权, 'post'表示后复权, None表示真实价格。 :param pre_factor_ref_date: 前复权基准日期,这一天的价格为真实价格。None则表示全部取真实价格。 :return: ''' valid_bar_fields = ('date', 'open', 'close', 'high', 'low', 'volume', 'money') if isinstance(fields, (list, tuple)): for f in fields: assert f in valid_bar_fields, "get_bars 只支持 %s 字段" % ( valid_bar_fields) str_field = False elif isinstance(fields, six.string_types): assert fields in valid_bar_fields, "get_bars 只支持 %s 字段" % ( valid_bar_fields) str_field = True else: raise ParamsError("fields 应该是字符串或者list") if str_field: new_fields = [fields] else: new_fields = [i for i in fields] if 'factor' not in fields: new_fields.append('factor') end_dt = convert_dt(end_dt) valid_unit = ('1m', '5m', '15m', '30m', '60m', '120m', '1d', '1w', '1M') assert unit in valid_unit, 'get_bars, unit必须是 %s 中一种' % valid_unit count = int(count) assert count > 0, "get_bars, count必须是一个正整数" fq = ensure_fq(fq) security = convert_security(security) include_now = bool(include_now) end_trade_date = CalendarStore.instance().get_current_trade_date( security, end_dt) def ensure_not_empty(cols_dict): if cols_dict == {}: ret = {} for f in valid_bar_fields: ret[f] = np.zeros(0) return ret return cols_dict if unit == '1d': if include_now: # 获取当天的snapshot snapshot = get_snapshot(security, end_trade_date, end_dt) if snapshot: cols_dict = get_daily_bar_by_count(security, end_trade_date, count - 1, new_fields, include_now=False) cols_dict = ensure_not_empty(cols_dict) for f in cols_dict: cols_dict[f] = np.append(cols_dict[f], snapshot[f]) else: cols_dict = get_daily_bar_by_count(security, end_trade_date, count, new_fields, include_now=False) cols_dict = ensure_not_empty(cols_dict) else: cols_dict = get_daily_bar_by_count(security, end_trade_date, count, new_fields, include_now=False) cols_dict = ensure_not_empty(cols_dict) elif unit == '1m': end_dt = convert_dt(end_dt) cols_dict = get_minute_bar_by_count(security, end_dt, count, new_fields, include_now=include_now) cols_dict = ensure_not_empty(cols_dict) elif unit in ('5m', '15m', '30m', '60m', '120m'): x = int(unit[:-1]) if security.is_futures(): trade_days = CalendarStore.instance().get_all_trade_days(security) trade_days = trade_days[(trade_days >= security.start_date )&\ (trade_days <= end_trade_date)&\ (trade_days <= security.end_date)] cols_dict = {f: np.zeros(0) for f in new_fields} for idx in range(len(trade_days) - 1, -1, -1): open_dt = CalendarStore.instance().get_open_dt( security, trade_days[idx]) if trade_days[idx] == end_trade_date: if not include_now: close_dt = _not_include_now(security, end_trade_date, end_dt, unit) else: close_dt = end_dt else: close_dt = CalendarStore.instance().get_close_dt( security, trade_days[idx]) tmp_dict = get_minute_bar_by_period(security, open_dt, close_dt, new_fields, include_now=True) if not tmp_dict or len(tmp_dict[new_fields[0]] == 0): continue tmp_dict = _resample_future_xm_bars(tmp_dict, x) for col in cols_dict: cols_dict[col] = np.append(tmp_dict[col], cols_dict[f]) if len(cols_dict[new_fields[0]]) >= count: break for f in cols_dict: cols_dict[f] = cols_dict[f][-count:] else: cols_dict = {f: np.zeros(0) for f in new_fields} open_dt = CalendarStore.instance().get_open_dt( security, end_dt.date()) if not include_now: close_dt = _not_include_now(security, end_trade_date, end_dt, unit) else: close_dt = end_dt tmp_dict = get_minute_bar_by_period(security, open_dt, close_dt, new_fields, include_now=True) if tmp_dict and len(tmp_dict[new_fields[0]]) > 0: tmp_dict = _resample_simple_xm_bars(tmp_dict, x) for col in cols_dict: cols_dict[col] = np.append(tmp_dict[col], cols_dict[col]) need_count = count - len(cols_dict[new_fields[0]]) if need_count > 0: tmp_dict = get_minute_bar_by_count(security, open_dt, need_count * x, new_fields, include_now=False) tmp_dict = _resample_simple_xm_bars(tmp_dict, x) for col in cols_dict: cols_dict[col] = np.append(tmp_dict[col], cols_dict[col]) for f in cols_dict: cols_dict[f] = cols_dict[f][-count:] # 周线和月线必须先复权,然后 resample elif unit == '1w': if include_now: snapshot = get_snapshot(security, end_trade_date, end_dt) cols_dict = get_daily_bar_by_count(security, end_trade_date, count * 5, new_fields, include_now=False) cols_dict = ensure_not_empty(cols_dict) if snapshot: for f in cols_dict: cols_dict[f] = np.append(cols_dict[f], snapshot[f]) else: # monday == 0 ... Sunday == 6 weekday = end_trade_date.weekday() last_sunday = end_trade_date - datetime.timedelta(weekday + 1) cols_dict = get_daily_bar_by_count(security, last_sunday, count * 5, new_fields, include_now=False) cols_dict = ensure_not_empty(cols_dict) cols_dict = _pre_fq(security, cols_dict, pre_factor_ref_date) cols_dict = _resample_days_bars(cols_dict, unit) for f in cols_dict: cols_dict[f] = cols_dict[f][-count:] elif unit == '1M': if include_now: snapshot = get_snapshot(security, end_trade_date, end_dt) cols_dict = get_daily_bar_by_count(security, end_trade_date, count * 31, new_fields, include_now=False) cols_dict = ensure_not_empty(cols_dict) if snapshot: for f in cols_dict: cols_dict[f] = np.append(cols_dict[f], snapshot[f]) else: end_date = end_trade_date.replace(day=1) - datetime.timedelta( days=1) cols_dict = get_daily_bar_by_count(security, end_date, count * 31, new_fields, include_now=False) cols_dict = ensure_not_empty(cols_dict) cols_dict = _pre_fq(security, cols_dict, pre_factor_ref_date) cols_dict = _resample_days_bars(cols_dict, unit) for f in cols_dict: cols_dict[f] = cols_dict[f][-count:] else: raise ParamsError("get_bars 支持 '1m', '1d'") # 将时间戳转换成datetime 或者 date。 if 'date' in cols_dict: if unit in ('1d', '1w', '1M'): cols_dict['date'] = vec2date(cols_dict['date']) else: cols_dict['date'] = vec2datetime(cols_dict['date']) # 期货没有复权。 # 周线和月线必须先复权(同一个周期内复权因子可能不同) if not security.is_futures() and unit not in ('1w', '1M'): if pre_factor_ref_date is not None: cols_dict = _pre_fq(security, cols_dict, pre_factor_ref_date) if str_field: dtype = np.dtype([(fields, cols_dict[fields].dtype)]) cols = [cols_dict[fields]] result = np.rec.fromarrays(cols, dtype=dtype).view(np.ndarray) else: # numpy bug: name 不能为unicode。 dtype = np.dtype([(str(name), cols_dict[name].dtype) for name in fields]) cols = [cols_dict[name] for name in fields] result = np.rec.fromarrays(cols, dtype=dtype).view(np.ndarray) return result
def calc_stock_minute_index(security, date_index, dt, side='left', include_now=True): ''' # 计算dt 在分钟数据中的索引(只对股票有效) 如果 dt 不存在 date_index , side=left 表示返回左边的index, side=right 表示返回右边的index -1 表示没有 include_now 表示是否包含 dt ''' if isinstance(dt, datetime): ts = to_timestamp(dt.date()) elif isinstance(dt, pd.Timestamp): ts = int(dt.value / (10**9) / 86400) * 86400 else: raise ParamsError("wrong dt=%s, type(dt)=%s" % (dt, type(dt))) total_days = date_index.searchsorted(ts, side='right') - 1 if total_days < 0: return -1 if date_index[total_days] == ts: trading_day = True else: trading_day = False total_days += 1 # 前面一共有total_days个交易日 dt_minutes = dt.hour * 60 + dt.minute total_minutes = 0 if trading_day: # 9:31 之前 if dt_minutes < (9 * 60 + 31): if side == 'left': total_minutes = 0 else: total_minutes = 1 elif dt_minutes < (11 * 60 + 31): total_minutes = dt_minutes - (9 * 60 + 30) if not include_now: total_minutes -= 1 elif dt_minutes < 13 * 60 + 1: if side == 'left': total_minutes = 120 else: total_minutes = 121 elif dt_minutes < 15 * 60 + 1: total_minutes = 120 + dt_minutes - 13 * 60 if not include_now: total_minutes -= 1 else: if side == 'left': total_minutes = 240 else: total_minutes = 241 else: if side == 'left': total_minutes = 0 else: total_minutes = 1 # 下标从0开始。 -1 表示 没有。 return (total_days * 240 + total_minutes) - 1