def get_bar_by_period(self,
                          security,
                          start_date,
                          end_date,
                          fields,
                          include_now=True):
        '''
        获取一支股票的天数据。

        `security`: 股票代码,例如 '000001.XSHE'。
        `start_date`: 开始日期, 例如 datetime.date(2015, 1, 1)。
        `end_date`: 结束日期,例如 datetime.date(2016, 12, 30)。
        `fields`: 行情数据字段。
        '''
        ct = self.open_bcolz_table(security)

        start_ts = to_timestamp(start_date)
        end_ts = to_timestamp(end_date)

        start_idx = ct.index.searchsorted(start_ts)
        end_idx = ct.index.searchsorted(end_ts, 'right') - 1
        if not include_now and end_idx >= 0 and ct.index[end_idx] == end_ts:
            end_idx -= 1
        if end_idx < start_idx:
            return {}
        data = {
            name: ct.table.cols[name][start_idx:end_idx + 1]
            for name in fields
        }
        data['date'] = ct.table.cols['date'][start_idx:end_idx + 1]
        return data
    def get_minute_by_period(self,
                             security,
                             start_dt,
                             end_dt,
                             include_now=True):
        ct = self.open_bcolz_table(security)
        if ct is None:
            return _EMPTY_NP_ARRAY

        if security.is_futures():
            start_ts = to_timestamp(start_dt)
            end_ts = to_timestamp(end_dt)
            start_idx = ct.index.searchsorted(start_ts)
            end_idx = ct.index.searchsorted(end_ts, 'right') - 1
            if not include_now and end_idx >= 0 and ct.index[end_idx] == end_ts:
                end_idx -= 1
        else:
            start_idx = calc_stock_minute_index(security,
                                                ct.index,
                                                start_dt,
                                                side='right')
            end_idx = calc_stock_minute_index(security,
                                              ct.index,
                                              end_dt,
                                              include_now=include_now)

        if start_idx > end_idx:
            return _EMPTY_NP_ARRAY

        if security.is_futures():
            return ct.index[start_idx:end_idx + 1]
        else:
            return ct.table.cols['date'][start_idx:end_idx + 1]
 def _get_idx_by_period(self, security, start_date, end_date, include_end):
     ct = self.open_shm_block(security)
     start_ts = to_timestamp(start_date)
     end_ts = to_timestamp(end_date)
     start_idx = ct.index.searchsorted(start_ts)
     stop_idx = ct.index.searchsorted(end_ts, side='right' if include_end else 'left')
     return (ct, slice(start_idx, stop_idx))
