Esempio n. 1
0
def load_fin_data_cwbbzy():
    """导入上市公司财务报表摘要"""
    cfg = ConfigParser()
    cfg.read('config.ini')
    cwbbzy_url = cfg.get('fin_data', 'cwbbzy_url')
    cwbbzy_path = os.path.join(cfg.get('factor_db', 'db_path'),
                               cfg.get('fin_data', 'cwbbzy_path'))
    # 读取个股代码
    data_api = DataApi(addr='tcp://data.tushare.org:8910')
    data_api.login(
        '13811931480',
        'eyJhbGciOiJIUzI1NiJ9.eyJjcmVhdGVfdGltZSI6IjE1MTI4Nzk0NTI2MjkiLCJpc3MiOiJhdXRoMCIsImlkIjoiMTM4MTE5MzE0ODAifQ.I0SXsA1bK--fbGu0B5Is2xdKOjALAeWBJRX6GdVmUL8'
    )
    df_stock_basics, msg = data_api.query(
        view='jz.instrumentInfo',
        fields='status,list_date,name,market',
        filter='inst_type=1&status=1&market=SH,SZ&symbol=',
        data_format='pandas')
    if msg != '0,':
        print('读取市场个股代码失败。')
        return
    df_stock_basics.symbol = df_stock_basics.symbol.map(
        lambda x: x.split('.')[0])
    # 遍历个股, 下载财务报表摘要数据
    for _, stock_info in df_stock_basics.iterrows():
        url = cwbbzy_url % stock_info.symbol
        resp = requests.get(url)
        if resp.status_code != requests.codes.ok:
            print('%s的财务报表摘要数据下载失败!' % stock_info.symbol)
            continue
        print('下载%s的财务报表摘要数据.' % stock_info.symbol)
        fin_data = resp.text
        tmp = fin_data.split(',')[-1]
        fin_data = fin_data.replace(tmp, '')
        fin_data = fin_data.split('\r\n')
        fin_datas = []
        for data in fin_data:
            s = data.split(',')
            fin_datas.append(s[:-1])
        n = min([len(data) for data in fin_datas])
        dict_fin_data = {data[0]: data[1:n] for data in fin_datas}
        fin_header = [data[0] for data in fin_datas]
        df_fin_data = DataFrame(dict_fin_data, columns=fin_header)
        df_fin_data = df_fin_data.sort_values(by=fin_header[0])
        df_fin_data.to_csv(os.path.join(
            cwbbzy_path, '%s.csv' % utils.code_to_symbol(stock_info.symbol)),
                           index=False)
Esempio n. 2
0
def load_ipo_info():
    """从网易财经下载个股的IPO数据"""
    cfg = ConfigParser()
    cfg.read('config.ini')
    ipo_info_url = cfg.get('ipo_info', 'ipo_info_url')
    db_path = Path(cfg.get('factor_db', 'db_path'), cfg.get('ipo_info', 'db_path'))
    # 读取所有已上市个股代码
    data_api = DataApi(addr='tcp://data.tushare.org:8910')
    data_api.login('13811931480', 'eyJhbGciOiJIUzI1NiJ9.eyJjcmVhdGVfdGltZSI6IjE1MTI4Nzk0NTI2MjkiLCJpc3MiOiJhdXRoMCIsImlkIjoiMTM4MTE5MzE0ODAifQ.I0SXsA1bK--fbGu0B5Is2xdKOjALAeWBJRX6GdVmUL8')
    df_stock_basics, msg = data_api.query(view='jz.instrumentInfo',
                                          fields='status,list_date,name,market',
                                          filter='inst_type=1&status=&market=SH,SZ&symbol=',
                                          data_format='pandas')
    if msg != '0,':
        print('读取市场个股代码失败。')
        return
    df_stock_basics.symbol = df_stock_basics.symbol.map(lambda x: x.split('.')[0])
    # 遍历个股, 下载ipo信息数据
    df_ipo_info = DataFrame()
    for _, stock_info in df_stock_basics.iterrows():
        print('下载%s的IPO数据.' % stock_info.symbol)
        ipo_info_header = []
        ipo_info_data = []

        secu_code = utils.code_to_symbol(stock_info.symbol)
        url = ipo_info_url % stock_info.symbol
        html = requests.get(url).content
        soup = BeautifulSoup(html, 'html.parser')
        tags = soup.find_all(name='h2')
        for tag in tags:
            if tag.get_text().strip() == 'IPO资料':
                ipo_table = tag.find_next(name='table')
                for tr in ipo_table.find_all(name='tr'):
                    tds = tr.find_all(name='td')
                    name = tds[0].get_text().replace(' ', '').replace('\n', '').replace('\r', '')
                    value = tds[1].get_text().replace(' ', '').replace(',','').replace('\n', '').replace('\r', '')
                    ipo_info_header.append(name)
                    ipo_info_data.append(value)
                ipo_info = Series(ipo_info_data, index=ipo_info_header)
                ipo_info['代码'] = secu_code
                ipo_info.to_csv(db_path.joinpath('%s.csv' % secu_code))
                df_ipo_info = df_ipo_info.append(ipo_info, ignore_index=True)
                break
    df_ipo_info.to_csv(db_path.joinpath('ipo_info.csv'), index=False)
Esempio n. 3
0
def get_index_basic_information():
    api = DataApi(addr=data_config.get('remote.data.address'))
    api.login(data_config.get('remote.data.username'), data_config.get('remote.data.password'))
    df, msg = api.query(
        view="lb.indexInfo",
        fields="symbol,name,listdate,expire_date",
        filter="symbol=" + index,
        data_format='pandas')

    # 判断今天的日期是否超出了指数的终止发布日期, 如果超出,以终止发布日期为准
    # 开始日期等于指数发布日期
    today = dt.today().strftime('%Y%m%d')
    if df['expire_date'][0]:
        end_date = min(today, df['expire_date'][0])
    else:
        end_date = today

    symbol, name, start_date = df['symbol'][0], df['name'][0], df['listdate'][0]
    # print(symbol, name, start_date, end_date)
    return (symbol, name, start_date, end_date)
Esempio n. 4
0
def test_data_api():
    dic = fileio.read_json(fileio.join_relative_path('etc/data_config.json'))
    address = dic.get("remote.address", None)
    username = dic.get("remote.username", None)
    password = dic.get("remote.password", None)
    if address is None or username is None or password is None:
        raise ValueError("no data service config available!")

    api = DataApi(address, use_jrpc=False)
    login_msg = api.login(username=username, password=password)
    print login_msg

    daily, msg = api.daily(
        symbol="600030.SH,000002.SZ",
        start_date=20170103,
        end_date=20170708,
        fields="open,high,low,close,volume,last,trade_date,settle")
    daily2, msg2 = api.daily(
        symbol="600030.SH",
        start_date=20170103,
        end_date=20170708,
        fields="open,high,low,close,volume,last,trade_date,settle")
    # err_code, err_msg = msg.split(',')
    assert msg == '0,'
    assert msg2 == '0,'
    assert daily.shape == (248, 8)
    assert daily2.shape == (124, 8)

    df, msg = api.bar(symbol="600030.SH",
                      trade_date=20170904,
                      freq='1m',
                      start_time=90000,
                      end_time=150000)
    print df.columns
    assert df.shape == (240, 15)

    print "test passed"
# coding: utf-8

# In[1]:


from jaqs.data.dataapi import DataApi


# In[2]:


api = DataApi(addr="tcp://192.168.1.117:8910") # 根据本地实际情况修改
api.login("25","25")


