Ejemplo n.º 1
0
 def get_index_stocks(self, index_symbol, date):
     assert isinstance(index_symbol, six.string_types)
     index_symbol = index_symbol.upper()
     session = get_session()
     idx = session.query(IndexEntity).get(index_symbol)
     if idx is not None:
         stocks = json.loads(idx.index_json)
     else:
         stocks = []
     if not stocks:
         if index_symbol not in SecurityStore.instance().get_all_indexs(
         ).keys():
             raise ParamsError("指数'%s'不存在" % index_symbol)
         else:
             return []
     if isinstance(date, (datetime.date, datetime.datetime)):
         date_s = date.strftime("%Y-%m-%d")
     elif isinstance(date, six.string_types):
         date_s = date
     else:
         raise ParamsError(
             "date参数必须是(datetime.date, datetime.datetime, str)中的一种")
     ret = []
     for code, periods in stocks:
         for i in range(0, len(periods) - 1, 2):
             # 不包括最后一天
             if periods[i] <= date_s and date_s < periods[i + 1]:
                 ret.append(code)
                 break
     return ret
Ejemplo n.º 2
0
    def get_concept_stocks(self, concept_code, date):
        assert isinstance(concept_code, six.string_types)
        concept_code = concept_code.upper()
        stocks = self._dic['stocks'].get(concept_code, [])
        if not stocks:
            if concept_code not in self._dic['name']:
                raise ParamsError("概念板块 '%s' 不存在" % concept_code)
            return []
        if isinstance(date, (datetime.date, datetime.datetime)):
            date_s = date.strftime("%Y-%m-%d")
        elif isinstance(date, six.string_types):
            date_s = date
        else:
            raise ParamsError(
                "date参数必须是(datetime.date, datetime.datetime, str)中的一种")

        ret = []
        for code, periods in stocks:
            if not SecurityStore.instance().exists(code):
                continue
            for i in range(0, len(periods) - 1, 2):
                # 不包括最后一天
                if periods[i] <= date_s and date_s < periods[i + 1]:
                    ret.append(code)
                    break
        return ret
Ejemplo n.º 3
0
    def get_concept_stocks(self, concept_code, date):
        assert isinstance(concept_code, six.string_types)
        concept_code = concept_code.upper()
        if isinstance(date, (datetime.date, datetime.datetime)):
            date_s = date.strftime("%Y-%m-%d")
        elif isinstance(date, six.string_types):
            date_s = date
        else:
            raise ParamsError(
                "date参数必须是(datetime.date, datetime.datetime, str)中的一种")

        session = get_session()
        count = session.query(ConceptEntity.stock).filter(
            ConceptEntity.code == concept_code).count()
        if count == 0:
            raise ParamsError("概念板块 '%s' 不存在" % concept_code)
        stocks = session.query(ConceptEntity.stock).filter(
            ConceptEntity.code == concept_code).filter(
                ConceptEntity.name != '').filter(
                    ConceptEntity.stock_startdate <= date_s).filter(
                        ConceptEntity.stock_enddate > date_s).distinct().all()
        ret = []
        for s, in stocks:
            ret.append(s)
        return ret
Ejemplo n.º 4
0
def normalize_code(code):
    '''
    上海证券交易所证券代码分配规则
    https://biz.sse.com.cn/cs/zhs/xxfw/flgz/rules/sserules/sseruler20090810a.pdf

    深圳证券交易所证券代码分配规则
    http://www.szse.cn/main/rule/bsywgz/39744233.shtml
    '''
    if isinstance(code, int):
        suffix = 'XSHG' if code >= 500000 else 'XSHE'
        return '%06d.%s' % (code, suffix)
    elif isinstance(code, six.string_types):
        code = code.upper()
        if code[-5:] in ('.XSHG', '.XSHE', '.CCFX'):
            return code
        suffix = None
        match = re.search(r'[0-9]{6}', code)
        if match is None:
            raise ParamsError(u"wrong code={}".format(code))
        number = match.group(0)
        if 'SH' in code:
            suffix = 'XSHG'
        elif 'SZ' in code:
            suffix = 'XSHE'

        if suffix is None:
            suffix = 'XSHG' if int(number) >= 500000 else 'XSHE'
        return '%s.%s' % (number, suffix)
    else:
        raise ParamsError(u"normalize_code(code=%s) 的参数必须是字符串或者整数" % code)
Ejemplo n.º 5
0
def convert_security(sec):
    if isinstance(sec, six.string_types):
        t = SecurityStore.instance().get_security(sec)
        if not t:
            raise ParamsError("找不到标的{}".format(sec))
        return t
    elif isinstance(sec, Security):
        return sec
    else:
        raise ParamsError('security 必须是一个Security对象')
