def get_bar_by_period(self, security, start_date, end_date, fields, include_now=True): ''' 获取一支股票的天数据。 `security`: 股票代码,例如 '000001.XSHE'。 `start_date`: 开始日期, 例如 datetime.date(2015, 1, 1)。 `end_date`: 结束日期,例如 datetime.date(2016, 12, 30)。 `fields`: 行情数据字段。 ''' ct = self.open_bcolz_table(security) start_ts = to_timestamp(start_date) end_ts = to_timestamp(end_date) start_idx = ct.index.searchsorted(start_ts) end_idx = ct.index.searchsorted(end_ts, 'right') - 1 if not include_now and end_idx >= 0 and ct.index[end_idx] == end_ts: end_idx -= 1 if end_idx < start_idx: return {} data = { name: ct.table.cols[name][start_idx:end_idx + 1] for name in fields } data['date'] = ct.table.cols['date'][start_idx:end_idx + 1] return data
def get_minute_by_period(self, security, start_dt, end_dt, include_now=True): ct = self.open_bcolz_table(security) if ct is None: return _EMPTY_NP_ARRAY if security.is_futures(): start_ts = to_timestamp(start_dt) end_ts = to_timestamp(end_dt) start_idx = ct.index.searchsorted(start_ts) end_idx = ct.index.searchsorted(end_ts, 'right') - 1 if not include_now and end_idx >= 0 and ct.index[end_idx] == end_ts: end_idx -= 1 else: start_idx = calc_stock_minute_index(security, ct.index, start_dt, side='right') end_idx = calc_stock_minute_index(security, ct.index, end_dt, include_now=include_now) if start_idx > end_idx: return _EMPTY_NP_ARRAY if security.is_futures(): return ct.index[start_idx:end_idx + 1] else: return ct.table.cols['date'][start_idx:end_idx + 1]
def _get_idx_by_period(self, security, start_date, end_date, include_end): ct = self.open_shm_block(security) start_ts = to_timestamp(start_date) end_ts = to_timestamp(end_date) start_idx = ct.index.searchsorted(start_ts) stop_idx = ct.index.searchsorted(end_ts, side='right' if include_end else 'left') return (ct, slice(start_idx, stop_idx))
def query(self, security, dates, field): n = len(dates) if n == 0: return np.array([]) ct = self.open_table(security) start_ts = to_timestamp(dates[0]) end_ts = to_timestamp(dates[-1]) start_idx = ct.index.searchsorted(start_ts) end_idx = ct.index.searchsorted(end_ts, 'right') - 1 if end_idx < start_idx: return np.array([np.nan] * n) if field == 'acc_net_value': name = 'acc' elif field == 'unit_net_value': name = 'unit' else: raise ParamsError( "field should in (acc_net_value, unit_net_value)") ret = np.round(ct.table.cols[name][start_idx:end_idx + 1], security.price_decimals) if len(ret) < n: dates = list(dates) raw_list = [np.nan] * n ret_idx = 0 for idx in range(start_idx, end_idx + 1): date = to_date(ct.table.cols['date'][idx]) day_idx = dates.index(date) raw_list[day_idx] = ret[ret_idx] ret_idx += 1 ret = np.array(raw_list) return ret
def get_date_by_period(self, security, start_date, end_date, include_now=True): ct = self.open_bcolz_table(security) start_ts = to_timestamp(start_date) end_ts = to_timestamp(end_date) start_idx = ct.index.searchsorted(start_ts) end_idx = ct.index.searchsorted(end_ts, side='right') - 1 if not include_now and end_idx >= 0 and ct.index[end_idx] == end_ts: end_idx -= 1 if start_idx <= end_idx: return ct.index[start_idx:end_idx + 1] else: return _EMPTY_NP_ARRAY
def get_bar_by_count(self, security, end_date, count, fields, include_now=True): ''' 获取一支股票的天数据。 `security`: 股票代码,例如 '000001.XSHE'。 `end_date`: 结束日期,例如 datetime.date(2016, 12, 30)。 `count`: count条记录, 例如 300。 `fields`: 行情数据字段。 ''' ct = self.open_bcolz_table(security) count = int(count) end_ts = to_timestamp(end_date) end_idx = ct.index.searchsorted(end_ts, side='right') - 1 if not include_now and end_idx >= 0 and ct.index[end_idx] == end_ts: end_idx -= 1 if end_idx < 0: return {} start_idx = end_idx - count + 1 if start_idx < 0: start_idx = 0 data = { name: ct.table.cols[name][start_idx:end_idx + 1] for name in fields } data['date'] = ct.table.cols['date'][start_idx:end_idx + 1] return data
def get_bar_by_date(self, security, somedate): ''' 获取一支股票的天数据。 `security`: 股票代码,例如 '000001.XSHE'。 `date`: 如果date天不存在数据,返回上一个交易日的数据。 ''' ct = self.open_bcolz_table(security) ts = to_timestamp(somedate) idx = ct.index.searchsorted(ts, side='right') - 1 if idx < 0: return None price_decimals = security.price_decimals bar = {} for name in ct.table.names: if name in [ 'open', 'close', 'high', 'low', 'price', 'avg', 'pre_close', 'high_limit', 'low_limit' ]: bar[name] = fixed_round(ct.table.cols[name][idx], price_decimals) elif name in ['volume', 'money']: bar[name] = fixed_round(ct.table.cols[name][idx], 0) elif name in ['unit', 'acc', 'refactor']: bar[name] = fixed_round(ct.table.cols[name][idx], price_decimals) else: bar[name] = ct.table.cols[name][idx] bar['date'] = to_date(bar['date']) return bar
def get_minute_by_count(self, security, end_dt, count, include_now=True): count = int(count) ct = self.open_bcolz_table(security) if ct is None: return _EMPTY_NP_ARRAY if security.is_futures(): end_ts = to_timestamp(end_dt) end_idx = ct.index.searchsorted(end_ts, 'right') - 1 if not include_now and end_idx >= 0 and ct.index[end_idx] == end_ts: end_idx -= 1 else: end_idx = calc_stock_minute_index(security, ct.index, end_dt, include_now=include_now) if end_idx < 0: return _EMPTY_NP_ARRAY start_idx = end_idx + 1 - count if start_idx < 0: start_idx = 0 if start_idx > end_idx: return _EMPTY_NP_ARRAY if security.is_futures(): return ct.index[start_idx:end_idx + 1] else: return ct.table.cols['date'][start_idx:end_idx + 1]
def fill_paused(security, cols_dict, index, full_index, fields, index_type='date'): ''' index 是 a的索引 full_index 是完整的交易索引。 index_type 表示 dates 和 trade_dates的类型, 'date' 表示日期, 'minute' 表示 分钟 ''' # 获取最大有效数据日期, 超过此日期数据为nan end_date = security.end_date if index_type == 'minute': t0 = datetime.datetime.now().replace(second=0, microsecond=0) if end_date: t1 = datetime.datetime.combine(end_date, datetime.time(14, 59)) max_valid_date = min(t0, t1) else: max_valid_date = t0 elif index_type == 'date': t0 = datetime.date.today() if end_date: max_valid_date = min(t0, end_date) else: max_valid_date = t0 else: raise Exception("wrong index_type=%s, should be 'date' or 'minute'" % index_type) max_valid_ts = to_timestamp(max_valid_date) v = np.searchsorted(index, full_index, 'right') - 1 a = np.column_stack((cols_dict[f] for f in fields)) nan_const = np.full(a.shape[1], nan) b = [] for i, vi in enumerate(v): if vi < 0: b.append(nan_const) # 未上市 else: index_ts = index[vi] full_index_ts = full_index[i] if full_index_ts != index_ts: if full_index_ts > max_valid_ts: b.append(nan_const) else: # 必须有close字段 close_value = cols_dict['close'][vi] if 'factor' in cols_dict: factor_value = cols_dict['factor'][vi] else: factor_value = 1.0 b.append( paused_day_array(close_value, factor_value, fields)) else: b.append(a[vi]) new_a = np.array(b).reshape(len(b), a.shape[1]) for i, col in enumerate(fields): cols_dict[col] = new_a[:, i] return cols_dict
def get_bar_by_date(self, security, somedate): """ 获取一支股票的天数据。 `security`: 股票代码,例如 '000001.XSHE'。 `date`: 如果date天不存在数据,返回上一个交易日的数据。 """ ct = self.open_shm_block(security) ts = to_timestamp(somedate) idx = ct.index.searchsorted(ts, side='right') - 1 if idx < 0: return None price_decimals = security.price_decimals bar = {} for name in ct.columns: if name in ['open', 'close', 'high', 'low', 'price', 'avg', 'pre_close', 'high_limit', 'low_limit']: bar[name] = fixed_round(ct.getitem(idx, name), price_decimals) elif name in ['volume', 'money']: bar[name] = fixed_round(ct.getitem(idx, name), 0) else: bar[name] = ct.getitem(idx, name) bar['date'] = to_date(bar['date']) return bar
def _get_idx_by_count(self, security, end_date, count, include_end): ct = self.open_shm_block(security) end_ts = to_timestamp(end_date) stop_idx = ct.index.searchsorted(end_ts, side='right' if include_end else 'left') start_idx = stop_idx - count if start_idx < 0: start_idx = 0 return (ct, slice(start_idx, stop_idx))
def have_data(self, security, date): """ 当天是否有数据, 当天有交易返回 True, 否则 False """ index = self.get_trading_days(security) ts = to_timestamp(date) idx = index.searchsorted(ts, side='right') - 1 if idx >= 0 and index[idx] == ts: return True return False
def get_bar_by_period(self, security, start_dt, end_dt, fields, include_now=True): ''' 获取一支股票的分钟数据。 `security`: 股票代码,例如 '000001.XSHE'。 `start_dt`: 开始时间, 例如 datetime.datetime(2015, 1, 1, 0, 0, 0)。 `end_dt`: 结束时间,例如 datetime.datetime(2016, 12, 30, 0, 0, 0)。 `fields`: 行情数据字段。 ''' ct = self.open_bcolz_table(security) if ct is None: return {k: _EMPTY_NP_ARRAY for k in ('date', ) + tuple(fields)} if security.is_futures(): start_ts = to_timestamp(start_dt) end_ts = to_timestamp(end_dt) start_idx = ct.index.searchsorted(start_ts) end_idx = ct.index.searchsorted(end_ts, 'right') - 1 if not include_now and end_idx >= 0 and ct.index[end_idx] == end_ts: end_idx -= 1 else: start_idx = calc_stock_minute_index(security, ct.index, start_dt, side='right') end_idx = calc_stock_minute_index(security, ct.index, end_dt, include_now=include_now) if start_idx > end_idx: return {k: _EMPTY_NP_ARRAY for k in ('date', ) + tuple(fields)} data = {name: ct.table.cols[name][start_idx:end_idx+1]/_COL_POWERS.get(name, 1)\ for name in fields} if security.is_futures(): data['date'] = ct.index[start_idx:end_idx + 1] else: data['date'] = ct.table.cols['date'][start_idx:end_idx + 1] return data
def get_factor_by_date(self, security, date): """ 获取 security 在 date 这一天的复权因子,不存在则返回 1.0 """ ct = self.open_shm_block(security) idx = ct.index.searchsorted(to_timestamp(date), side='right') - 1 if idx < 0: return 1.0 if idx < len(ct.index): return ct.getitem(idx, 'factor') return 1.0
def get_factor_by_period(self, security, start_date, end_date): ''' 获取 security [start_date, end_date] 期间的复权因子。 如果停牌,则返回停牌前的复权因子。 ''' ct = self.open_bcolz_table(security) start_ts = to_timestamp(start_date) end_ts = to_timestamp(end_date) start_idx = ct.index.searchsorted(start_ts) end_idx = ct.index.searchsorted(end_ts, side='right') - 1 if start_idx <= end_idx: index = ct.table.cols['date'][start_idx:end_idx + 1] index = vec2date(index) data = ct.table.cols['factor'][start_idx:end_idx + 1] return index, data else: factor = self.get_factor_by_date(security, start_date) index = np.array([start_date]) data = np.array([factor]) return index, data
def get_factor_by_date(self, security, date): '''获取 security 在 date 这一天的复权因子,不存在则返回 1.0 ''' ct = self.open_bcolz_table(security) idx = ct.index.searchsorted(to_timestamp(date), side='right') - 1 if idx < 0: return 1.0 if idx < len(ct.index): f = ct.table.cols['factor'][idx] assert f > 0, 'security=%s, date=%s, factor=%s' % (security.code, date, f) return f return 1.0
def get_date_by_count(self, security, end_date, count, include_now=True): count = int(count) ct = self.open_bcolz_table(security) end_ts = to_timestamp(end_date) end_idx = ct.index.searchsorted(end_ts, side='right') - 1 if not include_now and end_idx >= 0 and ct.index[end_idx] == end_ts: end_idx -= 1 if end_idx < 0: return _EMPTY_NP_ARRAY start_idx = end_idx - count + 1 if start_idx < 0: start_idx = 0 return ct.index[start_idx:end_idx + 1]
def query(self, security, dates, field): n = len(dates) if n == 0: return np.array([]) ct = self.open_table(security) start_ts = to_timestamp(dates[0]) end_ts = to_timestamp(dates[-1]) start_idx = ct.index.searchsorted(start_ts) end_idx = ct.index.searchsorted(end_ts, 'right') - 1 if end_idx < start_idx: return np.array([np.nan] * n) if field == 'futures_sett_price': name = 'settlement' elif field == 'futures_positions': name = 'open_interest' else: raise ParamsError("filed should in (futures_sett_price, open_interest)") ret = np.round(ct.table.cols[name][start_idx:end_idx+1], security.price_decimals) if len(ret) < n: st = to_date(ct.table.cols['date'][start_idx]) et = to_date(ct.table.cols['date'][end_idx]) for i in range(0, n): if dates[i] >= st: break if i > 0: ret = np.concatenate([np.array([np.nan]*i), ret]) for i in range(n-1, -1, -1): if dates[i] <= et: break if i < n - 1: ret = np.concatenate([ret, np.array([np.nan] * (n - i - 1))]) return ret
def get_bar_by_dt(self, security, somedt): shm_date = self.get_latest_shm_date(security) ct = self.open_shm_block(security, shm_date) ts = to_timestamp(somedt) idx = ct.index.searchsorted(ts, side='right') - 1 if idx < 0: return None price_decimals = security.price_decimals bar = {} for name in ct.columns: if name in ['open', 'close', 'high', 'low', 'price', 'avg', 'pre_close', 'high_limit', 'low_limit']: bar[name] = fixed_round(ct.getitem(idx, name), price_decimals) elif name in ['volume', 'money']: bar[name] = fixed_round(ct.getitem(idx, name), 0) else: bar[name] = ct.getitem(idx, name) bar['date'] = to_datetime(bar['date']) return bar
def get_bar_by_dt(self, security, somedt): ''' 获取一支股票的分钟数据。 `security`: 股票代码,例如 '000001.XSHE'。 `dt`: 如果dt 分钟不存在数据,返回上一个交易分钟的数据。 ''' ct = self.open_bcolz_table(security) if ct is None: return None if security.is_futures(): ts = to_timestamp(somedt) idx = ct.index.searchsorted(ts, side='right') - 1 else: idx = calc_stock_minute_index(security, ct.index, somedt) if idx < 0: return None price_decimals = security.price_decimals bar = {} for name in ct.table.names: if name in [ 'open', 'close', 'high', 'low', 'price', 'avg', 'pre_close', 'high_limit', 'low_limit' ]: bar[name] = fixed_round( ct.table.cols[name][idx] / _COL_POWERS.get(name, 1), price_decimals) elif name in ['volume', 'money']: bar[name] = fixed_round( ct.table.cols[name][idx] / _COL_POWERS.get(name, 1), 0) else: bar[name] = ct.table.cols[name][idx] bar['date'] = to_datetime(bar['date']) return bar
def get_price(self, somedt): ''' 获取一支股票的后复权价格 `security`: 股票代码,例如 '000001.XSHE'。 `date`: 如果date天不存在数据,返回上一个交易日的数据。 ''' ct = self.table ts = to_timestamp(somedt) idx = ct.index.searchsorted(ts, side='right') - 1 if idx < 0: return np.nan bar_date = to_date(ct.index[idx]) if bar_date == somedt.date(): open_dt = CalendarStore.instance().get_open_dt( self.security, somedt.date()) if somedt < open_dt: idx = idx - 1 if idx < 0: return np.nan return fixed_round(ct.closes[idx] * ct.factors[idx], self.security.price_decimals)
def calc_stock_minute_index(security, date_index, dt, side='left', include_now=True): ''' # 计算dt 在分钟数据中的索引(只对股票有效) 如果 dt 不存在 date_index , side=left 表示返回左边的index, side=right 表示返回右边的index -1 表示没有 include_now 表示是否包含 dt ''' if isinstance(dt, datetime): ts = to_timestamp(dt.date()) elif isinstance(dt, pd.Timestamp): ts = int(dt.value / (10**9) / 86400) * 86400 else: raise ParamsError("wrong dt=%s, type(dt)=%s" % (dt, type(dt))) total_days = date_index.searchsorted(ts, side='right') - 1 if total_days < 0: return -1 if date_index[total_days] == ts: trading_day = True else: trading_day = False total_days += 1 # 前面一共有total_days个交易日 dt_minutes = dt.hour * 60 + dt.minute total_minutes = 0 if trading_day: # 9:31 之前 if dt_minutes < (9 * 60 + 31): if side == 'left': total_minutes = 0 else: total_minutes = 1 elif dt_minutes < (11 * 60 + 31): total_minutes = dt_minutes - (9 * 60 + 30) if not include_now: total_minutes -= 1 elif dt_minutes < 13 * 60 + 1: if side == 'left': total_minutes = 120 else: total_minutes = 121 elif dt_minutes < 15 * 60 + 1: total_minutes = 120 + dt_minutes - 13 * 60 if not include_now: total_minutes -= 1 else: if side == 'left': total_minutes = 240 else: total_minutes = 241 else: if side == 'left': total_minutes = 0 else: total_minutes = 1 # 下标从0开始。 -1 表示 没有。 return (total_days * 240 + total_minutes) - 1