# In[3]:


symbol = 'rb1805.SHF'
fields = 'open,high,low,last,volume'

# 获取实时行情
df, msg = api.quote(symbol=symbol, fields=fields)
print(df)
print(msg)


# In[14]:

Esempio n. 6
0
class RemoteDataService(DataService):
    """
    RemoteDataService is a concrete class using data from remote server's database.

    """
    __metaclass__ = Singleton

    # TODO no validity check for input parameters

    def __init__(self):
        DataService.__init__(self)

        dic = fileio.read_json(
            fileio.join_relative_path('etc/data_config.json'))
        address = dic.get("remote.address", None)
        username = dic.get("remote.username", None)
        password = dic.get("remote.password", None)
        if address is None or username is None or password is None:
            raise ValueError("no address, username or password available!")

        self.api = DataApi(address, use_jrpc=False)
        self.api.set_timeout(60)
        r, msg = self.api.login(username=username, password=password)
        if not r:
            print msg
        else:
            print "DataAPI login success.".format(address)

        self.REPORT_DATE_FIELD_NAME = 'report_date'

    def daily(self, symbol, start_date, end_date, fields="", adjust_mode=None):
        df, err_msg = self.api.daily(symbol=symbol,
                                     start_date=start_date,
                                     end_date=end_date,
                                     fields=fields,
                                     adjust_mode=adjust_mode,
                                     data_format="")
        # trade_status performance warning
        # TODO there will be duplicate entries when on stocks' IPO day
        df = df.drop_duplicates()
        return df, err_msg

    def bar(self,
            symbol,
            start_time=200000,
            end_time=160000,
            trade_date=None,
            freq='1m',
            fields=""):
        df, msg = self.api.bar(symbol=symbol,
                               fields=fields,
                               start_time=start_time,
                               end_time=end_time,
                               trade_date=trade_date,
                               freq='1m',
                               data_format="")
        return df, msg

    def query(self, view, filter="", fields="", **kwargs):
        """
        Get various reference data.
        
        Parameters
        ----------
        view : str
            data source.
        fields : str
            Separated by ','
        filter : str
            filter expressions.
        kwargs

        Returns
        -------
        df : pd.DataFrame
        msg : str
            error code and error message, joined by ','
        
        Examples
        --------
        res3, msg3 = ds.query("lb.secDailyIndicator", fields="price_level,high_52w_adj,low_52w_adj",
                              filter="start_date=20170907&end_date=20170907",
                              orderby="trade_date",
                              data_format='pandas')
            view does not change. fileds can be any field predefined in reference data api.

        """
        df, msg = self.api.query(view,
                                 fields=fields,
                                 filter=filter,
                                 data_format="",
                                 **kwargs)
        return df, msg

    def get_suspensions(self):
        return None

    # TODO use Calendar instead
    def get_trade_date(self,
                       start_date,
                       end_date,
                       symbol=None,
                       is_datetime=False):
        if symbol is None:
            symbol = '000300.SH'
        df, msg = self.daily(symbol, start_date, end_date, fields="close")
        res = df.loc[:, 'trade_date'].values
        if is_datetime:
            res = dtutil.convert_int_to_datetime(res)
        return res

    @staticmethod
    def _dic2url(d):
        """
        Convert a dict to str like 'k1=v1&k2=v2'
        
        Parameters
        ----------
        d : dict

        Returns
        -------
        str

        """
        l = ['='.join([key, str(value)]) for key, value in d.items()]
        return '&'.join(l)

    def query_lb_fin_stat(self,
                          type_,
                          symbol,
                          start_date,
                          end_date,
                          fields=""):
        """
        Helper function to call data_api.query with 'lb.income' more conveniently.
        
        Parameters
        ----------
        type_ : {'income', 'balance_sheet', 'cash_flow'}
        symbol : str
            separated by ','
        start_date : int
            Annoucement date in results will be no earlier than start_date
        end_date : int
            Annoucement date in results will be no later than start_date
        fields : str, optional
            separated by ',', default ""

        Returns
        -------
        df : pd.DataFrame
            index date, columns fields
        msg : str

        """
        view_map = {
            'income': 'lb.income',
            'cash_flow': 'lb.cashFlow',
            'balance_sheet': 'lb.balanceSheet',
            'fin_indicator': 'lb.finIndicator'
        }
        view_name = view_map.get(type_, None)
        if view_name is None:
            raise NotImplementedError("type_ = {:s}".format(type_))

        dic_argument = {
            'symbol': symbol,
            'start_date': start_date,
            'end_date': end_date,
            'update_flag': '0'
        }
        if view_name != 'lb.finIndicator':
            dic_argument.update({
                'report_type': '408001000'
            })  # we do not use single quarter single there are zeros
            """
            408001000: joint
            408002000: joint (single quarter)
            """

        filter_argument = self._dic2url(
            dic_argument)  # 0 means first time, not update

        res, msg = self.query(view_name,
                              fields=fields,
                              filter=filter_argument,
                              order_by=self.REPORT_DATE_FIELD_NAME)
        # change data type
        try:
            cols = list(
                set.intersection({'ann_date', 'report_date'},
                                 set(res.columns)))
            dic_dtype = {col: int for col in cols}
            res = res.astype(dtype=dic_dtype)
        except:
            pass

        return res, msg

    def query_lb_dailyindicator(self, symbol, start_date, end_date, fields=""):
        """
        Helper function to call data_api.query with 'lb.secDailyIndicator' more conveniently.
        
        Parameters
        ----------
        symbol : str
            separated by ','
        start_date : int
        end_date : int
        fields : str, optional
            separated by ',', default ""

        Returns
        -------
        df : pd.DataFrame
            index date, columns fields
        msg : str
        
        """
        filter_argument = self._dic2url({
            'symbol': symbol,
            'start_date': start_date,
            'end_date': end_date
        })

        return self.query("lb.secDailyIndicator",
                          fields=fields,
                          filter=filter_argument,
                          orderby="trade_date")

    def _get_index_comp(self, index, start_date, end_date):
        """
        Return all securities that have been in index during start_date and end_date.
        
        Parameters
        ----------
        index : str
            separated by ','
        start_date : int
        end_date : int

        Returns
        -------
        list

        """
        filter_argument = self._dic2url({
            'index_code': index,
            'start_date': start_date,
            'end_date': end_date
        })

        df_io, msg = self.query("lb.indexCons",
                                fields="",
                                filter=filter_argument,
                                orderby="symbol")
        return df_io, msg

    def get_index_comp(self, index, start_date, end_date):
        """
        Return list of symbols that have been in index during start_date and end_date.
        
        Parameters
        ----------
        index : str
            separated by ','
        start_date : int
        end_date : int

        Returns
        -------
        list

        """
        df_io, msg = self._get_index_comp(index, start_date, end_date)
        if msg != '0,':
            print msg
        return list(np.unique(df_io.loc[:, 'symbol']))

    def get_index_comp_df(self, index, start_date, end_date):
        """
        Get index components on each day during start_date and end_date.
        
        Parameters
        ----------
        index : str
            separated by ','
        start_date : int
        end_date : int

        Returns
        -------
        res : pd.DataFrame
            index dates, columns all securities that have ever been components,
            values are 0 (not in) or 1 (in)

        """
        df_io, msg = self._get_index_comp(index, start_date, end_date)
        if msg != '0,':
            print msg

        def str2int(s):
            if isinstance(s, (str, unicode)):
                return int(s) if s else 99999999
            elif isinstance(s, (int, np.integer, float, np.float)):
                return s
            else:
                raise NotImplementedError("type s = {}".format(type(s)))

        df_io.loc[:, 'in_date'] = df_io.loc[:, 'in_date'].apply(str2int)
        df_io.loc[:, 'out_date'] = df_io.loc[:, 'out_date'].apply(str2int)

        # df_io.set_index('symbol', inplace=True)
        dates = self.get_trade_date(start_date=start_date,
                                    end_date=end_date,
                                    symbol=index)

        dic = dict()
        gp = df_io.groupby(by='symbol')
        for sec, df in gp:
            mask = np.zeros_like(dates, dtype=int)
            for idx, row in df.iterrows():
                bool_index = np.logical_and(dates > row['in_date'],
                                            dates < row['out_date'])
                mask[bool_index] = 1
            dic[sec] = mask

        res = pd.DataFrame(index=dates, data=dic)

        return res

    @staticmethod
    def _group_df_to_dict(df, by):
        gp = df.groupby(by=by)
        res = {key: value for key, value in gp}
        return res

    def get_industry_daily(self, symbol, start_date, end_date, type_='SW'):
        """
        Get index components on each day during start_date and end_date.
        
        Parameters
        ----------
        symbol : str
            separated by ','
        start_date : int
        end_date : int
        type_ : {'SW', 'ZZ'}

        Returns
        -------
        res : pd.DataFrame
            index dates, columns symbols
            values are industry code

        """
        df_raw = self.get_industry_raw(symbol, type_=type_)

        dic_sec = self._group_df_to_dict(df_raw, by='symbol')
        dic_sec = {
            sec: df.sort_values(by='in_date', axis=0).reset_index()
            for sec, df in dic_sec.viewitems()
        }

        df_ann = pd.concat([
            df.loc[:, 'in_date'].rename(sec)
            for sec, df in dic_sec.viewitems()
        ],
                           axis=1)
        df_value = pd.concat([
            df.loc[:, 'industry1_code'].rename(sec)
            for sec, df in dic_sec.viewitems()
        ],
                             axis=1)

        dates_arr = self.get_trade_date(start_date, end_date)
        df_industry = align.align(df_value, df_ann, dates_arr)

        # TODO before industry classification is available, we assume they belong to their first group.
        df_industry = df_industry.fillna(method='bfill')
        df_industry = df_industry.astype(str)

        return df_industry

    def get_industry_raw(self, symbol, type_='ZZ'):
        """
        Get daily industry of securities from ShenWanHongYuan or ZhongZhengZhiShu.
        
        Parameters
        ----------
        symbol : str
            separated by ','
        type_ : {'SW', 'ZZ'}

        Returns
        -------
        df : pd.DataFrame

        """
        if type_ == 'SW':
            src = u'申万研究所'.encode('utf-8')
        elif type_ == 'ZZ':
            src = u'中证指数有限公司'.encode('utf-8')
        else:
            raise ValueError("type_ must be one of SW of ZZ")

        filter_argument = self._dic2url({
            'symbol': symbol,
            'industry_src': src
        })
        fields_list = ['symbol', 'industry1_code', 'industry1_name']

        df_raw, msg = self.query("lb.secIndustry",
                                 fields=','.join(fields_list),
                                 filter=filter_argument,
                                 orderby="symbol")
        if msg != '0,':
            print msg

        df_raw = df_raw.astype(dtype={
            'in_date': int,
            # 'out_date': int
        })
        return df_raw.drop_duplicates()

    def get_adj_factor_daily(self, symbol, start_date, end_date, div=False):
        """
        Get index components on each day during start_date and end_date.
        
        Parameters
        ----------
        symbol : str
            separated by ','
        start_date : int
        end_date : int
        div : bool
            False for normal adjust factor, True for diff.

        Returns
        -------
        res : pd.DataFrame
            index dates, columns symbols
            values are industry code

        """
        df_raw = self.get_adj_factor_raw(symbol)

        dic_sec = self._group_df_to_dict(df_raw, by='symbol')
        dic_sec = {
            sec: df.loc[:, ['trade_date', 'adjust_factor']].set_index(
                'trade_date').iloc[:, 0]
            for sec, df in dic_sec.viewitems()
        }

        res = pd.concat(dic_sec, axis=1)

        # align to every trade date
        s, e = df_raw.loc[:,
                          'trade_date'].min(), df_raw.loc[:,
                                                          'trade_date'].max()
        dates_arr = self.get_trade_date(s, e)
        res = res.reindex(dates_arr)

        res = res.fillna(method='ffill').fillna(method='bfill')

        if div:
            res = res.div(res.shift(1, axis=0)).fillna(1.0)

        res = res.loc[start_date:end_date, :]

        return res

    def get_adj_factor_raw(self, symbol, start_date=None, end_date=None):
        """
        Query adjust factor for symbols.
        
        Parameters
        ----------
        symbol : str
            separated by ','
        start_date : int
        end_date : int

        Returns
        -------
        df : pd.DataFrame

        """
        if start_date is None:
            start_date = ""
        if end_date is None:
            end_date = ""

        filter_argument = self._dic2url({
            'symbol': symbol,
            'start_date': start_date,
            'end_date': end_date
        })
        fields_list = ['symbol', 'trade_date', 'adjust_factor']

        df_raw, msg = self.query("lb.secAdjFactor",
                                 fields=','.join(fields_list),
                                 filter=filter_argument,
                                 orderby="symbol")
        if msg != '0,':
            print msg
        df_raw = df_raw.astype(dtype={
            'symbol': str,
            'trade_date': int,
            'adjust_factor': float
        })
        return df_raw.drop_duplicates()

    def query_inst_info(self, symbol, inst_type="", fields=""):
        if inst_type == "":
            inst_type = "1,2,3,4,5,101,102,103,104"

        filter_argument = self._dic2url({
            'symbol': symbol,
            'inst_type': inst_type
        })

        df_raw, msg = self.query("jz.instrumentInfo",
                                 fields=fields,
                                 filter=filter_argument,
                                 orderby="symbol")
        if msg != '0,':
            print msg

        dtype_map = {'symbol': str, 'list_date': int, 'delist_date': int}
        cols = set(df_raw.columns)
        dtype_map = {k: v for k, v in dtype_map.items() if k in cols}

        df_raw = df_raw.astype(dtype=dtype_map)
        return df_raw, msg