Ejemplo n.º 6
0
def get_locked_shares(stock_list, start_date=None, end_date=None, forward_count=None):
    '''
    获取指定日期范围内的个股限售股解禁数据
    :param stock_list:单个股票或股票代码的列表
    :param start_date: 开始日期
    :param end_date: 结束日期
    :param forward_count: 交易日数量,与 end_date 不能同时使用。与 start_date 配合使用时, 表示获取 start_date 到 start_date+count-1个交易日期间的数据
    :return: dataframe
        |date|stock_code|num|rate1|rate2|
        |----------|-----------|--------|----|----|
        |2017-07-01|000001.XSHG|20000000|0.03|0.02|
        |2017-07-01|000001.XSHG|20000000|0.03|0.02|
     #### 注意单日个股多条解禁数据的问题 ####
    '''
    import pandas as pd
    from six import StringIO

    from ..utils.utils import convert_date, is_lists
    from ..db_utils import query, request_mysql_server
    if forward_count is not None and end_date is not None:
        raise ParamsError("get_locked_shares 不能同时指定 end_date 和 forward_count 两个参数")
    if forward_count is None and end_date is None:
        raise ParamsError("get_locked_shares 必须指定 end_date 或 forward_count 之一")
    start_date = convert_date(start_date)
    if not is_lists(stock_list):
        stock_list = [stock_list]
    if stock_list:
        stock_list = [s.split('.')[0] for s in stock_list]

    if forward_count is not None:
        end_date = start_date + datetime.timedelta(days=forward_count)

    end_date = convert_date(end_date)
    q = query(StkLockShares.day, StkLockShares.code, StkLockShares.num, StkLockShares.rate1,
              StkLockShares.rate2).filter(StkLockShares.code.in_(stock_list),
                                          StkLockShares.day <= end_date, StkLockShares.day >= start_date
                                          ).order_by(StkLockShares.day).order_by(StkLockShares.code.desc())

    sql = compile_query(q)
    cfg = get_config()
    if os.getenv('JQENV') == 'client':  # 客户端
        csv = request_mysql_server(sql)
    else:
        if not cfg.FUNDAMENTALS_SERVERS:
            raise RuntimeError(
                "you must config FUNDAMENTALS_SERVERS for jqdata")
        sql_runner = get_sql_runner(
            server_name='fundamentals', keep_connection=cfg.KEEP_DB_CONNECTION,
            retry_policy=cfg.DB_RETRY_POLICY, is_random=False)
        csv = sql_runner.run(sql, return_df=False)
    dtype_dict = {}
    dtype_dict['code'] = str
    df = pd.read_csv(StringIO(csv), dtype=dtype_dict)
    return df
    pass
Ejemplo n.º 7
0
def get_billboard_list(stock_list=None, start_date=None, end_date=None, count=None):
    '''
    返回执指定日期区间内的龙虎榜个股列表
    :param stock_list:单个股票或股票代码列表, 可以为 None, 返回股票的列表。
    :param start_date: 开始日期
    :param end_date: 结束日期
    :param count: 交易日数量,与 end_date 不能同时使用。与 start_date 配合使用时, 表示获取 start_date 到 start_date+count-1个交易日期间的数据
    :return:Dataframe
        |   date   | stock_code | abnormal_code |     abnormal_name        | sales_depart_name | abnormal_type | buy_value | buy_rate | sell_value | sell_rate | net_value | amount |
        |----------|------------|---------------|--------------------------|-------------------|---------------|-----------|----------|------------|-----------|-----------|--------|
        |2017-07-01| 000038.XSHE|     1         |日价格涨幅偏离值达7%以上的证券|        None       |      ALL      |  35298494 |0.37108699|  32098850  | 0.33744968|   3199644 |95121886|
    '''
    import pandas as pd

    from ..utils.utils import convert_date, is_lists
    from ..db_utils import query, request_mysql_server
    if count is not None and start_date is not None:
        raise ParamsError("get_billboard_list 不能同时指定 start_date 和 count 两个参数")
    if count is None and start_date is None:
        raise ParamsError("get_billboard_list 必须指定 start_date 或 count 之一")
    end_date = convert_date(end_date) if end_date else datetime.date.today()
    start_date = convert_date(start_date) if start_date else \
        (get_trade_days(end_date=end_date, count=count)[0] if count else TRADE_MIN_DATE)
    if not is_lists(stock_list):
        if stock_list is not None:
            stock_list = [stock_list]
    if stock_list:
        stock_list = [s.split('.')[0] for s in stock_list]

    q = query(StkAbnormal).filter(StkAbnormal.day <= end_date, StkAbnormal.day >= start_date)
    if stock_list is not None:
        q = q.filter(StkAbnormal.code.in_(stock_list))
    q = q.order_by(StkAbnormal.day.desc()).order_by(StkAbnormal.code.desc())
    sql = compile_query(q)
    cfg = get_config()
    if os.getenv('JQENV') == 'client':  # 客户端
        csv = request_mysql_server(sql)
        dtype_dict = {}
        dtype_dict['code'] = str
        df = pd.read_csv(six.StringIO(csv), dtype=dtype_dict)
    else:
        if not cfg.FUNDAMENTALS_SERVERS:
            raise RuntimeError(
                "you must config FUNDAMENTALS_SERVERS for jqdata")
        sql_runner = get_sql_runner(
            server_name='fundamentals', keep_connection=cfg.KEEP_DB_CONNECTION,
            retry_policy=cfg.DB_RETRY_POLICY, is_random=False)
        df = sql_runner.run(sql, return_df=True)
    return df
    pass
Ejemplo n.º 8
0
def get_extras(info,
               security_list,
               start_date=None,
               end_date=None,
               df=True,
               count=None):
    assert info in ('is_st', 'acc_net_value', 'unit_net_value',
                    'futures_sett_price', 'futures_positions')
    securities = list_or_str(security_list)
    securities = convert_security(securities)
    if start_date and count:
        raise ParamsError("start_date 参数与 count 参数只能二选一")
    if not (count is None or count > 0):
        raise ParamsError("count 参数需要大于 0 或者为 None")
    if count is not None:
        count = int(count)
    end_date = convert_date(end_date) if end_date else convert_date(
        '2015-12-31')

    from jqdata.stores import FundStore, StStore, FuturesStore, CalendarStore

    if start_date:
        start_date = convert_date(start_date)
    elif count:
        ix = CalendarStore.instance().get_trade_days_between(
            datetime.date(2005, 1, 4), end_date)
        start_date = ix[-count]
    else:
        start_date = convert_date('2015-01-01')
    df = bool(df)
    dates = CalendarStore.instance().get_trade_days_between(
        start_date, end_date)
    values = {}
    if info == 'is_st':
        for s in securities:
            values[s.code] = StStore.instance().query(s, dates)
    elif info in ('acc_net_value', 'unit_net_value'):
        for s in securities:
            values[s.code] = FundStore.instance().query(s, dates, info)
    elif info in ('futures_sett_price', 'futures_positions'):
        for s in securities:
            values[s.code] = FuturesStore.instance().query(s, dates, info)
    if df:
        columns = [s.code for s in securities]
        ret = dict(index=vec2combine(dates), columns=columns, data=values)
        ret = pd.DataFrame(**ret)
        return ret
    else:
        return values
