Ejemplo n.º 1
0
def _resample_days_bars(bars, unit):
    '''
    将day bar合并成周线和月线。
    :param bars: 
    :param unit: 
    :return: 
    '''
    assert unit in ('1w', '1M')
    cmpfunc = _is_same_week if unit == '1w' else _is_same_month

    if len(bars['date']) == 0:
        return {f: np.zeros(0, dtype=bars[f].dtype) for f in bars}

    dates = vec2datetime(bars['date'])
    indexes = []
    n = len(dates)
    i = j = 0
    while j < n:
        same_unit = cmpfunc(dates[i], dates[j])
        # print(dates[i], dates[j], same_unit)
        if not same_unit:
            indexes.append(i)  # append last day in week
            i = j
        j += 1
    if j == n:
        indexes.append(i)
    indexes = np.array(indexes)
    result_len = len(indexes)
    result = {f: np.zeros(result_len, dtype=bars[f].dtype) for f in bars}
    # print(indexes)
    for f in bars:
        how = FIELD_AGG_FUNCTIONS[f]
        if how == 'last':
            result[f][:-1] = bars[f][indexes[1:] - 1]
            result[f][-1] = bars[f][-1]
        elif how == 'first':
            result[f] = bars[f][indexes]
        else:
            result[f] = how.reduceat(bars[f], indexes)

    return result