Esempio n. 7
0
jaqs_fxdayu.patch_all()
from jaqs.data import DataView
from jaqs.data import RemoteDataService
from jaqs_fxdayu.data.dataservice import LocalDataService
import os
import numpy as np
import alpha32_, alpha42_, alpha56_, alpha62_, alpha64_, alpha194, alpha195, alpha197, Beta3
import pandas as pd
import matplotlib.pyplot as plt
from jaqs_fxdayu.research import SignalDigger
from jaqs_fxdayu.research.signaldigger import analysis
from jaqs_fxdayu.research.signaldigger import multi_factor

api = DataApi(addr='tcp://data.tushare.org:8910')
api.login(
    "18523827661",
    'eyJhbGciOiJIUzI1NiJ9.eyJjcmVhdGVfdGltZSI6IjE1MjIxMTc0NDY1MzAiLCJpc3MiOiJhdXRoMCIsImlkIjoiMTg1MjM4Mjc2NjEifQ.AO9Rp8jG_IWc6crPrBOC-ujMP0-g1S1c5kUlTs5qwrk'
)
start = 20100101
end = 20180401

SH_id = dp.index_cons(api, "000300.SH", start, end)
SZ_id = dp.index_cons(api, "000905.SH", start, end)

stock_symbol = list(set(SH_id.symbol) | set(SZ_id.symbol))
factor_list = ['volume', 'float_mv', 'pe', 'ps']
check_factor = ','.join(factor_list)
dataview_folder = '/Users/adam/Desktop/intern/test5/fxdayu_adam/data'
dataview_folder2 = 'muti_factor/'
dv = DataView()
#ds = LocalDataService(fp=dataview_folder)
data_config = {
Esempio n. 8
0
    "eyJhbGciOiJIUzI1NiJ9.eyJjcmVhdGVfdGltZSI6IjE1MzU1OTg2MTI0NzYiLCJpc3MiOiJhdXRoMCIsImlkIjoiMTU1NjYwMjg1NjgifQ.ToTAnVWpKtweGj4yoXhVW0pzHds7a9qQzXC8qLBui2g"  #token是api令牌
}
trade_config = {
    "remote.trade.address":
    "tcp://gw.quantos.org:8901",
    "remote.trade.username":
    "******",
    "remote.trade.password":
    "******"
}