Ejemplo n.º 9
0
def convert_dt(dt):
    """
    >>> convert_dt(datetime.date(2015, 1, 1))
    datetime.datetime(2015, 1, 1, 0, 0)

    >>> convert_dt(datetime.datetime(2015, 1, 1))
    datetime.datetime(2015, 1, 1, 0, 0)

    >>> convert_dt('2015-1-1')
    datetime.datetime(2015, 1, 1, 0, 0)

    >>> convert_dt('2015-01-01 09:30:00')
    datetime.datetime(2015, 1, 1, 9, 30)

    >>> convert_dt(datetime.datetime(2015, 1, 1, 9, 30))
    datetime.datetime(2015, 1, 1, 9, 30)
    """
    if is_str(dt):
        if ':' in dt:
            return datetime.datetime.strptime(dt, '%Y-%m-%d %H:%M:%S')
        else:
            return datetime.datetime.strptime(dt, '%Y-%m-%d')
    elif isinstance(dt, datetime.datetime):
        return dt
    elif isinstance(dt, datetime.date):
        return date2dt(dt)
    raise ParamsError(
        "date 必须是datetime.date, datetime.datetime或者如下格式的字符串:'2015-01-05'")
Ejemplo n.º 10
0
    def query(self, security, dates, field):
        n = len(dates)
        if n == 0:
            return np.array([])
        ct = self.open_table(security)
        start_ts = to_timestamp(dates[0])
        end_ts = to_timestamp(dates[-1])

        start_idx = ct.index.searchsorted(start_ts)
        end_idx = ct.index.searchsorted(end_ts, 'right') - 1
        if end_idx < start_idx:
            return np.array([np.nan] * n)
        if field == 'acc_net_value':
            name = 'acc'
        elif field == 'unit_net_value':
            name = 'unit'
        else:
            raise ParamsError(
                "field should in (acc_net_value, unit_net_value)")
        ret = np.round(ct.table.cols[name][start_idx:end_idx + 1],
                       security.price_decimals)

        if len(ret) < n:
            dates = list(dates)
            raw_list = [np.nan] * n
            ret_idx = 0
            for idx in range(start_idx, end_idx + 1):
                date = to_date(ct.table.cols['date'][idx])
                day_idx = dates.index(date)
                raw_list[day_idx] = ret[ret_idx]
                ret_idx += 1
            ret = np.array(raw_list)
        return ret
Ejemplo n.º 11
0
 def get_industries(self, name):
     assert isinstance(name, six.string_types)
     if name in self._dic['lists']:
         return self._dic['lists'][name]
     else:
         raise ParamsError("name 参数必须是zjw/jq_l1/jq_l2/sw_l1/sw_l2/sw_l3中的一种")
     pass
Ejemplo n.º 12
0
def get_ticks(security,
              end_dt,
              start_dt=None,
              count=None,
              fields=['time', 'current', 'high', 'low', 'volume', 'money']):
    from jqdata.stores.tick_store import get_tick_store
    from jqdata.utils.security import convert_security
    from jqdata.utils.datetime_utils import parse_datetime

    security = convert_security(security)
    end_dt = parse_datetime(end_dt)
    if start_dt is None and count is None:
        raise ParamsError("start_dt和count不能同时为None")
    elif start_dt is not None and count is not None:
        raise ParamsError("start_dt和count只能有一个不为None")
    if start_dt is not None:
        start_dt = parse_datetime(start_dt)
    if count is not None:
        count = int(count)
        assert count > 0, "get_ticks, count必须是一个正整数"

    store = get_tick_store()
    table = store.get_table(security)
    idx = table.find_great_or_equal(end_dt)
    if start_dt is not None:
        if start_dt > end_dt:
            raise ParamsError("start_dt 必须小于等于 end_dt")
        start = table.find_great_or_equal(start_dt)
    elif count is not None:
        start = max(0, idx - count)

    arr = table.array[start:idx]

    ret = {}
    for f in fields:
        if f in ('current', 'high', 'low', 'a1_p', 'b1_p'):
            ret[f] = arr[f] / 10000.
        elif f == 'time':
            ret[f] = arr[f] / 1000.
        else:
            ret[f] = arr[f] / 1.0
    dtype = np.dtype([(str(f), ret[f].dtype) for f in fields])
    cols = [ret[f] for f in fields]
    result = np.rec.fromarrays(cols, dtype=dtype).view(np.ndarray)
    return result
Ejemplo n.º 13
0
def check_date(date):
    if isinstance(date, (datetime.date, datetime.datetime)):
        date_s = date.strftime("%Y-%m-%d")
    elif isinstance(date, six.string_types):
        date_s = date
    else:
        raise ParamsError("date参数必须是(datetime.date, datetime.datetime, str)中的一种")
    return date_s
    pass
Ejemplo n.º 14
0
def convert_security_list(sec_list):
    if isinstance(sec_list, six.string_types):
        sec = convert_security(sec_list)
        return [sec]
    elif isinstance(sec_list, Security):
        return [sec_list]
    elif isinstance(sec_list, (list, tuple, set)):
        return [convert_security(o) for o in sec_list]
    else:
        raise ParamsError('security 必须是一个Security实例或者数组')
Ejemplo n.º 15
0
def get_trade_days(start_date=None, end_date=None, count=None):
    if start_date and count:
        raise ParamsError("start_date 参数与 count 参数只能二选一")
    if not (count is None or count > 0):
        raise ParamsError("count 参数需要大于 0 或者为 None")

    if not end_date:
        end_date = datetime.date.today()
    else:
        end_date = parse_date(end_date)

    store = get_calendar_store()
    if start_date:
        start_date = parse_date(start_date)
        return store.get_trade_days_between(start_date, end_date)
    elif count is not None:
        return store.get_trade_days_by_count(end_date, count)
    else:
        raise ParamsError("start_date 参数与 count 参数必须输入一个")