Ejemplo n.º 2
0
def get_bars(end_dt,
             security,
             count,
             unit='1d',
             fields=('open', 'high', 'low', 'close'),
             include_now=False,
             fq=None,
             pre_factor_ref_date=None):
    '''
    :param end_dt: 截止日期
    :param security: 标的
    :param count:   bar个数
    :param unit: 频率,'1d'表示1天,'xm'表示x分钟。
    :param fields:
    :param include_now:
    :param fq: 'pre'表示前复权, 'post'表示后复权, None表示真实价格。
    :param pre_factor_ref_date: 前复权基准日期,这一天的价格为真实价格。None则表示全部取真实价格。
    :return:
    '''
    valid_bar_fields = ('date', 'open', 'close', 'high', 'low', 'volume',
                        'money')
    if isinstance(fields, (list, tuple)):
        for f in fields:
            assert f in valid_bar_fields, "get_bars 只支持 %s 字段" % (
                valid_bar_fields)
        str_field = False
    elif isinstance(fields, six.string_types):
        assert fields in valid_bar_fields, "get_bars 只支持 %s 字段" % (
            valid_bar_fields)
        str_field = True
    else:
        raise ParamsError("fields 应该是字符串或者list")

    if str_field:
        new_fields = [fields]
    else:
        new_fields = [i for i in fields]
    if 'factor' not in fields:
        new_fields.append('factor')
    end_dt = convert_dt(end_dt)
    valid_unit = ('1m', '5m', '15m', '30m', '60m', '120m', '1d', '1w', '1M')
    assert unit in valid_unit, 'get_bars, unit必须是 %s 中一种' % valid_unit
    count = int(count)
    assert count > 0, "get_bars, count必须是一个正整数"
    fq = ensure_fq(fq)
    security = convert_security(security)
    include_now = bool(include_now)

    end_trade_date = CalendarStore.instance().get_current_trade_date(
        security, end_dt)

    def ensure_not_empty(cols_dict):
        if cols_dict == {}:
            ret = {}
            for f in valid_bar_fields:
                ret[f] = np.zeros(0)
            return ret
        return cols_dict

    if unit == '1d':
        if include_now:
            # 获取当天的snapshot
            snapshot = get_snapshot(security, end_trade_date, end_dt)
            if snapshot:
                cols_dict = get_daily_bar_by_count(security,
                                                   end_trade_date,
                                                   count - 1,
                                                   new_fields,
                                                   include_now=False)
                cols_dict = ensure_not_empty(cols_dict)
                for f in cols_dict:
                    cols_dict[f] = np.append(cols_dict[f], snapshot[f])
            else:
                cols_dict = get_daily_bar_by_count(security,
                                                   end_trade_date,
                                                   count,
                                                   new_fields,
                                                   include_now=False)
                cols_dict = ensure_not_empty(cols_dict)
        else:
            cols_dict = get_daily_bar_by_count(security,
                                               end_trade_date,
                                               count,
                                               new_fields,
                                               include_now=False)
            cols_dict = ensure_not_empty(cols_dict)
    elif unit == '1m':
        end_dt = convert_dt(end_dt)
        cols_dict = get_minute_bar_by_count(security,
                                            end_dt,
                                            count,
                                            new_fields,
                                            include_now=include_now)
        cols_dict = ensure_not_empty(cols_dict)
    elif unit in ('5m', '15m', '30m', '60m', '120m'):
        x = int(unit[:-1])
        if security.is_futures():
            trade_days = CalendarStore.instance().get_all_trade_days(security)
            trade_days = trade_days[(trade_days >= security.start_date )&\
                                    (trade_days <= end_trade_date)&\
                                    (trade_days <= security.end_date)]
            cols_dict = {f: np.zeros(0) for f in new_fields}

            for idx in range(len(trade_days) - 1, -1, -1):
                open_dt = CalendarStore.instance().get_open_dt(
                    security, trade_days[idx])
                if trade_days[idx] == end_trade_date:
                    if not include_now:
                        close_dt = _not_include_now(security, end_trade_date,
                                                    end_dt, unit)
                    else:
                        close_dt = end_dt
                else:
                    close_dt = CalendarStore.instance().get_close_dt(
                        security, trade_days[idx])
                tmp_dict = get_minute_bar_by_period(security,
                                                    open_dt,
                                                    close_dt,
                                                    new_fields,
                                                    include_now=True)
                if not tmp_dict or len(tmp_dict[new_fields[0]] == 0):
                    continue
                tmp_dict = _resample_future_xm_bars(tmp_dict, x)
                for col in cols_dict:
                    cols_dict[col] = np.append(tmp_dict[col], cols_dict[f])
                if len(cols_dict[new_fields[0]]) >= count:
                    break
            for f in cols_dict:
                cols_dict[f] = cols_dict[f][-count:]
        else:
            cols_dict = {f: np.zeros(0) for f in new_fields}
            open_dt = CalendarStore.instance().get_open_dt(
                security, end_dt.date())
            if not include_now:
                close_dt = _not_include_now(security, end_trade_date, end_dt,
                                            unit)
            else:
                close_dt = end_dt
            tmp_dict = get_minute_bar_by_period(security,
                                                open_dt,
                                                close_dt,
                                                new_fields,
                                                include_now=True)
            if tmp_dict and len(tmp_dict[new_fields[0]]) > 0:
                tmp_dict = _resample_simple_xm_bars(tmp_dict, x)
                for col in cols_dict:
                    cols_dict[col] = np.append(tmp_dict[col], cols_dict[col])
            need_count = count - len(cols_dict[new_fields[0]])
            if need_count > 0:
                tmp_dict = get_minute_bar_by_count(security,
                                                   open_dt,
                                                   need_count * x,
                                                   new_fields,
                                                   include_now=False)
                tmp_dict = _resample_simple_xm_bars(tmp_dict, x)
                for col in cols_dict:
                    cols_dict[col] = np.append(tmp_dict[col], cols_dict[col])
            for f in cols_dict:
                cols_dict[f] = cols_dict[f][-count:]

    # 周线和月线必须先复权,然后 resample
    elif unit == '1w':
        if include_now:
            snapshot = get_snapshot(security, end_trade_date, end_dt)
            cols_dict = get_daily_bar_by_count(security,
                                               end_trade_date,
                                               count * 5,
                                               new_fields,
                                               include_now=False)
            cols_dict = ensure_not_empty(cols_dict)
            if snapshot:
                for f in cols_dict:
                    cols_dict[f] = np.append(cols_dict[f], snapshot[f])
        else:
            # monday == 0 ... Sunday == 6
            weekday = end_trade_date.weekday()
            last_sunday = end_trade_date - datetime.timedelta(weekday + 1)
            cols_dict = get_daily_bar_by_count(security,
                                               last_sunday,
                                               count * 5,
                                               new_fields,
                                               include_now=False)
            cols_dict = ensure_not_empty(cols_dict)
        cols_dict = _pre_fq(security, cols_dict, pre_factor_ref_date)
        cols_dict = _resample_days_bars(cols_dict, unit)
        for f in cols_dict:
            cols_dict[f] = cols_dict[f][-count:]
    elif unit == '1M':
        if include_now:
            snapshot = get_snapshot(security, end_trade_date, end_dt)
            cols_dict = get_daily_bar_by_count(security,
                                               end_trade_date,
                                               count * 31,
                                               new_fields,
                                               include_now=False)
            cols_dict = ensure_not_empty(cols_dict)
            if snapshot:
                for f in cols_dict:
                    cols_dict[f] = np.append(cols_dict[f], snapshot[f])
        else:
            end_date = end_trade_date.replace(day=1) - datetime.timedelta(
                days=1)
            cols_dict = get_daily_bar_by_count(security,
                                               end_date,
                                               count * 31,
                                               new_fields,
                                               include_now=False)
            cols_dict = ensure_not_empty(cols_dict)
        cols_dict = _pre_fq(security, cols_dict, pre_factor_ref_date)
        cols_dict = _resample_days_bars(cols_dict, unit)
        for f in cols_dict:
            cols_dict[f] = cols_dict[f][-count:]
    else:
        raise ParamsError("get_bars 支持 '1m', '1d'")

    # 将时间戳转换成datetime 或者 date。
    if 'date' in cols_dict:
        if unit in ('1d', '1w', '1M'):
            cols_dict['date'] = vec2date(cols_dict['date'])
        else:
            cols_dict['date'] = vec2datetime(cols_dict['date'])

    # 期货没有复权。
    # 周线和月线必须先复权(同一个周期内复权因子可能不同)
    if not security.is_futures() and unit not in ('1w', '1M'):
        if pre_factor_ref_date is not None:
            cols_dict = _pre_fq(security, cols_dict, pre_factor_ref_date)

    if str_field:
        dtype = np.dtype([(fields, cols_dict[fields].dtype)])
        cols = [cols_dict[fields]]
        result = np.rec.fromarrays(cols, dtype=dtype).view(np.ndarray)
    else:
        # numpy bug: name 不能为unicode。
        dtype = np.dtype([(str(name), cols_dict[name].dtype)
                          for name in fields])
        cols = [cols_dict[name] for name in fields]
        result = np.rec.fromarrays(cols, dtype=dtype).view(np.ndarray)
    return result