Exemple #4
0
    def query(self, security, dates, field):
        n = len(dates)
        if n == 0:
            return np.array([])
        ct = self.open_table(security)
        start_ts = to_timestamp(dates[0])
        end_ts = to_timestamp(dates[-1])

        start_idx = ct.index.searchsorted(start_ts)
        end_idx = ct.index.searchsorted(end_ts, 'right') - 1
        if end_idx < start_idx:
            return np.array([np.nan] * n)
        if field == 'acc_net_value':
            name = 'acc'
        elif field == 'unit_net_value':
            name = 'unit'
        else:
            raise ParamsError(
                "field should in (acc_net_value, unit_net_value)")
        ret = np.round(ct.table.cols[name][start_idx:end_idx + 1],
                       security.price_decimals)

        if len(ret) < n:
            dates = list(dates)
            raw_list = [np.nan] * n
            ret_idx = 0
            for idx in range(start_idx, end_idx + 1):
                date = to_date(ct.table.cols['date'][idx])
                day_idx = dates.index(date)
                raw_list[day_idx] = ret[ret_idx]
                ret_idx += 1
            ret = np.array(raw_list)
        return ret
 def get_date_by_period(self,
                        security,
                        start_date,
                        end_date,
                        include_now=True):
     ct = self.open_bcolz_table(security)
     start_ts = to_timestamp(start_date)
     end_ts = to_timestamp(end_date)
     start_idx = ct.index.searchsorted(start_ts)
     end_idx = ct.index.searchsorted(end_ts, side='right') - 1
     if not include_now and end_idx >= 0 and ct.index[end_idx] == end_ts:
         end_idx -= 1
     if start_idx <= end_idx:
         return ct.index[start_idx:end_idx + 1]
     else:
         return _EMPTY_NP_ARRAY
    def get_bar_by_count(self,
                         security,
                         end_date,
                         count,
                         fields,
                         include_now=True):
        '''
        获取一支股票的天数据。

        `security`: 股票代码,例如 '000001.XSHE'。
        `end_date`: 结束日期,例如 datetime.date(2016, 12, 30)。
        `count`: count条记录, 例如 300。
        `fields`: 行情数据字段。
        '''
        ct = self.open_bcolz_table(security)
        count = int(count)
        end_ts = to_timestamp(end_date)
        end_idx = ct.index.searchsorted(end_ts, side='right') - 1
        if not include_now and end_idx >= 0 and ct.index[end_idx] == end_ts:
            end_idx -= 1

        if end_idx < 0:
            return {}
        start_idx = end_idx - count + 1
        if start_idx < 0:
            start_idx = 0

        data = {
            name: ct.table.cols[name][start_idx:end_idx + 1]
            for name in fields
        }
        data['date'] = ct.table.cols['date'][start_idx:end_idx + 1]
        return data
    def get_bar_by_date(self, security, somedate):
        '''
        获取一支股票的天数据。

        `security`: 股票代码,例如 '000001.XSHE'。
        `date`: 如果date天不存在数据,返回上一个交易日的数据。
        '''
        ct = self.open_bcolz_table(security)
        ts = to_timestamp(somedate)
        idx = ct.index.searchsorted(ts, side='right') - 1
        if idx < 0:
            return None
        price_decimals = security.price_decimals
        bar = {}
        for name in ct.table.names:
            if name in [
                    'open', 'close', 'high', 'low', 'price', 'avg',
                    'pre_close', 'high_limit', 'low_limit'
            ]:
                bar[name] = fixed_round(ct.table.cols[name][idx],
                                        price_decimals)
            elif name in ['volume', 'money']:
                bar[name] = fixed_round(ct.table.cols[name][idx], 0)
            elif name in ['unit', 'acc', 'refactor']:
                bar[name] = fixed_round(ct.table.cols[name][idx],
                                        price_decimals)
            else:
                bar[name] = ct.table.cols[name][idx]
        bar['date'] = to_date(bar['date'])
        return bar
    def get_minute_by_count(self, security, end_dt, count, include_now=True):
        count = int(count)
        ct = self.open_bcolz_table(security)
        if ct is None:
            return _EMPTY_NP_ARRAY
        if security.is_futures():
            end_ts = to_timestamp(end_dt)
            end_idx = ct.index.searchsorted(end_ts, 'right') - 1
            if not include_now and end_idx >= 0 and ct.index[end_idx] == end_ts:
                end_idx -= 1
        else:
            end_idx = calc_stock_minute_index(security,
                                              ct.index,
                                              end_dt,
                                              include_now=include_now)
        if end_idx < 0:
            return _EMPTY_NP_ARRAY
        start_idx = end_idx + 1 - count
        if start_idx < 0:
            start_idx = 0
        if start_idx > end_idx:
            return _EMPTY_NP_ARRAY

        if security.is_futures():
            return ct.index[start_idx:end_idx + 1]
        else:
            return ct.table.cols['date'][start_idx:end_idx + 1]