Ejemplo n.º 16
0
def convert_security(s):
    if isinstance(s, six.string_types):
        t = SecurityStore.instance().get_security(s)
        if not t:
            raise ParamsError("找不到标的{}".format(s))
        return t
    elif isinstance(s, Security):
        return s
    elif isinstance(s, (list, tuple)):
        res = []
        for i in range(len(s)):
            if isinstance(s[i], Security):
                res.append(s[i])
            elif isinstance(s[i], six.string_types):
                t = SecurityStore.instance().get_security(s[i])
                if not t:
                    raise ParamsError("找不到标的{}".format(s[i]))
                res.append(t)
            else:
                raise ParamsError("找不到标的{}".format(s[i]))
        return res
    else:
        raise ParamsError('security 必须是一个Security实例或者数组')
Ejemplo n.º 17
0
 def get_factor_by_date(self, code, date):
     factors = self._dic.get(code)
     if not factors:
         return 1.0
     if isinstance(date, (datetime.date, datetime.datetime)):
         date_s = date.strftime("%Y-%m-%d")
     elif isinstance(date, six.string_types):
         date_s = date
     else:
         raise ParamsError(
             "date参数必须是(datetime.date, datetime.datetime, str)中的一种")
     for i in range(len(factors) - 1, -1, -1):
         if factors[i][0] <= date_s:
             return factors[i][1]
     return 1.0
Ejemplo n.º 18
0
 def get_history_name(self, code, date):
     names = self._dic.get(code)
     if not names:
         return ''
     if isinstance(date, (datetime.date, datetime.datetime)):
         date_s = date.strftime("%Y-%m-%d")
     elif isinstance(date, six.string_types):
         date_s = date
     else:
         raise ParamsError(
             "date参数必须是(datetime.date, datetime.datetime, str)中的一种")
     for i in range(len(names) - 1, -1, -1):
         if names[i][0] <= date_s:
             return names[i][1]
     return ''
Ejemplo n.º 19
0
 def get_industry_stocks(self, industry_code, date):
     '''获取行业代码在指定日期的股票列表'''
     assert isinstance(industry_code, six.string_types)
     industry_code = industry_code.upper()
     date_s = check_date(date)
     session = get_session()
     count = session.query(IndustryEntity.code).filter(IndustryEntity.code == industry_code).count()
     if count == 0:
         raise ParamsError("行业板块 '%s' 不存在" % industry_code)
     stocks = session.query(IndustryEntity.stock).filter(IndustryEntity.code == industry_code).filter(
         IndustryEntity.stock_startdate <= date_s).filter(IndustryEntity.stock_enddate > date_s).filter(
         IndustryEntity.stock != '').distinct().all()
     ret = []
     for s, in stocks:
         ret.append(s)
     return ret
Ejemplo n.º 20
0
 def get_margin_stocks(self, day):
     if isinstance(day, (datetime.date, datetime.datetime)):
         day_s = day.strftime("%Y-%m-%d")
     elif isinstance(day, six.string_types):
         day_s = day
     else:
         raise ParamsError(
             "date参数必须是(datetime.date, datetime.datetime, str)中的一种")
     session = get_session()
     ms = session.query(MarginStockEntity).filter(
         MarginStockEntity.margin_date <= day_s).order_by(
             MarginStockEntity.margin_date.desc()).first()
     if ms is not None:
         res = json.loads(ms.margin_json)
     else:
         res = None
     if res:
         return res
     return []
Ejemplo n.º 21
0
def get_fundamentals(query_object=None, date=None, statDate=None, sql=None): # noqa
    if query_object is None and sql is None:
        raise ParamsError("get_fundamentals 至少输入 query_object 或者 sql 参数")

    cfg = get_config()
    if date:
        date = convert_date(date)
    if query_object:
        sql = fundamentals_query_to_sql(query_object, date, statDate)
    check_string(sql)
    if os.getenv('JQENV') == 'client':  # 客户端
        from jqdata.db_utils import request_mysql_server
        csv = request_mysql_server(sql)
    else:
        if not cfg.FUNDAMENTALS_SERVERS:
            raise RuntimeError("you must config FUNDAMENTALS_SERVERS for jqdata")
        sql_runner = get_sql_runner(
            server_name='fundamentals', keep_connection=cfg.KEEP_DB_CONNECTION,
            retry_policy=cfg.DB_RETRY_POLICY, is_random=False)
        # return csv 在转成 DataFrame, 跟kaunke保持兼容, 防止直接return df 跟以前不一样
        csv = sql_runner.run(sql, return_df=False)
    return pd.read_csv(StringIO(csv))
Ejemplo n.º 22
0
    def get_industry_stocks(self, industry_code, date):
        '''获取行业代码在指定日期的股票列表'''
        assert isinstance(industry_code, six.string_types)
        stocks = self._dic['stocks'].get(industry_code.upper())
        if not stocks:
            if industry_code not in self._dic['codes']['csrc'] and \
               industry_code not in self._dic['codes']['wind'] and \
               industry_code not in self._dic['codes']['sw']:
                raise ParamsError(u"行业'%s'不存在" % industry_code)
            return []
        ret = []
        date_s = check_date(date)
        for code, periods in stocks:
            if not SecurityStore.instance().exists(code):
                continue
            for i in range(0, len(periods) - 1, 2):
                # 不包括最后一天
                if periods[i] <= date_s and date_s < periods[i + 1]:
                    ret.append(code)
                    break

        return ret