Ejemplo n.º 3
0
def attribute_history(end_dt,
                      security,
                      count,
                      unit='1d',
                      fields=tuple(DEFAULT_FIELDS),
                      skip_paused=True,
                      df=True,
                      fq='pre',
                      pre_factor_ref_date=None):
    '''
    只能在回测/模拟中调用,不能再研究中调用。
    参数说明:

    `unit`是'Xd'时, end_dt的类型是datetime.date;
    `unit`是'Xm'时, end_dt的类型是datetime.datetime;

    `count` 必须大于0;
    '''
    count = int(count)
    assert count > 0, "attribute_history, count必须是一个正整数"
    fields = ensure_str_tuple(fields)
    check_unit_fields(unit, fields)
    security = convert_security(security)
    if 'price' in fields:
        warn_price_as_avg('使用 price 作为 attribute_history 的 fields 参数',
                          'attributehistory')

    group = int(unit[:-1])
    total = int(count * group)
    skip_paused = bool(skip_paused)
    df = bool(df)
    fq = ensure_fq(fq)
    if pre_factor_ref_date is not None:
        pre_factor_ref_date = convert_date(pre_factor_ref_date)
    if unit.endswith('d'):
        end_dt = convert_date(end_dt)
        a, index = get_price_daily_single(
            security,
            end_date=end_dt,
            count=total,
            fields=fields,
            skip_paused=skip_paused,
            fq=fq,
            include_now=False,
            pre_factor_ref_date=pre_factor_ref_date)

    else:
        end_dt = convert_dt(end_dt)
        a, index = get_price_minute_single(
            security,
            end_dt=end_dt,
            count=total,
            fields=fields,
            skip_paused=skip_paused,
            fq=fq,
            include_now=False,
            pre_factor_ref_date=pre_factor_ref_date)

    dict_by_column = {
        f: group_array(a[f if f != 'price' else 'avg'], group, f)
        for f in fields
    }

    if not df:
        return dict_by_column
    else:
        if index is not None and len(index) > 0:
            index = group_array(index, group, 'index')
            index = vec2datetime(index)
        return pd.DataFrame(index=index, columns=fields, data=dict_by_column)