def fill_paused(security,
                cols_dict,
                index,
                full_index,
                fields,
                index_type='date'):
    '''
    index 是 a的索引
    full_index 是完整的交易索引。

    index_type 表示 dates 和 trade_dates的类型, 'date' 表示日期, 'minute' 表示 分钟
    '''
    # 获取最大有效数据日期, 超过此日期数据为nan
    end_date = security.end_date
    if index_type == 'minute':
        t0 = datetime.datetime.now().replace(second=0, microsecond=0)
        if end_date:
            t1 = datetime.datetime.combine(end_date, datetime.time(14, 59))
            max_valid_date = min(t0, t1)
        else:
            max_valid_date = t0
    elif index_type == 'date':
        t0 = datetime.date.today()
        if end_date:
            max_valid_date = min(t0, end_date)
        else:
            max_valid_date = t0

    else:
        raise Exception("wrong index_type=%s, should be 'date' or 'minute'" %
                        index_type)
    max_valid_ts = to_timestamp(max_valid_date)
    v = np.searchsorted(index, full_index, 'right') - 1
    a = np.column_stack((cols_dict[f] for f in fields))
    nan_const = np.full(a.shape[1], nan)
    b = []
    for i, vi in enumerate(v):
        if vi < 0:
            b.append(nan_const)  # 未上市
        else:
            index_ts = index[vi]
            full_index_ts = full_index[i]
            if full_index_ts != index_ts:
                if full_index_ts > max_valid_ts:
                    b.append(nan_const)
                else:
                    # 必须有close字段
                    close_value = cols_dict['close'][vi]
                    if 'factor' in cols_dict:
                        factor_value = cols_dict['factor'][vi]
                    else:
                        factor_value = 1.0
                    b.append(
                        paused_day_array(close_value, factor_value, fields))
            else:
                b.append(a[vi])
    new_a = np.array(b).reshape(len(b), a.shape[1])
    for i, col in enumerate(fields):
        cols_dict[col] = new_a[:, i]
    return cols_dict
    def get_bar_by_date(self, security, somedate):
        """
        获取一支股票的天数据。

        `security`: 股票代码,例如 '000001.XSHE'。
        `date`: 如果date天不存在数据,返回上一个交易日的数据。
        """
        ct = self.open_shm_block(security)
        ts = to_timestamp(somedate)
        idx = ct.index.searchsorted(ts, side='right') - 1
        if idx < 0:
            return None
        price_decimals = security.price_decimals
        bar = {}

        for name in ct.columns:
            if name in ['open', 'close', 'high', 'low', 'price', 'avg',
                        'pre_close', 'high_limit', 'low_limit']:
                bar[name] = fixed_round(ct.getitem(idx, name), price_decimals)
            elif name in ['volume', 'money']:
                bar[name] = fixed_round(ct.getitem(idx, name), 0)
            else:
                bar[name] = ct.getitem(idx, name)
        bar['date'] = to_date(bar['date'])
        return bar
 def _get_idx_by_count(self, security, end_date, count, include_end):
     ct = self.open_shm_block(security)
     end_ts = to_timestamp(end_date)
     stop_idx = ct.index.searchsorted(end_ts, side='right' if include_end else 'left')
     start_idx = stop_idx - count
     if start_idx < 0:
         start_idx = 0
     return (ct, slice(start_idx, stop_idx))
 def have_data(self, security, date):
     """ 当天是否有数据, 当天有交易返回 True, 否则 False """
     index = self.get_trading_days(security)
     ts = to_timestamp(date)
     idx = index.searchsorted(ts, side='right') - 1
     if idx >= 0 and index[idx] == ts:
         return True
     return False
    def get_bar_by_period(self,
                          security,
                          start_dt,
                          end_dt,
                          fields,
                          include_now=True):
        '''
        获取一支股票的分钟数据。

        `security`: 股票代码,例如 '000001.XSHE'。
        `start_dt`: 开始时间, 例如 datetime.datetime(2015, 1, 1, 0, 0, 0)。
        `end_dt`: 结束时间,例如 datetime.datetime(2016, 12, 30, 0, 0, 0)。
        `fields`: 行情数据字段。
        '''

        ct = self.open_bcolz_table(security)
        if ct is None:
            return {k: _EMPTY_NP_ARRAY for k in ('date', ) + tuple(fields)}

        if security.is_futures():
            start_ts = to_timestamp(start_dt)
            end_ts = to_timestamp(end_dt)
            start_idx = ct.index.searchsorted(start_ts)
            end_idx = ct.index.searchsorted(end_ts, 'right') - 1
            if not include_now and end_idx >= 0 and ct.index[end_idx] == end_ts:
                end_idx -= 1
        else:
            start_idx = calc_stock_minute_index(security,
                                                ct.index,
                                                start_dt,
                                                side='right')
            end_idx = calc_stock_minute_index(security,
                                              ct.index,
                                              end_dt,
                                              include_now=include_now)

        if start_idx > end_idx:
            return {k: _EMPTY_NP_ARRAY for k in ('date', ) + tuple(fields)}

        data = {name: ct.table.cols[name][start_idx:end_idx+1]/_COL_POWERS.get(name, 1)\
                for name in fields}
        if security.is_futures():
            data['date'] = ct.index[start_idx:end_idx + 1]
        else:
            data['date'] = ct.table.cols['date'][start_idx:end_idx + 1]
        return data
    def get_factor_by_date(self, security, date):
        """
        获取 security 在 date 这一天的复权因子,不存在则返回 1.0
        """
        ct = self.open_shm_block(security)

        idx = ct.index.searchsorted(to_timestamp(date), side='right') - 1
        if idx < 0:
            return 1.0
        if idx < len(ct.index):
            return ct.getitem(idx, 'factor')
        return 1.0
    def get_factor_by_period(self, security, start_date, end_date):
        '''
        获取 security [start_date, end_date] 期间的复权因子。
        如果停牌,则返回停牌前的复权因子。
        '''
        ct = self.open_bcolz_table(security)

        start_ts = to_timestamp(start_date)
        end_ts = to_timestamp(end_date)
        start_idx = ct.index.searchsorted(start_ts)
        end_idx = ct.index.searchsorted(end_ts, side='right') - 1
        if start_idx <= end_idx:
            index = ct.table.cols['date'][start_idx:end_idx + 1]
            index = vec2date(index)
            data = ct.table.cols['factor'][start_idx:end_idx + 1]
            return index, data
        else:
            factor = self.get_factor_by_date(security, start_date)
            index = np.array([start_date])
            data = np.array([factor])
            return index, data
 def get_factor_by_date(self, security, date):
     '''获取 security 在 date 这一天的复权因子,不存在则返回 1.0
     '''
     ct = self.open_bcolz_table(security)
     idx = ct.index.searchsorted(to_timestamp(date), side='right') - 1
     if idx < 0:
         return 1.0
     if idx < len(ct.index):
         f = ct.table.cols['factor'][idx]
         assert f > 0, 'security=%s, date=%s, factor=%s' % (security.code,
                                                            date, f)
         return f
     return 1.0
 def get_date_by_count(self, security, end_date, count, include_now=True):
     count = int(count)
     ct = self.open_bcolz_table(security)
     end_ts = to_timestamp(end_date)
     end_idx = ct.index.searchsorted(end_ts, side='right') - 1
     if not include_now and end_idx >= 0 and ct.index[end_idx] == end_ts:
         end_idx -= 1
     if end_idx < 0:
         return _EMPTY_NP_ARRAY
     start_idx = end_idx - count + 1
     if start_idx < 0:
         start_idx = 0
     return ct.index[start_idx:end_idx + 1]
    def query(self, security, dates, field):
        n = len(dates)
        if n == 0:
            return np.array([])

        ct = self.open_table(security)
        start_ts = to_timestamp(dates[0])
        end_ts = to_timestamp(dates[-1])

        start_idx = ct.index.searchsorted(start_ts)
        end_idx = ct.index.searchsorted(end_ts, 'right') - 1
        if end_idx < start_idx:
            return np.array([np.nan] * n)
        if field == 'futures_sett_price':
            name = 'settlement'
        elif field == 'futures_positions':
            name = 'open_interest'
        else:
            raise ParamsError("filed should in (futures_sett_price, open_interest)")

        ret = np.round(ct.table.cols[name][start_idx:end_idx+1], security.price_decimals)

        if len(ret) < n:
            st = to_date(ct.table.cols['date'][start_idx])
            et = to_date(ct.table.cols['date'][end_idx])
            for i in range(0, n):
                if dates[i] >= st:
                    break
            if i > 0:
                ret = np.concatenate([np.array([np.nan]*i), ret])
            for i in range(n-1, -1, -1):
                if dates[i] <= et:
                    break
            if i < n - 1:
                ret = np.concatenate([ret, np.array([np.nan] * (n - i - 1))])
        return ret
    def get_bar_by_dt(self, security, somedt):
        shm_date = self.get_latest_shm_date(security)
        ct = self.open_shm_block(security, shm_date)
        ts = to_timestamp(somedt)
        idx = ct.index.searchsorted(ts, side='right') - 1

        if idx < 0:
            return None

        price_decimals = security.price_decimals
        bar = {}
        for name in ct.columns:
            if name in ['open', 'close', 'high', 'low', 'price', 'avg',
                        'pre_close', 'high_limit', 'low_limit']:
                bar[name] = fixed_round(ct.getitem(idx, name), price_decimals)
            elif name in ['volume', 'money']:
                bar[name] = fixed_round(ct.getitem(idx, name), 0)
            else:
                bar[name] = ct.getitem(idx, name)

        bar['date'] = to_datetime(bar['date'])
        return bar
    def get_bar_by_dt(self, security, somedt):
        '''
        获取一支股票的分钟数据。

        `security`: 股票代码,例如 '000001.XSHE'。
        `dt`: 如果dt 分钟不存在数据,返回上一个交易分钟的数据。
        '''

        ct = self.open_bcolz_table(security)
        if ct is None:
            return None
        if security.is_futures():
            ts = to_timestamp(somedt)
            idx = ct.index.searchsorted(ts, side='right') - 1
        else:
            idx = calc_stock_minute_index(security, ct.index, somedt)

        if idx < 0:
            return None

        price_decimals = security.price_decimals
        bar = {}
        for name in ct.table.names:
            if name in [
                    'open', 'close', 'high', 'low', 'price', 'avg',
                    'pre_close', 'high_limit', 'low_limit'
            ]:
                bar[name] = fixed_round(
                    ct.table.cols[name][idx] / _COL_POWERS.get(name, 1),
                    price_decimals)
            elif name in ['volume', 'money']:
                bar[name] = fixed_round(
                    ct.table.cols[name][idx] / _COL_POWERS.get(name, 1), 0)
            else:
                bar[name] = ct.table.cols[name][idx]

        bar['date'] = to_datetime(bar['date'])
        return bar
    def get_price(self, somedt):
        '''
        获取一支股票的后复权价格

        `security`: 股票代码,例如 '000001.XSHE'。
        `date`: 如果date天不存在数据,返回上一个交易日的数据。
        '''
        ct = self.table
        ts = to_timestamp(somedt)
        idx = ct.index.searchsorted(ts, side='right') - 1
        if idx < 0:
            return np.nan
        bar_date = to_date(ct.index[idx])
        if bar_date == somedt.date():
            open_dt = CalendarStore.instance().get_open_dt(
                self.security, somedt.date())
            if somedt < open_dt:
                idx = idx - 1
                if idx < 0:
                    return np.nan

        return fixed_round(ct.closes[idx] * ct.factors[idx],
                           self.security.price_decimals)