Ejemplo n.º 23
0
def get_fundamentals_continuously(query_object=None, end_date=None, count=1):
    '''
    query_object:查询对象
    end_date:查询财务数据的截止日期
    count:查询财务数据前溯天数,默认为1
    返回一个pd.Panel, 三维分别是 field, date, security.
    field: 下面的表的中属性
    https://www.joinquant.com/data/dict/fundamentals
    '''
    if query_object is None:
        raise ParamsError("get_fundamentals_continuously 需要输入 query_object 参数")
    if end_date is None:
        end_date = datetime.date.today()
    cfg = get_config()

    trade_day = get_trade_days(end_date=end_date, count=count)
    if query_object:
        sql = fundamentals_continuously_query_to_sql(query_object, trade_day)
    check_string(sql)
    # 调用查询接口生成CSV格式字符串
    if os.getenv('JQENV') == 'client':  # 客户端
        from jqdata.db_utils import request_mysql_server
        csv = request_mysql_server(sql)
    else:
        if not cfg.FUNDAMENTALS_SERVERS:
            raise RuntimeError("you must config FUNDAMENTALS_SERVER for jqdata")
        sql_runner = get_sql_runner(
            server_name='fundamentals', keep_connection=cfg.KEEP_DB_CONNECTION,
            retry_policy=cfg.DB_RETRY_POLICY, is_random=False)
        csv = sql_runner.run(sql, return_df=False)
    # 转换成panel,设置时间和股票code为索引
    df = pd.read_csv(StringIO(csv))
    df = df.drop_duplicates()
    newdf = df.set_index(['day', 'code'])
    pan = newdf.to_panel()
    return pan
    pass
Ejemplo n.º 24
0
def convert_date(date):
    """
    >>> convert_date('2015-1-1')
    datetime.date(2015, 1, 1)

    >>> convert_date('2015-01-01 00:00:00')
    datetime.date(2015, 1, 1)

    >>> convert_date(datetime.datetime(2015, 1, 1))
    datetime.date(2015, 1, 1)

    >>> convert_date(datetime.date(2015, 1, 1))
    datetime.date(2015, 1, 1)
    """
    if is_str(date):
        if ':' in date:
            date = date[:10]
        return datetime.datetime.strptime(date, '%Y-%m-%d').date()
    elif isinstance(date, datetime.datetime):
        return date.date()
    elif isinstance(date, datetime.date):
        return date
    raise ParamsError(
        "date 必须是datetime.date, datetime.datetime或者如下格式的字符串:'2015-01-05'")
Ejemplo n.º 25
0
    def query(self, security, dates, field):
        n = len(dates)
        if n == 0:
            return np.array([])

        ct = self.open_table(security)
        start_ts = to_timestamp(dates[0])
        end_ts = to_timestamp(dates[-1])

        start_idx = ct.index.searchsorted(start_ts)
        end_idx = ct.index.searchsorted(end_ts, 'right') - 1
        if end_idx < start_idx:
            return np.array([np.nan] * n)
        if field == 'futures_sett_price':
            name = 'settlement'
        elif field == 'futures_positions':
            name = 'open_interest'
        else:
            raise ParamsError("filed should in (futures_sett_price, open_interest)")

        ret = np.round(ct.table.cols[name][start_idx:end_idx+1], security.price_decimals)

        if len(ret) < n:
            st = to_date(ct.table.cols['date'][start_idx])
            et = to_date(ct.table.cols['date'][end_idx])
            for i in range(0, n):
                if dates[i] >= st:
                    break
            if i > 0:
                ret = np.concatenate([np.array([np.nan]*i), ret])
            for i in range(n-1, -1, -1):
                if dates[i] <= et:
                    break
            if i < n - 1:
                ret = np.concatenate([ret, np.array([np.nan] * (n - i - 1))])
        return ret
Ejemplo n.º 26
0
def get_mtss(security_list, start_date=None, end_date=None, fields=None, count=None):
    """
    获取融资融券信息

    security_list: 股票代码或者 list
    start_date: 开始日期, **与 count 二选一, 不可同时使用**. str/datetime.date/datetime.datetime 对象, 默认为平台提供的数据的最早日期
    end_date: 结束日期, str/datetime.date/datetime.datetime 对象, 默认为 datetime.date.today()
    fields: 字段名或者 list, 可选, 默认全部字段
    count: 数量, **与 start_date 二选一,不可同时使用**. 表示返回 end_date 之前 count 个交易日的数据, 包含 end_date

    返回pd.DataFrame, columns:
    日期,股票代码, 融资余额,融资买入额,融资偿还额,融券余额,融资卖出额,融资偿还额,融资融券余额
    date, sec_code, fin_value, fin_buy_value, fin_refund_value, sec_value, sec_sell_value, sec_refund_value,
    fin_sec_value
    """
    if start_date and count:
        raise ParamsError("start_date 参数与 count 参数只能二选一")
    if not (count is None or count > 0):
        raise ParamsError("count 参数需要大于 0 或者为 None")
    if count:
        count = int(count)
    security_list = obj_to_tuple(security_list)
    check_string_list(security_list)

    end_date = convert_date(end_date) if end_date else datetime.date.today()
    start_date = convert_date(start_date) if start_date else \
        (get_trade_days(end_date=end_date, count=count)[0] if count else TRADE_MIN_DATE)

    keys = ["sec_code", "date", "fin_value", "fin_buy_value", "fin_refund_value",
            "sec_value", "sec_sell_value", "sec_refund_value", "fin_sec_value"]
    nkeys = len(keys)
    if fields:
        fields = obj_to_tuple(fields)
        check_string_list(fields)
        check_fields(keys, fields)
    else:
        fields = ["date", "sec_code"] + keys[2:]

    request_path = "/stock/mtss/query"
    request_params = {
        "code": "",
        "startDate": start_date,
        "endDate": end_date,
    }

    lists = []
    convert_funcs = [str, convert_dt, float, float, float, float, float, float, float]
    for security in security_list:
        request_params["code"] = security
        if os.getenv('JQENV') == 'client': # 客户端
            data = request_client_data(request_path, request_params)
        else:
            data = request_data(DATA_SERVER + request_path, request_params)
        for d in data:
            values = [item.strip() for item in d.split(",", nkeys)[:nkeys]]
            values = [convert_funcs[i](v) for i, v in enumerate(values)]
            sec_dict = dict(zip(keys, values))
            sec_dict["sec_code"] = normalize_code(sec_dict["sec_code"])
            lists.append(filter_dict_values(sec_dict, fields))

    import pandas as pd
    df = pd.DataFrame(columns=fields, data=lists)
    return df