Ejemplo n.º 4
0
def get_price(security,
              start_date=None,
              end_date=None,
              frequency='daily',
              fields=None,
              skip_paused=False,
              fq='pre',
              count=None,
              pre_factor_ref_date=None):

    security = convert_security(security)

    if count is not None and start_date is not None:
        raise ParamsError("get_price 不能同时指定 start_date 和 count 两个参数")

    if count is not None:
        count = int(count)

    end_dt = convert_dt(end_date) if end_date else datetime.datetime(
        2015, 12, 31)
    end_dt = min(end_dt, date2dt(CalendarStore.instance().last_day))

    start_dt = convert_dt(start_date) if start_date else datetime.datetime(
        2015, 1, 1)
    start_dt = max(start_dt, date2dt(CalendarStore.instance().first_day))

    if pre_factor_ref_date:
        pre_factor_ref_date = convert_date(pre_factor_ref_date)

    if frequency in frequency_compat:
        unit = frequency_compat.get(frequency)
    else:
        unit = frequency

    if fields is not None:
        fields = ensure_str_tuple(fields)
        if 'price' in fields:
            warn_price_as_avg('使用 price 作为 get_price 的 fields 参数', 'getprice')
    else:
        fields = tuple(DEFAULT_FIELDS)

    check_unit_fields(unit, fields)
    fq = ensure_fq(fq)
    skip_paused = bool(skip_paused)
    if is_list(security) and skip_paused:
        raise ParamsError("get_price 取多只股票数据时, 为了对齐日期, 不能跳过停牌")

    if is_list(security) and _has_stock_with_future(security):
        if unit.endswith('m'):
            raise ParamsError("get_price 取分钟数据时,为了对齐数据,不能同时取股票和期货。")

    group = int(unit[:-1])
    res = {}
    for s in (security if is_list(security) else [security]):
        if unit.endswith('d'):
            a, index = get_price_daily_single(
                s,
                end_date=end_dt.date(),
                start_date=start_dt.date() if start_dt else None,
                count=count * group if count is not None else None,
                fields=fields,
                skip_paused=skip_paused,
                fq=fq,
                include_now=True,
                pre_factor_ref_date=pre_factor_ref_date)
        else:
            a, index = get_price_minute_single(
                s,
                end_dt=end_dt,
                start_dt=start_dt,
                count=count * group if count is not None else None,
                fields=fields,
                skip_paused=skip_paused,
                fq=fq,
                include_now=True,
                pre_factor_ref_date=pre_factor_ref_date)

        # group it
        dict_by_column = {
            f: group_array(a[f if f != 'price' else 'avg'], group, f)
            for f in fields
        }
        if index is not None and len(index) > 0:
            index = group_array(index, group, 'index')
            index = vec2datetime(index)
        res[s.code] = dict(index=index, columns=fields, data=dict_by_column)

    if is_list(security):
        fields = fields or DEFAULT_FIELDS
        if len(security) == 0:
            return pd.Panel(items=fields)
        pn_dict = {}
        index = res[security[0].code]['index']
        for f in fields:
            df_dict = {s.code: res[s.code]['data'][f] for s in security}
            pn_dict[f] = pd.DataFrame(index=index,
                                      columns=[s.code for s in security],
                                      data=df_dict)
        return pd.Panel(pn_dict)
    else:
        return pd.DataFrame(**res[security.code])
Ejemplo n.º 5
0
def history(end_dt,
            count,
            unit='1d',
            field='avg',
            security_list=None,
            df=True,
            skip_paused=False,
            fq='pre',
            pre_factor_ref_date=None):
    '''
    只能在回测/模拟中调用,不能再研究中调用。

    参数说明:

    `unit`是'Xd'时, end_dt的类型是datetime.date;
    `unit`是'Xm'时, end_dt的类型是datetime.datetime;

    `count` 必须大于0;

    `security_list`: 必须是 str 或者 tuple,不能是list 否则 lru_cache 会出错。
    '''
    count = int(count)
    assert count > 0, "history, count必须是一个正整数"
    check_unit_fields(unit, (field, ))
    if security_list is not None:
        security_list = list_or_str(security_list)
    if isinstance(security_list, tuple):
        security_list = list(security_list)
    security_list = convert_security(security_list)

    if field == 'price':
        warn_price_as_avg('使用 price 作为 history 的 field 参数', 'history')
        field = 'avg'

    group = int(unit[:-1])
    total = count * group
    dict_by_column = {}

    _index = None
    df = bool(df)
    skip_paused = bool(skip_paused)
    fq = ensure_fq(fq)
    if pre_factor_ref_date is not None:
        pre_factor_ref_date = convert_date(pre_factor_ref_date)
    need_index = df and not skip_paused

    if is_list(security_list) and _has_stock_with_future(security_list):
        if unit.endswith('m') and need_index:
            raise ParamsError("history 取分钟数据时,为了对齐数据,不能同时取股票和期货。")

    if unit.endswith('d'):
        end_dt = convert_date(end_dt)
        for security in security_list:
            a, _index = get_price_daily_single(
                security,
                end_date=end_dt,
                count=total,
                fields=(field, ),
                skip_paused=skip_paused,
                fq=fq,
                include_now=False,
                pre_factor_ref_date=pre_factor_ref_date)

            a = a[field]
            a = group_array(a, group, field)
            dict_by_column[security.code] = a
    else:
        end_dt = convert_dt(end_dt)
        for security in security_list:
            a, _index = get_price_minute_single(
                security,
                end_dt=end_dt,
                count=total,
                fields=(field, ),
                skip_paused=skip_paused,
                fq=fq,
                include_now=False,
                pre_factor_ref_date=pre_factor_ref_date)

            # 取第一列
            a = a[field]
            a = group_array(a, group, field)
            dict_by_column[security.code] = a

    if not df:
        return dict_by_column
    else:
        if need_index and _index is not None and len(_index) > 0:
            index = group_array(_index, group, 'index')
            index = vec2datetime(index)
        else:
            index = None
        return pd.DataFrame(index=index,
                            columns=[s.code for s in security_list],
                            data=dict_by_column)