dataview_dir_path = '***'  #dataview存储路径
backtest_result_dir_path = '***'  #回测结果存储路径

api = DataApi(addr='tcp://data.quantos.org:8910')
api.login("phone", "token")

dv = DataView()
dv.load_dataview(folder_path=dataview_dir_path)
dv.add_field('turnover_ratio', ds)
dv.save_dataview(folder_path=dataview_dir_path)
dv.update_snapshot()

#在已经保存dataview之后,需要添加某项指标(dataview.py里面可以找这些指标,不是因子!),如turnover_ratio,使用这段代码,添加完之后下一次运行就可以注释掉


def my_selector(context, user_options=None):
    #筛选僵尸股和ST股
    df = context.dataview.data_inst['name']
    result = pd.DataFrame(df, columns=['ST'], index=df.index, dtype='bool')
    selector_volume = context.snapshot[
import pandas as pd
from jaqs_fxdayu.util import dp
from jaqs.data.dataapi import DataApi
import numpy as np
import talib as ta
from jaqs_fxdayu.research.signaldigger.process import neutralize

api = DataApi(addr='tcp://data.tushare.org:8910')
api.login(
    "13662241013",
    'eyJhbGciOiJIUzI1NiJ9.eyJjcmVhdGVfdGltZSI6IjE1MTc2NDQzMzg5MTIiLCJpc3MiOiJhdXRoMCIsImlkIjoiMTM2NjIyNDEwMTMifQ.sVIzI5VLqq8fbZCW6yZZW0ClaCkcZpFqpiK944AHEow'
)
start = 20130101
end = 20180101
SH_id = dp.index_cons(api, "000300.SH", start, end)
SZ_id = dp.index_cons(api, "000905.SH", start, end)
stock_symbol = list(set(SH_id.symbol) | set(SZ_id.symbol))

#读取数据,其中需要调节我们需要的因子的数量和内容
factor_list = [
    'oper_rev', 'oper_rev_ttm', 'oper_rev_lyr', "total_oper_rev",
    'less_gerl_admin_exp', 'AdminiExpenseRate', 'volume', "gainvariance120",
    'index', "cash_recp_sg_and_rs"
]
check_factor = ','.join(factor_list)

import jaqs_fxdayu
jaqs_fxdayu.patch_all()
from jaqs.data import DataView
from jaqs.data import RemoteDataService
from jaqs_fxdayu.data.dataservice import LocalDataService
Esempio n. 10
0
# coding: utf-8

# In[2]:

import pandas as pd
import numpy as np

# In[3]:

from jaqs.data.dataapi import DataApi

api = DataApi(addr='tcp://data.tushare.org:8910')
api.login(
    '18810695562',
    'eyJhbGciOiJIUzI1NiJ9.eyJjcmVhdGVfdGltZSI6IjE1MTMzMTU1MTE3MDEiLCJpc3MiOiJhdXRoMCIsImlkIjoiMTg4MTA2OTU1NjIifQ.2qvBxyHzcIFfC4_5lP_MFkNwDGLYD2gwqxMLWMpOLw0'
)

# In[4]:

df, msg = api.quote("000001.SH, cu1802.SHF",
                    fields="open,high,low,last,volume")

# In[5]:

df

# In[5]:

df1, msg1 = api.query(view="jz.secTradeCal",
                      fields="date,market",
                      filter="start_date=20170901&end_date=20171101",
Esempio n. 11
0
class MyQuantosDataApi():
    def __init__(self):
        self.symbol = ""
        self.fields = ""
        self.start_date = 0
        self.end_date = 0
        self._client = None
        self._db = None
        self._col = None
        self._db_name = ""
        self._col_name = ""
        self._db_list = []
        self._col_list = []
        self._data_init_config = {}
        self._dataapi = None
        self._remote_data_service = None
        self._instrument_info = pd.DataFrame()
        self._symbol_set_all_A = set()
        self._dataview_data = pd.DataFrame()
        self._dbase_props = {}
        self._db_init_config = ['reference_daily_fields']
        # self._input_init_config = []
        self._data_status = 0

    def _login(self, dataapi=False, remote=False, mongodb=False):
        self._data_init_config = {
            "remote.data.address":
            "tcp://data.tushare.org:8910",
            "remote.data.username":
            "******",
            "remote.data.password":
            "******"
        }

        if dataapi:
            self._dataapi = DataApi(
                addr=self._data_init_config.get('remote.data.address'))
            self._dataapi.login(
                self._data_init_config.get('remote.data.username'),
                self._data_init_config.get('remote.data.password'))

        if remote:
            self._remote_data_service = RemoteDataService()
            self._remote_data_service.init_from_config(self._data_init_config)

        if mongodb:
            # 连接mongoDB
            self._client = pymongo.MongoClient(host='localhost', port=27017)

    def _logout(self):
        "登出mongoDB数据库"
        self._client.close()

    def _set_data_status(self):
        '''
        判断是首次下载,或增量更新数据
        首次下载,则self._data_status = 1
        增量更新,则self._data_status = 2
        同时,返回去掉默认数据库后的db列表和collection列表
        '''
        self._login(mongodb=True)
        self._db_list = self._client.list_database_names()
        # 去掉默认的数据库,admin和local
        self._db_list = self._db_list.remove('admin')
        self._db_list = self._db_list.remove('local')
        if self._db_list == []:
            self._data_status = 1
            self._client.close()
        else:
            self._col_list = self._db.collection_names()
            self._data_status = 2
            self._client.close()

    def _prepare_data(self):
        self._set_data_status()
        self._set_fields()
        self._login(dataapi=True, remote=True)
        self._get_instrumentInfo()
        self._get_symbol_set_all_A()

        if self._data_status == 1:
            self._db_name = self._db_init_config[0]
            self._col_name = self.symbol
            self._get_start_end_date_first_time()

        if self._data_status == 2:

            self._get_start_end_date_update()

    # def _get_input_init_config(self):
    #     "确定最终多线程map函数的输入列表,返回[(db_name, col_name)]"
    #     for db_name in self._db_init_config:
    #         for col_name in self._symbol_set_all_A:
    #             self._input_init_config.append(db_name, col_name)

    def _set_fields(self):
        "确定查询字段,同时也确定mongoDB的db_name"
        dv = DataView()
        fields_init_config = {
            'reference_daily_fields': dv.reference_daily_fields
            # 此处可能增加新的字段,只要是qunatos的dataview支持的字段
        }
        self.fields = fields_init_config.get('reference_daily_fields')

    def _get_instrumentInfo(self):
        '''
        获取沪深A股基本资料
        inst_type=1  证券类别:股票
        status=1     上市状态:上市
        返回: pd.DataFrame
        '''
        df, msg = self._dataapi.query(
            view="jz.instrumentInfo",
            fields="status,list_date,delist_date,name,market,symbol",
            filter="inst_type=1&status=1",
            data_format='pandas')

        # 股票市场为沪市和深市,原数据包括港股
        df = df[(df['market'] == 'SH') | (df['market'] == 'SZ')]
        self._instrument_info = df

    def _get_symbol_set_all_A(self):
        "获取当前日期,所有上市状态的A股代码集合"
        self._symbol_set_all_A = set(self._instrument_info['symbol'])

    def _get_start_end_date_first_time(self):
        "确定首次下载数据时的起止日期"
        df = self._instrument_info
        list_date = df[df['symbol'] == self.symbol]['list_date'].iloc[0]
        delist_date = df[df['symbol'] == self.symbol]['delist_date'].iloc[0]

        # 判定上市、退市日期和today的关系
        today = datetime.today().strftime('%Y%m%d')
        self.start_date = int(list_date)
        self.end_date = int(min(today, delist_date))

    def _get_start_end_date_update(self):
        "确定增量更新数据时的起止日期"
        # 从数据库中获取已存数据的交易日期, 类型pd.DataFrame"
        trade_date = self._col.find({}, {'trade_date': 1, '_id': 0})
        trade_date = [i for i in trade_date]
        trade_date = pd.DataFrame(trade_date)

        # 获取已存数据的最新交易日期,类型str"
        latest_date = trade_date['trade_date'].max()
        latest_date = str(latest_date)

        # 增量更新时的起止日期"
        today = datetime.today().strftime('%Y%m%d')
        latest_date = datetime.strptime(latest_date, '%Y%m%d')
        latest_date = latest_date + timedelta(days=1)
        latest_date = latest_date.strftime('%Y%m%d')

        if today >= latest_date:
            self.start_date = int(latest_date)
            self.end_date = int(today)
        else:
            print("Trade date in Dbase is latest.")

    def _download_data(self):
        "使用quantos的dataview,下载单个股票的给定字段的数据,返回pd.DataFrame"
        dv = DataView()
        # fields = ','.join(list(dv.reference_daily_fields))
        props = {
            'symbol': self.symbol,
            'fields': self.fields,
            'start_date': self.start_date,
            'end_date': self.end_date,
            'freq': 1
        }

        dv.init_from_config(props=props, data_api=self._remote_data_service)
        dv.prepare_data()
        self._dataview_data = dv.data_d

    def _dataframe_to_dbase_props(self):
        "将dataframe转换为mongoDB中的props,以备写入数据库"
        df = self._dataview_date
        df.columns = df.columns.droplevel()
        df.reset_index(inplace=True)

        # dataframe的范围限制在更新时期之内,因为dataview取数据时会将日期的范围前后各放宽几天
        df = df[(df['trade_date'] >= self.start_date)
                & (df['trade_date'] <= self.end_date)]

        # 判断drop掉trade_date列,并dropna后,dataframe是否为空,空则说明更新日期已经是最新的
        is_df = df.drop('trade_date', axis=1).dropna(how='all')

        if is_df.empty:
            print("\n\nTrade date is latest.\n\n")
            return None
        else:
            props = df.to_dict(orient='records')
            self._dbase_props = props

    def _write_data_to_dbase(self):
        pass
Esempio n. 12
0
class Calendar(object):
    """
    A calendar for manage trade date.
    
    Attributes
    ----------
    data_api :

    """
    def __init__(self, data_api=None):
        if data_api is not None:
            self.data_api = data_api
        else:
            props = jutil.read_json(
                jutil.join_relative_path('etc/data_config.json'))

            address = props.get("remote.address", "")
            username = props.get("remote.username", "")
            password = props.get("remote.password", "")
            if address is None or username is None or password is None:
                raise ValueError("no address, username or password available!")
            time_out = props.get("timeout", 60)

            self.data_api = DataApi(address, use_jrpc=False)
            self.data_api.set_timeout(timeout=time_out)
            r, msg = self.data_api.login(username=username, password=password)
            if not r:
                print("DataAPI login failed: msg = '{}".format(msg))
            else:
                print "DataAPI login success : {}@{}".format(username, address)

    @staticmethod
    def _dic2url(d):
        """
        Convert a dict to str like 'k1=v1&k2=v2'
        
        Parameters
        ----------
        d : dict

        Returns
        -------
        str

        """
        l = ['='.join([key, str(value)]) for key, value in d.viewitems()]
        return '&'.join(l)

    def get_trade_date_range(self, start_date, end_date):
        """
        Get array of trade dates within given range.
        Return zero size array if no trade dates within range.
        
        Parameters
        ----------
        start_date : int
            YYmmdd
        end_date : int

        Returns
        -------
        trade_dates_arr : np.ndarray
            dtype = int

        """
        filter_argument = self._dic2url({
            'start_date': start_date,
            'end_date': end_date
        })

        df_raw, msg = self.data_api.query("jz.secTradeCal",
                                          fields="trade_date",
                                          filter=filter_argument,
                                          orderby="")
        if df_raw.empty:
            return np.array([], dtype=int)

        trade_dates_arr = df_raw['trade_date'].values.astype(int)
        return trade_dates_arr

    def get_last_trade_date(self, date):
        """
        
        Parameters
        ----------
        date : int

        Returns
        -------
        res : int

        """
        dt = jutil.convert_int_to_datetime(date)
        delta = pd.Timedelta(weeks=2)
        dt_old = dt - delta
        date_old = jutil.convert_datetime_to_int(dt_old)

        dates = self.get_trade_date_range(date_old, date)
        mask = dates < date
        res = dates[mask][-1]

        return res

    def is_trade_date(self, date):
        """
        Check whether date is a trade date.

        Parameters
        ----------
        date : int

        Returns
        -------
        bool

        """
        dates = self.get_trade_date_range(date, date)
        return len(dates) > 0

    def get_next_trade_date(self, date):
        """
        
        Parameters
        ----------
        date : int

        Returns
        -------
        res : int

        """
        dt = jutil.convert_int_to_datetime(date)
        delta = pd.Timedelta(weeks=2)
        dt_new = dt + delta
        date_new = jutil.convert_datetime_to_int(dt_new)

        dates = self.get_trade_date_range(date, date_new)
        mask = dates > date
        res = dates[mask][0]

        return res
Esempio n. 13
0
class RemoteDataService(DataService):
    """
    RemoteDataService is a concrete class using data from remote server's database.

    """
    __metaclass__ = Singleton

    # TODO no validity check for input parameters

    def __init__(self):
        DataService.__init__(self)

        self.data_api = None

        self.REPORT_DATE_FIELD_NAME = 'report_date'
        self.calendar = None

    def __del__(self):
        self.data_api.close()

    def init_from_config(self, props=None):
        if props is None:
            props = dict()

        if self.data_api is not None:
            if len(props) == 0:
                return
            else:
                self.data_api.close()

        def get_from_list_of_dict(l, key, default=None):
            res = None
            for dic in l:
                res = dic.get(key, None)
                if res is not None:
                    break
            if res is None:
                res = default
            return res

        props_default = jutil.read_json(
            jutil.join_relative_path('etc/data_config.json'))
        dic_list = [props, props_default]

        address = get_from_list_of_dict(dic_list, "remote.address", "")
        username = get_from_list_of_dict(dic_list, "remote.username", "")
        password = get_from_list_of_dict(dic_list, "remote.password", "")
        if address is None or username is None or password is None:
            raise ValueError("no address, username or password available!")
        time_out = get_from_list_of_dict(dic_list, "timeout", 60)

        self.data_api = DataApi(address, use_jrpc=False)
        self.data_api.set_timeout(timeout=time_out)
        print("\nDataApi login: {}@{}".format(username, address))
        r, msg = self.data_api.login(username=username, password=password)
        if not r:
            print("    login failed: msg = '{}'\n".format(msg))
        else:
            print "    login success \n"

        self.calendar = Calendar(self.data_api)

    # -----------------------------------------------------------------------------------
    # Basic APIs
    def daily(self, symbol, start_date, end_date, fields="", adjust_mode=None):
        df, err_msg = self.data_api.daily(symbol=symbol,
                                          start_date=start_date,
                                          end_date=end_date,
                                          fields=fields,
                                          adjust_mode=adjust_mode,
                                          data_format="")
        # trade_status performance warning
        # TODO there will be duplicate entries when on stocks' IPO day
        df = df.drop_duplicates()
        return df, err_msg

    def bar(self,
            symbol,
            start_time=200000,
            end_time=160000,
            trade_date=None,
            freq='1M',
            fields=""):
        df, msg = self.data_api.bar(symbol=symbol,
                                    fields=fields,
                                    start_time=start_time,
                                    end_time=end_time,
                                    trade_date=trade_date,
                                    freq='1M',
                                    data_format="")
        return df, msg

    def query(self, view, filter="", fields="", **kwargs):
        """
        Get various reference data.
        
        Parameters
        ----------
        view : str
            data source.
        fields : str
            Separated by ','
        filter : str
            filter expressions.
        kwargs

        Returns
        -------
        df : pd.DataFrame
        msg : str
            error code and error message, joined by ','
        
        Examples
        --------
        res3, msg3 = ds.query("lb.secDailyIndicator", fields="price_level,high_52w_adj,low_52w_adj",\
                              filter="start_date=20170907&end_date=20170907",\
                              orderby="trade_date",\
                              data_format='pandas')
            view does not change. fileds can be any field predefined in reference data api.

        """
        df, msg = self.data_api.query(view,
                                      fields=fields,
                                      filter=filter,
                                      data_format="",
                                      **kwargs)
        return df, msg

    # -----------------------------------------------------------------------------------
    # Convenient Functions
    def get_trade_date_range(self, start_date, end_date):
        return self.calendar.get_trade_date_range(start_date, end_date)

    @staticmethod
    def _dic2url(d):
        """
        Convert a dict to str like 'k1=v1&k2=v2'
        
        Parameters
        ----------
        d : dict

        Returns
        -------
        str

        """
        l = ['='.join([key, str(value)]) for key, value in d.viewitems()]
        return '&'.join(l)

    def query_lb_fin_stat(self,
                          type_,
                          symbol,
                          start_date,
                          end_date,
                          fields="",
                          drop_dup_cols=None):
        """
        Helper function to call data_api.query with 'lb.income' more conveniently.
        
        Parameters
        ----------
        type_ : {'income', 'balance_sheet', 'cash_flow'}
        symbol : str
            separated by ','
        start_date : int
            Annoucement date in results will be no earlier than start_date
        end_date : int
            Annoucement date in results will be no later than start_date
        fields : str, optional
            separated by ',', default ""
        drop_dup_cols : list or tuple
            Whether drop duplicate entries according to drop_dup_cols.

        Returns
        -------
        df : pd.DataFrame
            index date, columns fields
        msg : str

        """
        view_map = {
            'income': 'lb.income',
            'cash_flow': 'lb.cashFlow',
            'balance_sheet': 'lb.balanceSheet',
            'fin_indicator': 'lb.finIndicator'
        }
        view_name = view_map.get(type_, None)
        if view_name is None:
            raise NotImplementedError("type_ = {:s}".format(type_))

        dic_argument = {
            'symbol': symbol,
            'start_date': start_date,
            'end_date': end_date,
            # 'update_flag': '0'
        }
        if view_name != 'lb.finIndicator':
            dic_argument.update({
                'report_type': '408001000'
            })  # we do not use single quarter single there are zeros
            """
            408001000: joint
            408002000: joint (single quarter)
            """

        filter_argument = self._dic2url(
            dic_argument)  # 0 means first time, not update

        res, msg = self.query(view_name,
                              fields=fields,
                              filter=filter_argument,
                              order_by=self.REPORT_DATE_FIELD_NAME)

        # change data type
        try:
            cols = list(
                set.intersection({'ann_date', 'report_date'},
                                 set(res.columns)))
            dic_dtype = {col: int for col in cols}
            res = res.astype(dtype=dic_dtype)
        except:
            pass

        if drop_dup_cols is not None:
            res = res.sort_values(by=drop_dup_cols, axis=0)
            res = res.drop_duplicates(subset=drop_dup_cols, keep='first')

        return res, msg

    def query_lb_dailyindicator(self, symbol, start_date, end_date, fields=""):
        """
        Helper function to call data_api.query with 'lb.secDailyIndicator' more conveniently.
        
        Parameters
        ----------
        symbol : str
            separated by ','
        start_date : int
        end_date : int
        fields : str, optional
            separated by ',', default ""

        Returns
        -------
        df : pd.DataFrame
            index date, columns fields
        msg : str
        
        """
        filter_argument = self._dic2url({
            'symbol': symbol,
            'start_date': start_date,
            'end_date': end_date
        })

        return self.query("lb.secDailyIndicator",
                          fields=fields,
                          filter=filter_argument,
                          orderby="trade_date")

    def get_index_weights(self, index, trade_date):
        """
        Return all securities that have been in index during start_date and end_date.
        
        Parameters
        ----------
        index : str
            separated by ','
        trade_date : int

        Returns
        -------
        pd.DataFrame

        """
        if index == '000300.SH':
            index = '399300.SZ'

        filter_argument = self._dic2url({
            'index_code': index,
            'trade_date': trade_date
        })

        df_io, msg = self.query("lb.indexWeight",
                                fields="",
                                filter=filter_argument)
        if msg != '0,':
            print msg
        df_io = df_io.set_index('symbol')
        df_io = df_io.astype({'weight': float, 'trade_date': int})
        df_io.loc[:, 'weight'] = df_io['weight'] / 100.
        return df_io

    def get_index_weights_daily(self, index, start_date, end_date):
        """
        Return all securities that have been in index during start_date and end_date.
        
        Parameters
        ----------
        index : str
        start_date : int
        end_date : int

        Returns
        -------
        res : pd.DataFrame
            Index is trade_date, columns are symbols.

        """
        # TODO: temparary api
        trade_dates = self.get_trade_date_range(start_date, end_date)
        start_date, end_date = trade_dates[0], trade_dates[-1]
        td = start_date

        dic = dict()
        symbols_set = set()
        while True:
            if td > end_date:
                break
            df = self.get_index_weights(index, td)
            update_date = df['trade_date'].iat[0]
            if update_date >= start_date and update_date <= end_date:
                symbols_set.update(set(df.index))
                dic[td] = df['weight']
            td = jutil.get_next_period_day(td, 'month', 1)
        merge = pd.concat(dic, axis=1).T
        merge = merge.fillna(0.0)  # for those which are not components
        res = pd.DataFrame(index=trade_dates,
                           columns=sorted(list(symbols_set)),
                           data=np.nan)
        res.update(merge)
        res = res.fillna(method='ffill')
        res = res.loc[start_date:end_date]
        return res

    def _get_index_comp(self, index, start_date, end_date):
        """
        Return all securities that have been in index during start_date and end_date.
        
        Parameters
        ----------
        index : str
            separated by ','
        start_date : int
        end_date : int

        Returns
        -------
        list

        """
        filter_argument = self._dic2url({
            'index_code': index,
            'start_date': start_date,
            'end_date': end_date
        })

        df_io, msg = self.query("lb.indexCons",
                                fields="",
                                filter=filter_argument,
                                orderby="symbol")
        return df_io, msg

    def get_index_comp(self, index, start_date, end_date):
        """
        Return list of symbols that have been in index during start_date and end_date.
        
        Parameters
        ----------
        index : str
            separated by ','
        start_date : int
        end_date : int

        Returns
        -------
        list

        """
        df_io, msg = self._get_index_comp(index, start_date, end_date)
        if msg != '0,':
            print msg
        return list(np.unique(df_io.loc[:, 'symbol']))

    def get_index_comp_df(self, index, start_date, end_date):
        """
        Get index components on each day during start_date and end_date.
        
        Parameters
        ----------
        index : str
            separated by ','
        start_date : int
        end_date : int

        Returns
        -------
        res : pd.DataFrame
            index dates, columns all securities that have ever been components,
            values are 0 (not in) or 1 (in)

        """
        df_io, msg = self._get_index_comp(index, start_date, end_date)
        if msg != '0,':
            print msg

        def str2int(s):
            if isinstance(s, (str, unicode)):
                return int(s) if s else 99999999
            elif isinstance(s, (int, np.integer, float, np.float)):
                return s
            else:
                raise NotImplementedError("type s = {}".format(type(s)))

        df_io.loc[:, 'in_date'] = df_io.loc[:, 'in_date'].apply(str2int)
        df_io.loc[:, 'out_date'] = df_io.loc[:, 'out_date'].apply(str2int)

        # df_io.set_index('symbol', inplace=True)
        dates = self.get_trade_date_range(start_date=start_date,
                                          end_date=end_date)

        dic = dict()
        gp = df_io.groupby(by='symbol')
        for sec, df in gp:
            mask = np.zeros_like(dates, dtype=int)
            for idx, row in df.iterrows():
                bool_index = np.logical_and(dates > row['in_date'],
                                            dates < row['out_date'])
                mask[bool_index] = 1
            dic[sec] = mask

        res = pd.DataFrame(index=dates, data=dic)

        return res

    def get_industry_daily(self,
                           symbol,
                           start_date,
                           end_date,
                           type_='SW',
                           level=1):
        """
        Get index components on each day during start_date and end_date.
        
        Parameters
        ----------
        symbol : str
            separated by ','
        start_date : int
        end_date : int
        type_ : {'SW', 'ZZ'}

        Returns
        -------
        res : pd.DataFrame
            index dates, columns symbols
            values are industry code

        """
        df_raw = self.get_industry_raw(symbol, type_=type_, level=level)

        dic_sec = jutil.group_df_to_dict(df_raw, by='symbol')
        dic_sec = {
            sec: df.sort_values(by='in_date', axis=0).reset_index()
            for sec, df in dic_sec.viewitems()
        }

        df_ann_tmp = pd.concat(
            {sec: df.loc[:, 'in_date']
             for sec, df in dic_sec.viewitems()},
            axis=1)
        df_value_tmp = pd.concat(
            {
                sec: df.loc[:, 'industry{:d}_code'.format(level)]
                for sec, df in dic_sec.viewitems()
            },
            axis=1)

        idx = np.unique(
            np.concatenate([df.index.values for df in dic_sec.values()]))
        symbol_arr = np.sort(symbol.split(','))
        df_ann = pd.DataFrame(index=idx, columns=symbol_arr, data=np.nan)
        df_ann.loc[df_ann_tmp.index, df_ann_tmp.columns] = df_ann_tmp
        df_value = pd.DataFrame(index=idx, columns=symbol_arr, data=np.nan)
        df_value.loc[df_value_tmp.index, df_value_tmp.columns] = df_value_tmp

        dates_arr = self.get_trade_date_range(start_date, end_date)
        df_industry = align.align(df_value, df_ann, dates_arr)

        # TODO before industry classification is available, we assume they belong to their first group.
        df_industry = df_industry.fillna(method='bfill')
        df_industry = df_industry.astype(str)

        return df_industry

    def get_industry_raw(self, symbol, type_='ZZ', level=1):
        """
        Get daily industry of securities from ShenWanZhiShu or ZhongZhengZhiShu.
        
        Parameters
        ----------
        symbol : str
            separated by ','
        type_ : {'SW', 'ZZ'}
        level : {1, 2, 3, 4}
            Use which level of industry index classification.

        Returns
        -------
        df : pd.DataFrame

        """
        if type_ == 'SW':
            src = u'申万研究所'.encode('utf-8')
            if level not in [1, 2, 3, 4]:
                raise ValueError("For [SW], level must be one of {1, 2, 3, 4}")
        elif type_ == 'ZZ':
            src = u'中证指数有限公司'.encode('utf-8')
            if level not in [1, 2, 3, 4]:
                raise ValueError("For [ZZ], level must be one of {1, 2}")
        else:
            raise ValueError("type_ must be one of SW of ZZ")

        filter_argument = self._dic2url({
            'symbol': symbol,
            'industry_src': src
        })
        fields_list = [
            'symbol', 'industry{:d}_code'.format(level),
            'industry{:d}_name'.format(level)
        ]

        df_raw, msg = self.query("lb.secIndustry",
                                 fields=','.join(fields_list),
                                 filter=filter_argument,
                                 orderby="symbol")
        if msg != '0,':
            print msg

        df_raw = df_raw.astype(dtype={
            'in_date': int,
            # 'out_date': int
        })
        return df_raw.drop_duplicates()

    def get_adj_factor_daily(self, symbol, start_date, end_date, div=False):
        """
        Get index components on each day during start_date and end_date.
        
        Parameters
        ----------
        symbol : str
            separated by ','
        start_date : int
        end_date : int
        div : bool
            False for normal adjust factor, True for diff.

        Returns
        -------
        res : pd.DataFrame
            index dates, columns symbols
            values are industry code

        """
        df_raw = self.get_adj_factor_raw(symbol,
                                         start_date=start_date,
                                         end_date=end_date)

        dic_sec = jutil.group_df_to_dict(df_raw, by='symbol')
        dic_sec = {
            sec: df.set_index('trade_date').loc[:, 'adjust_factor']
            for sec, df in dic_sec.viewitems()
        }

        # TODO: duplicate codes with dataview.py: line 512
        res = pd.concat(dic_sec, axis=1)  # TODO: fillna ?

        idx = np.unique(
            np.concatenate([df.index.values for df in dic_sec.values()]))
        symbol_arr = np.sort(symbol.split(','))
        res_final = pd.DataFrame(index=idx, columns=symbol_arr, data=np.nan)
        res_final.loc[res.index, res.columns] = res

        # align to every trade date
        s, e = df_raw.loc[:,
                          'trade_date'].min(), df_raw.loc[:,
                                                          'trade_date'].max()
        dates_arr = self.get_trade_date_range(s, e)
        if not len(dates_arr) == len(res_final.index):
            res_final = res_final.reindex(dates_arr)

            res_final = res_final.fillna(method='ffill').fillna(method='bfill')

        if div:
            res_final = res_final.div(res_final.shift(1, axis=0)).fillna(1.0)

        # res = res.loc[start_date: end_date, :]

        return res_final

    def get_adj_factor_raw(self, symbol, start_date=None, end_date=None):
        """
        Query adjust factor for symbols.
        
        Parameters
        ----------
        symbol : str
            separated by ','
        start_date : int
        end_date : int

        Returns
        -------
        df : pd.DataFrame

        """
        if start_date is None:
            start_date = ""
        if end_date is None:
            end_date = ""

        filter_argument = self._dic2url({
            'symbol': symbol,
            'start_date': start_date,
            'end_date': end_date
        })
        fields_list = ['symbol', 'trade_date', 'adjust_factor']

        df_raw, msg = self.query("lb.secAdjFactor",
                                 fields=','.join(fields_list),
                                 filter=filter_argument,
                                 orderby="symbol")
        if msg != '0,':
            print msg
        df_raw = df_raw.astype(dtype={
            'symbol': str,
            'trade_date': int,
            'adjust_factor': float
        })
        return df_raw.drop_duplicates()

    def query_inst_info(self, symbol, inst_type="", fields=""):
        if inst_type == "":
            inst_type = "1,2,3,4,5,101,102,103,104"

        filter_argument = self._dic2url({
            'symbol': symbol,
            'inst_type': inst_type
        })

        df_raw, msg = self.query("jz.instrumentInfo",
                                 fields=fields,
                                 filter=filter_argument,
                                 orderby="symbol")
        if msg != '0,':
            print msg

        dtype_map = {
            'symbol': str,
            'list_date': int,
            'delist_date': int,
            'inst_type': int
        }
        cols = set(df_raw.columns)
        dtype_map = {k: v for k, v in dtype_map.viewitems() if k in cols}

        df_raw = df_raw.astype(dtype=dtype_map)

        res = df_raw.set_index('symbol')
        return res

    # -----------------------------------------------------------------------------------
    # subscribe for real time trading
    def subscribe(self, symbols):
        """
        
        Parameters
        ----------
        symbols : str
            Separated by ,

        """
        self.data_api.subscribe(symbols, func=self.mkt_data_callback)

    def mkt_data_callback(self, key, quote):
        e = Event(EVENT_TYPE.MARKET_DATA)
        # print quote
        e.dic = {'quote': quote}
        self.ctx.instance.put(e)
Esempio n. 14
0
class MyTushareApi():
    def __init__(self):
        self.universe = []
        self.symbol = ""
        self.fields = ""
        self.start_date = 0
        self.end_date = 0
        self._data_status = 0
        self._data_init_config = {}
        self._client = None
        self._db = None
        self._col = None
        self._db_list = []
        self._col_list = []
        self._folder_path = './log/'
        self.reference = {
            "lb.secDailyIndicator": "pe, pe_ttm, pb, ps, ps_ttm, net_assets",
            # 可扩展view和fileds
            }


    def _login(self, dataapi=False, mongodb=False):
        self._data_init_config = {
            "remote.data.address": "tcp://data.tushare.org:8910",
            "remote.data.username": "******",
            "remote.data.password": "******"
        }

        if dataapi:
            self._dataapi = DataApi(
                addr=self._data_init_config.get('remote.data.address'))
            self._dataapi.login(self._data_init_config.get(
                'remote.data.username'), self._data_init_config.get('remote.data.password'))

        if mongodb:
            # 连接mongoDB
            self._client = pymongo.MongoClient(host='localhost', port=27017)

    
    def _logout(self):
        self._client.close()


    def _set_data_status(self):
        '''
        判断是首次下载,或增量更新数据
        首次下载,则self._data_status = 1
        增量更新,则self._data_status = 2
        同时,返回去掉默认数据库后的db列表和collection列表
        '''
        self._db_list = self._client.list_database_names()
        # 去掉默认的数据库,admin和local
        self._db_list.remove('admin')
        self._db_list.remove('local')
        if self._db_list == []:
            self._data_status = 1
        else:
            self._col_list = self._db.collection_names()
            self._data_status = 2


    def _get_instrumentInfo(self):
        '''
        获取沪深A股基本资料
        inst_type=1  证券类别:股票
        status=1     上市状态:上市
        返回: pd.DataFrame
        '''
        df, msg = self._dataapi.query(
            view="jz.instrumentInfo",
            fields="status,list_date,delist_date,name,market,symbol",
            filter="inst_type=1&status=1",
            data_format='pandas')

        # 股票市场为沪市和深市,原数据包括港股
        df = df[(df['market'] == 'SH') | (df['market'] == 'SZ')]
        return df


    def _set_universe(self):
        "获取当前日期,所有上市状态的A股代码集合"
        self.universe = set(self._get_instrumentInfo()['symbol'])


    def _set_start_end_date_first_time(self):
        "确定首次下载数据时的起止日期"
        df = self._get_instrumentInfo()
        list_date = df[df['symbol'] == self.symbol]['list_date'].iloc[0]
        delist_date = df[df['symbol'] == self.symbol]['delist_date'].iloc[0]

        # 判定上市、退市日期和today的关系
        today = datetime.today().strftime('%Y%m%d')
        self.start_date = list_date
        self.end_date = min(today, delist_date)


    def _set_start_end_date_update(self):
        "确定增量更新数据时的起止日期"
        # 从数据库中获取已存数据的交易日期, 类型pd.DataFrame"
        trade_date = self._col.find({}, {'trade_date': 1, '_id': 0})
        trade_date = [i for i in trade_date]
        trade_date = pd.DataFrame(trade_date)

        # 获取已存数据的最新交易日期,类型str"
        latest_date = trade_date['trade_date'].max()
        latest_date = str(latest_date)

        # 增量更新时的起止日期"
        today = datetime.today().strftime('%Y%m%d')
        latest_date = datetime.strptime(latest_date, '%Y%m%d')
        latest_date = latest_date + timedelta(days=1)
        latest_date = latest_date.strftime('%Y%m%d')

        if today >= latest_date:
            self.start_date = latest_date
            self.end_date = today
        else:
            print("Trade date in Dbase is latest.")


    def _download_data(self):
        df = self._dataapi.query(
            view=self.db,
            fields=self.fields,
            filter=self.symbol + '&' + self.start_date + '&' + self.end_date
        )
        return df


    def _save_data(self):
        for db_name, self.fields in self.reference.items():
            df = self._download_data()
            props = df.to_dict(orient='records')

            self._db = self._client[db_name]
            self._col = self._db[self.symbol]

            try:
                self._col.insert_many(props)
                print("\n\Save data to DBase...\nSymbol: {}\n\n".format(self.symbol))
            except:
                # 将错误股票代码列表写入文件
                with open(self._folder_path + 'error_list.csv', 'a') as f:
                    f.write(symbol + ',')
            finally:
                self._logout()


    def main(self, symbol):
        self.symbol = symbol
        self._login(dataapi=True, mongodb=True)
        self._set_data_status()
        self._set_universe()

        if self._data_status == 1:
            self._set_start_end_date_first_time()
            self._save_data()


        if self._data_status == 2:
            self._set_start_end_date_update()
            self._save_data()