def calc_stock_minute_index(security,
                            date_index,
                            dt,
                            side='left',
                            include_now=True):
    '''
    # 计算dt 在分钟数据中的索引(只对股票有效)

    如果 dt 不存在 date_index ,
    side=left 表示返回左边的index,
    side=right 表示返回右边的index
    -1 表示没有

    include_now 表示是否包含 dt
    '''
    if isinstance(dt, datetime):
        ts = to_timestamp(dt.date())
    elif isinstance(dt, pd.Timestamp):
        ts = int(dt.value / (10**9) / 86400) * 86400
    else:
        raise ParamsError("wrong dt=%s, type(dt)=%s" % (dt, type(dt)))

    total_days = date_index.searchsorted(ts, side='right') - 1
    if total_days < 0:
        return -1
    if date_index[total_days] == ts:
        trading_day = True
    else:
        trading_day = False
        total_days += 1
    # 前面一共有total_days个交易日
    dt_minutes = dt.hour * 60 + dt.minute
    total_minutes = 0
    if trading_day:
        # 9:31 之前
        if dt_minutes < (9 * 60 + 31):
            if side == 'left':
                total_minutes = 0
            else:
                total_minutes = 1
        elif dt_minutes < (11 * 60 + 31):
            total_minutes = dt_minutes - (9 * 60 + 30)
            if not include_now:
                total_minutes -= 1
        elif dt_minutes < 13 * 60 + 1:
            if side == 'left':
                total_minutes = 120
            else:
                total_minutes = 121
        elif dt_minutes < 15 * 60 + 1:
            total_minutes = 120 + dt_minutes - 13 * 60
            if not include_now:
                total_minutes -= 1
        else:
            if side == 'left':
                total_minutes = 240
            else:
                total_minutes = 241
    else:
        if side == 'left':
            total_minutes = 0
        else:
            total_minutes = 1

    # 下标从0开始。 -1 表示 没有。
    return (total_days * 240 + total_minutes) - 1