Ejemplo n.º 27
0
def history(end_dt,
            count,
            unit='1d',
            field='avg',
            security_list=None,
            df=True,
            skip_paused=False,
            fq='pre',
            pre_factor_ref_date=None):
    '''
    只能在回测/模拟中调用,不能再研究中调用。

    参数说明:

    `unit`是'Xd'时, end_dt的类型是datetime.date;
    `unit`是'Xm'时, end_dt的类型是datetime.datetime;

    `count` 必须大于0;

    `security_list`: 必须是 str 或者 tuple,不能是list 否则 lru_cache 会出错。
    '''
    count = int(count)
    assert count > 0, "history, count必须是一个正整数"
    check_unit_fields(unit, (field, ))
    if security_list is not None:
        security_list = list_or_str(security_list)
    if isinstance(security_list, tuple):
        security_list = list(security_list)
    security_list = convert_security(security_list)

    if field == 'price':
        warn_price_as_avg('使用 price 作为 history 的 field 参数', 'history')
        field = 'avg'

    group = int(unit[:-1])
    total = count * group
    dict_by_column = {}

    _index = None
    df = bool(df)
    skip_paused = bool(skip_paused)
    fq = ensure_fq(fq)
    if pre_factor_ref_date is not None:
        pre_factor_ref_date = convert_date(pre_factor_ref_date)
    need_index = df and not skip_paused

    if is_list(security_list) and _has_stock_with_future(security_list):
        if unit.endswith('m') and need_index:
            raise ParamsError("history 取分钟数据时,为了对齐数据,不能同时取股票和期货。")

    if unit.endswith('d'):
        end_dt = convert_date(end_dt)
        for security in security_list:
            a, _index = get_price_daily_single(
                security,
                end_date=end_dt,
                count=total,
                fields=(field, ),
                skip_paused=skip_paused,
                fq=fq,
                include_now=False,
                pre_factor_ref_date=pre_factor_ref_date)

            a = a[field]
            a = group_array(a, group, field)
            dict_by_column[security.code] = a
    else:
        end_dt = convert_dt(end_dt)
        for security in security_list:
            a, _index = get_price_minute_single(
                security,
                end_dt=end_dt,
                count=total,
                fields=(field, ),
                skip_paused=skip_paused,
                fq=fq,
                include_now=False,
                pre_factor_ref_date=pre_factor_ref_date)

            # 取第一列
            a = a[field]
            a = group_array(a, group, field)
            dict_by_column[security.code] = a

    if not df:
        return dict_by_column
    else:
        if need_index and _index is not None and len(_index) > 0:
            index = group_array(_index, group, 'index')
            index = vec2datetime(index)
        else:
            index = None
        return pd.DataFrame(index=index,
                            columns=[s.code for s in security_list],
                            data=dict_by_column)
Ejemplo n.º 28
0
def get_price(security,
              start_date=None,
              end_date=None,
              frequency='daily',
              fields=None,
              skip_paused=False,
              fq='pre',
              count=None,
              pre_factor_ref_date=None):

    security = convert_security(security)

    if count is not None and start_date is not None:
        raise ParamsError("get_price 不能同时指定 start_date 和 count 两个参数")

    if count is not None:
        count = int(count)

    end_dt = convert_dt(end_date) if end_date else datetime.datetime(
        2015, 12, 31)
    end_dt = min(end_dt, date2dt(CalendarStore.instance().last_day))

    start_dt = convert_dt(start_date) if start_date else datetime.datetime(
        2015, 1, 1)
    start_dt = max(start_dt, date2dt(CalendarStore.instance().first_day))

    if pre_factor_ref_date:
        pre_factor_ref_date = convert_date(pre_factor_ref_date)

    if frequency in frequency_compat:
        unit = frequency_compat.get(frequency)
    else:
        unit = frequency

    if fields is not None:
        fields = ensure_str_tuple(fields)
        if 'price' in fields:
            warn_price_as_avg('使用 price 作为 get_price 的 fields 参数', 'getprice')
    else:
        fields = tuple(DEFAULT_FIELDS)

    check_unit_fields(unit, fields)
    fq = ensure_fq(fq)
    skip_paused = bool(skip_paused)
    if is_list(security) and skip_paused:
        raise ParamsError("get_price 取多只股票数据时, 为了对齐日期, 不能跳过停牌")

    if is_list(security) and _has_stock_with_future(security):
        if unit.endswith('m'):
            raise ParamsError("get_price 取分钟数据时,为了对齐数据,不能同时取股票和期货。")

    group = int(unit[:-1])
    res = {}
    for s in (security if is_list(security) else [security]):
        if unit.endswith('d'):
            a, index = get_price_daily_single(
                s,
                end_date=end_dt.date(),
                start_date=start_dt.date() if start_dt else None,
                count=count * group if count is not None else None,
                fields=fields,
                skip_paused=skip_paused,
                fq=fq,
                include_now=True,
                pre_factor_ref_date=pre_factor_ref_date)
        else:
            a, index = get_price_minute_single(
                s,
                end_dt=end_dt,
                start_dt=start_dt,
                count=count * group if count is not None else None,
                fields=fields,
                skip_paused=skip_paused,
                fq=fq,
                include_now=True,
                pre_factor_ref_date=pre_factor_ref_date)

        # group it
        dict_by_column = {
            f: group_array(a[f if f != 'price' else 'avg'], group, f)
            for f in fields
        }
        if index is not None and len(index) > 0:
            index = group_array(index, group, 'index')
            index = vec2datetime(index)
        res[s.code] = dict(index=index, columns=fields, data=dict_by_column)

    if is_list(security):
        fields = fields or DEFAULT_FIELDS
        if len(security) == 0:
            return pd.Panel(items=fields)
        pn_dict = {}
        index = res[security[0].code]['index']
        for f in fields:
            df_dict = {s.code: res[s.code]['data'][f] for s in security}
            pn_dict[f] = pd.DataFrame(index=index,
                                      columns=[s.code for s in security],
                                      data=df_dict)
        return pd.Panel(pn_dict)
    else:
        return pd.DataFrame(**res[security.code])
Ejemplo n.º 29
0
def get_bars(end_dt,
             security,
             count,
             unit='1d',
             fields=('open', 'high', 'low', 'close'),
             include_now=False,
             fq=None,
             pre_factor_ref_date=None):
    '''
    :param end_dt: 截止日期
    :param security: 标的
    :param count:   bar个数
    :param unit: 频率,'1d'表示1天,'xm'表示x分钟。
    :param fields:
    :param include_now:
    :param fq: 'pre'表示前复权, 'post'表示后复权, None表示真实价格。
    :param pre_factor_ref_date: 前复权基准日期,这一天的价格为真实价格。None则表示全部取真实价格。
    :return:
    '''
    valid_bar_fields = ('date', 'open', 'close', 'high', 'low', 'volume',
                        'money')
    if isinstance(fields, (list, tuple)):
        for f in fields:
            assert f in valid_bar_fields, "get_bars 只支持 %s 字段" % (
                valid_bar_fields)
        str_field = False
    elif isinstance(fields, six.string_types):
        assert fields in valid_bar_fields, "get_bars 只支持 %s 字段" % (
            valid_bar_fields)
        str_field = True
    else:
        raise ParamsError("fields 应该是字符串或者list")

    if str_field:
        new_fields = [fields]
    else:
        new_fields = [i for i in fields]
    if 'factor' not in fields:
        new_fields.append('factor')
    end_dt = convert_dt(end_dt)
    valid_unit = ('1m', '5m', '15m', '30m', '60m', '120m', '1d', '1w', '1M')
    assert unit in valid_unit, 'get_bars, unit必须是 %s 中一种' % valid_unit
    count = int(count)
    assert count > 0, "get_bars, count必须是一个正整数"
    fq = ensure_fq(fq)
    security = convert_security(security)
    include_now = bool(include_now)

    end_trade_date = CalendarStore.instance().get_current_trade_date(
        security, end_dt)

    def ensure_not_empty(cols_dict):
        if cols_dict == {}:
            ret = {}
            for f in valid_bar_fields:
                ret[f] = np.zeros(0)
            return ret
        return cols_dict

    if unit == '1d':
        if include_now:
            # 获取当天的snapshot
            snapshot = get_snapshot(security, end_trade_date, end_dt)
            if snapshot:
                cols_dict = get_daily_bar_by_count(security,
                                                   end_trade_date,
                                                   count - 1,
                                                   new_fields,
                                                   include_now=False)
                cols_dict = ensure_not_empty(cols_dict)
                for f in cols_dict:
                    cols_dict[f] = np.append(cols_dict[f], snapshot[f])
            else:
                cols_dict = get_daily_bar_by_count(security,
                                                   end_trade_date,
                                                   count,
                                                   new_fields,
                                                   include_now=False)
                cols_dict = ensure_not_empty(cols_dict)
        else:
            cols_dict = get_daily_bar_by_count(security,
                                               end_trade_date,
                                               count,
                                               new_fields,
                                               include_now=False)
            cols_dict = ensure_not_empty(cols_dict)
    elif unit == '1m':
        end_dt = convert_dt(end_dt)
        cols_dict = get_minute_bar_by_count(security,
                                            end_dt,
                                            count,
                                            new_fields,
                                            include_now=include_now)
        cols_dict = ensure_not_empty(cols_dict)
    elif unit in ('5m', '15m', '30m', '60m', '120m'):
        x = int(unit[:-1])
        if security.is_futures():
            trade_days = CalendarStore.instance().get_all_trade_days(security)
            trade_days = trade_days[(trade_days >= security.start_date )&\
                                    (trade_days <= end_trade_date)&\
                                    (trade_days <= security.end_date)]
            cols_dict = {f: np.zeros(0) for f in new_fields}

            for idx in range(len(trade_days) - 1, -1, -1):
                open_dt = CalendarStore.instance().get_open_dt(
                    security, trade_days[idx])
                if trade_days[idx] == end_trade_date:
                    if not include_now:
                        close_dt = _not_include_now(security, end_trade_date,
                                                    end_dt, unit)
                    else:
                        close_dt = end_dt
                else:
                    close_dt = CalendarStore.instance().get_close_dt(
                        security, trade_days[idx])
                tmp_dict = get_minute_bar_by_period(security,
                                                    open_dt,
                                                    close_dt,
                                                    new_fields,
                                                    include_now=True)
                if not tmp_dict or len(tmp_dict[new_fields[0]] == 0):
                    continue
                tmp_dict = _resample_future_xm_bars(tmp_dict, x)
                for col in cols_dict:
                    cols_dict[col] = np.append(tmp_dict[col], cols_dict[f])
                if len(cols_dict[new_fields[0]]) >= count:
                    break
            for f in cols_dict:
                cols_dict[f] = cols_dict[f][-count:]
        else:
            cols_dict = {f: np.zeros(0) for f in new_fields}
            open_dt = CalendarStore.instance().get_open_dt(
                security, end_dt.date())
            if not include_now:
                close_dt = _not_include_now(security, end_trade_date, end_dt,
                                            unit)
            else:
                close_dt = end_dt
            tmp_dict = get_minute_bar_by_period(security,
                                                open_dt,
                                                close_dt,
                                                new_fields,
                                                include_now=True)
            if tmp_dict and len(tmp_dict[new_fields[0]]) > 0:
                tmp_dict = _resample_simple_xm_bars(tmp_dict, x)
                for col in cols_dict:
                    cols_dict[col] = np.append(tmp_dict[col], cols_dict[col])
            need_count = count - len(cols_dict[new_fields[0]])
            if need_count > 0:
                tmp_dict = get_minute_bar_by_count(security,
                                                   open_dt,
                                                   need_count * x,
                                                   new_fields,
                                                   include_now=False)
                tmp_dict = _resample_simple_xm_bars(tmp_dict, x)
                for col in cols_dict:
                    cols_dict[col] = np.append(tmp_dict[col], cols_dict[col])
            for f in cols_dict:
                cols_dict[f] = cols_dict[f][-count:]

    # 周线和月线必须先复权,然后 resample
    elif unit == '1w':
        if include_now:
            snapshot = get_snapshot(security, end_trade_date, end_dt)
            cols_dict = get_daily_bar_by_count(security,
                                               end_trade_date,
                                               count * 5,
                                               new_fields,
                                               include_now=False)
            cols_dict = ensure_not_empty(cols_dict)
            if snapshot:
                for f in cols_dict:
                    cols_dict[f] = np.append(cols_dict[f], snapshot[f])
        else:
            # monday == 0 ... Sunday == 6
            weekday = end_trade_date.weekday()
            last_sunday = end_trade_date - datetime.timedelta(weekday + 1)
            cols_dict = get_daily_bar_by_count(security,
                                               last_sunday,
                                               count * 5,
                                               new_fields,
                                               include_now=False)
            cols_dict = ensure_not_empty(cols_dict)
        cols_dict = _pre_fq(security, cols_dict, pre_factor_ref_date)
        cols_dict = _resample_days_bars(cols_dict, unit)
        for f in cols_dict:
            cols_dict[f] = cols_dict[f][-count:]
    elif unit == '1M':
        if include_now:
            snapshot = get_snapshot(security, end_trade_date, end_dt)
            cols_dict = get_daily_bar_by_count(security,
                                               end_trade_date,
                                               count * 31,
                                               new_fields,
                                               include_now=False)
            cols_dict = ensure_not_empty(cols_dict)
            if snapshot:
                for f in cols_dict:
                    cols_dict[f] = np.append(cols_dict[f], snapshot[f])
        else:
            end_date = end_trade_date.replace(day=1) - datetime.timedelta(
                days=1)
            cols_dict = get_daily_bar_by_count(security,
                                               end_date,
                                               count * 31,
                                               new_fields,
                                               include_now=False)
            cols_dict = ensure_not_empty(cols_dict)
        cols_dict = _pre_fq(security, cols_dict, pre_factor_ref_date)
        cols_dict = _resample_days_bars(cols_dict, unit)
        for f in cols_dict:
            cols_dict[f] = cols_dict[f][-count:]
    else:
        raise ParamsError("get_bars 支持 '1m', '1d'")

    # 将时间戳转换成datetime 或者 date。
    if 'date' in cols_dict:
        if unit in ('1d', '1w', '1M'):
            cols_dict['date'] = vec2date(cols_dict['date'])
        else:
            cols_dict['date'] = vec2datetime(cols_dict['date'])

    # 期货没有复权。
    # 周线和月线必须先复权(同一个周期内复权因子可能不同)
    if not security.is_futures() and unit not in ('1w', '1M'):
        if pre_factor_ref_date is not None:
            cols_dict = _pre_fq(security, cols_dict, pre_factor_ref_date)

    if str_field:
        dtype = np.dtype([(fields, cols_dict[fields].dtype)])
        cols = [cols_dict[fields]]
        result = np.rec.fromarrays(cols, dtype=dtype).view(np.ndarray)
    else:
        # numpy bug: name 不能为unicode。
        dtype = np.dtype([(str(name), cols_dict[name].dtype)
                          for name in fields])
        cols = [cols_dict[name] for name in fields]
        result = np.rec.fromarrays(cols, dtype=dtype).view(np.ndarray)
    return result
Ejemplo n.º 30
0
def calc_stock_minute_index(security,
                            date_index,
                            dt,
                            side='left',
                            include_now=True):
    '''
    # 计算dt 在分钟数据中的索引(只对股票有效)

    如果 dt 不存在 date_index ,
    side=left 表示返回左边的index,
    side=right 表示返回右边的index
    -1 表示没有

    include_now 表示是否包含 dt
    '''
    if isinstance(dt, datetime):
        ts = to_timestamp(dt.date())
    elif isinstance(dt, pd.Timestamp):
        ts = int(dt.value / (10**9) / 86400) * 86400
    else:
        raise ParamsError("wrong dt=%s, type(dt)=%s" % (dt, type(dt)))

    total_days = date_index.searchsorted(ts, side='right') - 1
    if total_days < 0:
        return -1
    if date_index[total_days] == ts:
        trading_day = True
    else:
        trading_day = False
        total_days += 1
    # 前面一共有total_days个交易日
    dt_minutes = dt.hour * 60 + dt.minute
    total_minutes = 0
    if trading_day:
        # 9:31 之前
        if dt_minutes < (9 * 60 + 31):
            if side == 'left':
                total_minutes = 0
            else:
                total_minutes = 1
        elif dt_minutes < (11 * 60 + 31):
            total_minutes = dt_minutes - (9 * 60 + 30)
            if not include_now:
                total_minutes -= 1
        elif dt_minutes < 13 * 60 + 1:
            if side == 'left':
                total_minutes = 120
            else:
                total_minutes = 121
        elif dt_minutes < 15 * 60 + 1:
            total_minutes = 120 + dt_minutes - 13 * 60
            if not include_now:
                total_minutes -= 1
        else:
            if side == 'left':
                total_minutes = 240
            else:
                total_minutes = 241
    else:
        if side == 'left':
            total_minutes = 0
        else:
            total_minutes = 1

    # 下标从0开始。 -1 表示 没有。
    return (total_days * 240 + total_minutes) - 1