Ejemplo n.º 1
0
    def __init__(self):
        self.wsetData = [
            "000001.SH", "399300.SZ", "000016.SH", "000905.SH", "000906.SH"
        ]  # 要获取数据的证券代码
        self.indexFieldName = [
            "open", "high", "low", "close", "volume", "amt", "chg", "pct_chg",
            "turn"
        ]  # 要获取的数据字段
        self.fundFieldName = ["nav", "NAV_acc", "sec_name"]
        self.monetaryFund = [
            "mmf_annualizedyield", "mmf_unityield", "sec_name"
        ]
        self.stockFieldName = [
            "open", "high", "low", "close", "volume", "amt", "turn",
            "mkt_cap_ard", "pe_ttm", "ps_ttm", "pb_lf"
        ]
        self.engine = MysqlCon().getMysqlCon(flag='engine')
        self.conn = MysqlCon().getMysqlCon(flag='connect')
        self.GetDataToMysqlDemo = GetDataToMysql()
        self.logger = mylog.logger
        # w.start()

        log_state = THS_iFinDLogin('zszq5072', '754628')
        if log_state == 0:
            self.logger.info("同花顺账号登录成功!")
        else:
            self.logger.error("同花顺账号登录异常,请检查!")
            return
Ejemplo n.º 2
0
class GetDataToMysql:
    def __init__(self):
        self.conn = MysqlCon().getMysqlCon(flag='connect')
        self.logger = mylog.logger

    def GetMain(self, dataDf, tableName):
        # 插入数据语句
        tableList = dataDf.columns.tolist()
        strFormat = '%s,' * len(tableList)
        sqlStr = "replace into %s(%s)" % (
            tableName, ','.join(tableList)) + "VALUES(%s)" % strFormat[:-1]

        # dataDf.fillna('None',inplace=True)
        dataDf = dataDf.astype(object).where(pd.notnull(dataDf), None)
        # dataDf.where(dataDf.notnull(), None)
        # dataDf.where(dataDf.notnull(), None)
        cursor = self.conn.cursor()

        try:
            for r in range(0, len(dataDf)):
                values = tuple(dataDf.iloc[r][tableList].tolist())
                cursor.execute(sqlStr, values)
        except:
            self.logger.error("插入数据到数据库错误,请检查!")

        cursor.close()
        self.conn.commit()
Ejemplo n.º 3
0
class GetDataToMysql:
    def __init__(self):
        self.conn = MysqlCon().getMysqlCon(flag='connect')
        self.logger = mylog.set_log()

    def GetMain(
        self,
        dataDf,
        tableName,
    ):
        # 插入数据语句
        tableList = dataDf.columns.tolist()
        strFormat = '%s,' * len(tableList)
        sqlStr = "replace into %s(%s)" % (
            tableName, ','.join(tableList)) + "VALUES(%s)" % strFormat[:-1]

        dataDf = dataDf.astype(object).where(pd.notnull(dataDf), None)
        cursor = self.conn.cursor()

        for r in range(0, len(dataDf)):
            values = tuple(dataDf.iloc[r][tableList].tolist())
            cursor.execute(sqlStr, values)

        cursor.close()
        self.conn.commit()
        self.logger.info("数据存入mysql成功!")
Ejemplo n.º 4
0
 def __init__(self, data_resource='ifind'):
     self.logger = mylog.set_log()
     self.dic_init = {}
     self.dic_init['data_resource'] = data_resource
     self.dic_init['data_init_flag'] = self.log_init(data_resource)
     mysql_con_demo = MysqlCon()
     self.engine = mysql_con_demo.getMysqlCon(flag='engine')
     self.conn = mysql_con_demo.getMysqlCon(flag='connect')
     self.GetDataToMysqlDemo = GetDataToMysql()
Ejemplo n.º 5
0
 def __init__(self):
     self.start_date = '2020-07-01'
     self.end_date = '2020-07-31'
     self.file_path = r"D:\\工作文件\\指数基金月报\\培训PPT\\0821\\"
     MysqlConDemo = MysqlCon()
     self.engine = MysqlConDemo.getMysqlCon('engine')
     self.name_dic = {
         'fund_code': '基金代码',
         'fund_type': '基金类型',
         'product_type': '产品类型',
         'fund_name': '基金名称',
         'establish_date': '基金成立日',
         'indx_sname': '跟踪指数',
         'class_classify': '跟踪指数类型',
         'index_code': '跟踪指数代码'
     }
     self.logger = mylog.set_log()
Ejemplo n.º 6
0
 def __init__(self):
     self.wsetData = [
         "000001.SH", "399300.SZ", "000016.SH", "000905.SH", "000906.SH"
     ]  # 要获取数据的证券代码
     self.indexFieldName = [
         "open", "high", "low", "close", "volume", "amt", "chg", "pct_chg",
         "turn"
     ]  # 要获取的数据字段
     self.fundFieldName = ["nav", "NAV_acc", "sec_name"]
     self.monetaryFund = [
         "mmf_annualizedyield", "mmf_unityield", "sec_name"
     ]
     self.stockFieldName = [
         "open", "high", "low", "close", "volume", "amt", "turn",
         "mkt_cap_ard", "pe_ttm", "ps_ttm", "pb_lf"
     ]
     self.engine = MysqlCon().getMysqlCon(flag='engine')
     self.conn = MysqlCon().getMysqlCon(flag='connect')
     self.GetDataToMysqlDemo = GetDataToMysql()
     self.logger = mylog.logger
Ejemplo n.º 7
0
 def __init__(self):
     self.logger = mylog.set_log()
     mysql_con_demo = MysqlCon()
     self.engine = mysql_con_demo.getMysqlCon(flag='engine')
     w.start()
 def __init__(self):
     MysqlConDemo = MysqlCon()
     self.engine = MysqlConDemo.getMysqlCon('engine')
Ejemplo n.º 9
0
class GetDataFromWindAndMySql:
    def __init__(self):
        self.wsetData = [
            "000001.SH", "399300.SZ", "000016.SH", "000905.SH", "000906.SH"
        ]  # 要获取数据的证券代码
        self.indexFieldName = [
            "open", "high", "low", "close", "volume", "amt", "chg", "pct_chg",
            "turn"
        ]  # 要获取的数据字段
        self.fundFieldName = ["nav", "NAV_acc", "sec_name"]
        self.monetaryFund = [
            "mmf_annualizedyield", "mmf_unityield", "sec_name"
        ]
        self.stockFieldName = [
            "open", "high", "low", "close", "volume", "amt", "turn",
            "mkt_cap_ard", "pe_ttm", "ps_ttm", "pb_lf"
        ]
        self.engine = MysqlCon().getMysqlCon(flag='engine')
        self.conn = MysqlCon().getMysqlCon(flag='connect')
        self.GetDataToMysqlDemo = GetDataToMysql()
        self.logger = mylog.logger
        # w.start()

        log_state = THS_iFinDLogin('zszq5072', '754628')
        if log_state == 0:
            self.logger.info("同花顺账号登录成功!")
        else:
            self.logger.error("同花顺账号登录异常,请检查!")
            return

    def getStockMonthToMySql(self):
        start_date = '2011-11-01'
        end_date = '2019-12-31'
        total_trade_list = self.getTradeDay(startDate=start_date,
                                            endDate=end_date,
                                            Period='M')
        wsetdata = w.wset(
            "sectorconstituent",
            "date=%s;sectorid=a001010100000000" % total_trade_list[0])
        if wsetdata.ErrorCode != 0:
            self.logger.debug("获取全A股数据有误,错误代码" + str(wsetdata.ErrorCode))
        index_df = pd.DataFrame(wsetdata.Data,
                                index=wsetdata.Fields,
                                columns=wsetdata.Codes).T
        if index_df.empty:
            return

        for trade_date in total_trade_list:
            optionstr = "tradeDate=%s;cycle=M" % (
                trade_date[:4] + trade_date[5:7] + trade_date[8:])
            wssdata = w.wss(
                codes=index_df['wind_code'].tolist(),
                fields=["mkt_freeshares", "pe_ttm", "ps_ttm", "pct_chg"],
                options=optionstr)
            if wssdata.ErrorCode != 0:
                self.logger.debug("获取因子数据有误,错误代码" + str(wssdata.ErrorCode))
                return pd.DataFrame()
            resultDf = pd.DataFrame(wssdata.Data,
                                    index=wssdata.Fields,
                                    columns=wssdata.Codes).T

            df_list = []
            for col in resultDf:
                temp_df = pd.DataFrame(resultDf[col].values,
                                       index=resultDf.index,
                                       columns=['factor_value'])
                temp_df['update_time'] = trade_date
                temp_df['stock_code'] = resultDf.index.tolist()
                temp_df['factor_name'] = col
                df_list.append(temp_df)
            total_fa_df = pd.concat(df_list, axis=0, sort=True)
            self.GetDataToMysqlDemo.GetMain(total_fa_df,
                                            'stock_factor_month_value')
            self.logger.info("存储日期%s因子数据成功!" % trade_date)

    def getFactorValue(self,
                       code_list=[],
                       factor_list=[],
                       start_date='2019-04-01',
                       end_date='2019-05-01'):
        # 获取截面因子数据
        sqlStr = "select * from stock_factor_month_value where stock_code in %s and factor_name in %s and update_time='%s'" % (
            str(tuple(code_list)), str(tuple(factor_list)), start_date)
        resultDf = pd.read_sql(sql=sqlStr, con=self.engine)

        if resultDf.empty:
            dateFormat = start_date[:4] + start_date[5:7] + start_date[8:]
            wssdata = w.wss(codes=code_list,
                            fields=factor_list,
                            options="unit=1;tradeDate=%s" % dateFormat)
            if wssdata.ErrorCode != 0:
                self.logger.debug("获取因子数据有误,错误代码" + str(wssdata.ErrorCode))
                return pd.DataFrame()
            resultDf = pd.DataFrame(wssdata.Data,
                                    index=wssdata.Fields,
                                    columns=wssdata.Codes).T
        else:
            df_list = []
            for factor, temp_df in resultDf.groupby(by='factor_name'):
                temp = pd.DataFrame(temp_df['factor_value'].values,
                                    index=temp_df['stock_code'],
                                    columns=[factor])
                df_list.append(temp)
            resultDf = pd.concat(df_list, sort=True, axis=1)
            return resultDf
        return resultDf

    def getIndexConstituent(self, indexCode='000300.SH', getDate='2019-06-06'):
        '''
        获取指数成分股
        :param indexCode:
        :param getDate:
        :return:
        '''
        if indexCode != '全A股':
            sqlStr = "select * from index_constituent where index_code='%s' and update_time='%s'" % (
                indexCode, getDate)
            resultDf = pd.read_sql(sql=sqlStr, con=self.engine)
            if resultDf.empty:
                wsetdata = w.wset("indexconstituent",
                                  "date=%s;windcode=%s" % (getDate, indexCode))
                if wsetdata.ErrorCode != 0:
                    self.logger.error("获取指数成分股数据有误,错误代码" +
                                      str(wsetdata.ErrorCode))
                    return pd.DataFrame()

                resultDf = pd.DataFrame(wsetdata.Data, index=wsetdata.Fields).T
                if resultDf.empty:
                    wsetdata = w.wset(
                        "sectorconstituent",
                        "date=%s;windcode=%s" % (getDate, indexCode))
                    if wsetdata.ErrorCode != 0:
                        self.logger.error("获取板块指数成分股数据有误,错误代码" +
                                          str(wsetdata.ErrorCode))
                        return pd.DataFrame()

                    resultDf = pd.DataFrame(wsetdata.Data,
                                            index=wsetdata.Fields).T
                    if resultDf.empty:
                        self.logger.info("指定日期内,未找到有效成分股数据")
                    return resultDf

                dateList = [
                    datetampStr.strftime('%Y-%m-%d')
                    for datetampStr in resultDf['date'].tolist()
                ]
                resultDf['date'] = dateList
                nameDic = {
                    'date': 'adjust_time',
                    'wind_code': 'stock_code',
                    "sec_name": 'stock_name',
                    'i_weight': 'stock_weight'
                }
                resultDf.rename(columns=nameDic, inplace=True)
                resultDf['update_time'] = getDate
                resultDf['index_code'] = indexCode
                if 'industry' in resultDf:
                    resultDf.drop(labels='industry', inplace=True, axis=1)
                self.GetDataToMysqlDemo.GetMain(resultDf, 'index_constituent')
        else:
            sqlStr = "select distinct stock_code from stock_factor_month_value where update_time='%s'" % getDate
            resultDf = pd.read_sql(sql=sqlStr, con=self.engine)
        return resultDf

    def getLackDataToMySql(self, tempCode, startDate, endDate):
        tableStr = 'index_value'
        codeName = 'index_code'

        sqlStr = "select max(update_time),min(update_time) from %s where %s='%s'" % (
            tableStr, codeName, tempCode)
        cursor = self.conn.cursor()
        cursor.execute(sqlStr)
        dateStrTuple = cursor.fetchall()[0]
        maxDate = dateStrTuple[0]
        minDate = dateStrTuple[1]

        if not maxDate:
            self.get_hq_data_from_ths(tempCode,
                                      startDate=startDate,
                                      endDate=endDate)
            return

        if endDate < minDate or startDate > minDate:
            self.get_hq_data_from_ths(tempCode,
                                      startDate=startDate,
                                      endDate=endDate)
        elif startDate <= minDate:
            if minDate <= endDate < maxDate:
                if startDate != minDate:
                    self.get_hq_data_from_ths(tempCode,
                                              startDate=startDate,
                                              endDate=minDate)
            elif endDate >= maxDate:
                self.get_hq_data_from_ths(tempCode,
                                          startDate=startDate,
                                          endDate=minDate)
                if endDate != maxDate:
                    self.get_hq_data_from_ths(tempCode,
                                              startDate=maxDate,
                                              endDate=endDate)
        elif endDate > maxDate:
            self.get_hq_data_from_ths(tempCode,
                                      startDate=maxDate,
                                      endDate=endDate)

    def getDataFromMySql(self,
                         tempCode,
                         startDate,
                         endDate,
                         tableFlag='index',
                         nameList=['close_price']):
        if not nameList:
            self.logger.error('传入获取指数的字段不合法,请检查!')
        if tableFlag == 'index':
            tableStr = 'index_value'
            codeName = 'index_code'
        elif tableFlag == 'fund':
            codeName = 'fund_code'
            tableStr = 'fund_net_value'
        elif tableFlag == 'stock':
            codeName = 'stock_code'
            tableStr = 'stock_hq_value'
        elif tableFlag == 'private':
            codeName = 'fund_code'
            tableStr = 'private_net_value'
        elif tableFlag == 'monetary_fund':
            codeName = 'fund_code'
            tableStr = 'monetary_fund'

        sqlStr = "select %s,update_time from %s where %s='%s' and  update_time>='%s'" \
                 " and update_time<='%s'" % (','.join(nameList), tableStr, codeName, tempCode, startDate, endDate)
        resultDf = pd.read_sql(sql=sqlStr, con=self.engine)
        resultDf = resultDf.drop_duplicates('update_time').sort_index()
        resultDf.set_index(keys='update_time', inplace=True, drop=True)
        return resultDf

    def getCurrentNameData(self,
                           tempCodeList,
                           startDate,
                           endDate,
                           tableFlag='stock',
                           nameStr='close_price'):
        '''
        获取指定字段的数据
        '''
        if tableFlag == 'stock':
            totalCodeStr = ''
            for stockCode in tempCodeList:
                totalCodeStr = totalCodeStr + stockCode + "','"

            sqlStr1 = "select max(update_time),min(update_time) from stock_hq_value where stock_code in ('%s')" % totalCodeStr[:
                                                                                                                               -3]
            cursor = self.conn.cursor()
            cursor.execute(sqlStr1)
            dateStrTuple = cursor.fetchall()[0]
            maxDate = dateStrTuple[0]
            minDate = dateStrTuple[1]

            if not maxDate:
                for tempCode in tempCodeList:
                    self.getDataFromWind(tempCode,
                                         startDate=startDate,
                                         endDate=endDate,
                                         tableFlag=tableFlag)
                    return
            else:
                if endDate < minDate or startDate > minDate:
                    for tempCode in tempCodeList:
                        self.getDataFromWind(tempCode,
                                             startDate=startDate,
                                             endDate=endDate,
                                             tableFlag=tableFlag)
                elif startDate <= minDate:
                    if minDate <= endDate < maxDate:
                        for tempCode in tempCodeList:
                            self.getDataFromWind(tempCode,
                                                 startDate=startDate,
                                                 endDate=minDate,
                                                 tableFlag=tableFlag)
                    elif endDate >= maxDate:
                        for tempCode in tempCodeList:
                            self.getDataFromWind(tempCode,
                                                 startDate=startDate,
                                                 endDate=minDate,
                                                 tableFlag=tableFlag)
                            self.getDataFromWind(tempCode,
                                                 startDate=maxDate,
                                                 endDate=endDate,
                                                 tableFlag=tableFlag)
                elif endDate >= maxDate:
                    for tempCode in tempCodeList:
                        self.getDataFromWind(tempCode,
                                             startDate=maxDate,
                                             endDate=endDate,
                                             tableFlag=tableFlag)

            sqlStr = "select %s,update_time,stock_code from stock_hq_value where stock_code in ('%s') and update_time<='%s' " \
                     "and update_time>='%s'" % (nameStr, totalCodeStr, endDate, startDate)
            resultDf = pd.read_sql(sql=sqlStr, con=self.engine)
            dfList = []
            for code, tempDf in resultDf.groupby('stock_code'):
                df = pd.DataFrame(tempDf[nameStr].values,
                                  index=tempDf['update_time'],
                                  columns=[code])
                dfList.append(df)
            resultDf = pd.concat(dfList, axis=1)
            return resultDf

    def getCurrentDateData(self,
                           tempCodeList,
                           getDate,
                           tableFlag='stock',
                           nameList=['close_price']):
        '''
        获取指定日期的截面数据
        :return:
        '''
        if tableFlag == 'stock':
            totalCodeStr = ""
            for stockCode in tempCodeList:
                totalCodeStr = totalCodeStr + stockCode + "','"

            sqlStr = "select * from stock_hq_value where stock_code in ('%s') and update_time='%s'" % (
                totalCodeStr[:-3], getDate)
            resultDf = pd.read_sql(sql=sqlStr, con=self.engine)
            if resultDf.empty:
                codes = tempCodeList
                fields = self.stockFieldName
                tradeDate = getDate
                wssData = w.wss(codes=codes,
                                fields=fields,
                                options="tradeDate=%s;priceAdj=F;cycle=D" %
                                tradeDate)
                if wssData.ErrorCode != 0:
                    self.logger.error("获取行情数据有误,错误代码" + str(wssData.ErrorCode))
                    return pd.DataFrame()
                tempDf = pd.DataFrame(wssData.Data,
                                      index=fields,
                                      columns=codes).T
                tempDf.dropna(inplace=True)
                if tempDf.empty:
                    self.logger.error("当前日期%s无行情" % getDate)
                    return pd.DataFrame()

                tempDf['update_time'] = getDate
                nameDic = {
                    "open": "open_price",
                    "high": "high_price",
                    "low": "low_price",
                    "close": "close_price",
                    "mkt_cap_ard": "market_value",
                }
                tempDf.rename(columns=nameDic, inplace=True)

                tempDf['stock_code'] = tempDf.index.tolist()
                self.GetDataToMysqlDemo.GetMain(tempDf, 'stock_hq_value')
                returnDf = tempDf[nameList]
                return returnDf
            else:
                resultDf.set_index('stock_code', drop=True, inplace=True)
                returnDf = resultDf[nameList]
                return returnDf

    def getFirstDayData(self, codeList, tableFlag='fund'):
        if tableFlag == 'fund':
            wssData = w.wss(codes=codeList, fields=["fund_setupDate"])
            if wssData.ErrorCode != 0:
                self.logger.error("getFirstDayData获取wind数据错误,错误代码%s" %
                                  wssData.ErrorCode)
                return pd.DataFrame()
            self.logger.debug("getFirstDayData获取wind数据成功!")
            resultDf = pd.DataFrame(wssData.Data,
                                    columns=wssData.Codes,
                                    index=wssData.Fields)
            return resultDf

    def get_hq_data_from_ths(
        self,
        tempCode,
        startDate='2019-04-01',
        endDate='2019-04-30',
    ):
        ths_data = THS_HistoryQuotes(
            thscode=tempCode,
            jsonIndicator=
            'open,high,low,close,volume,changeRatio,turnoverRatio,change',
            jsonparam=
            'Interval:D,CPS:6,baseDate:1900-01-01,Currency:YSHB,fill:Previous',
            begintime=startDate,
            endtime=endDate,
            outflag=True)
        if ths_data['errorcode'] != 0:
            self.logger.error("同花顺获取行情数据错误,请检查:%s" % ths_data['errmsg'])
            return

        tempDf = THS_Trans2DataFrame(ths_data)
        tempDf.dropna(how='all', inplace=True)
        tempDf[codeName] = tempCode
        tempDf['update_time'] = tempDf.index.tolist()
        tempDf.rename(columns=nameDic, inplace=True)
        dateList = [
            dateStr.strftime("%Y-%m-%d")
            for dateStr in tempDf['update_time'].tolist()
        ]
        tempDf['update_time'] = dateList
        self.GetDataToMysqlDemo.GetMain(tempDf, 'index_value')
        return tempDf

    def getHQData(self,
                  tempCode,
                  startDate='2019-03-01',
                  endDate='2019-05-30',
                  tableFlag='index',
                  nameList=['close_price']):
        '''
        #获取指数行情数据入口
        '''
        self.getLackDataToMySql(tempCode, startDate, endDate, tableFlag)
        resultDf = self.getDataFromMySql(tempCode,
                                         startDate,
                                         endDate,
                                         tableFlag=tableFlag,
                                         nameList=nameList)
        return resultDf

    def getTradeDay(self, startDate, endDate, Period='M'):
        '''
        获取指定周期交易日,封装wind接口
        :param Period: ''日,W周,M月,Q季,S半年,Y年
        :return:
        '''
        # w.start()
        data = w.tdays(beginTime=startDate,
                       endTime=endDate,
                       options="Period=%s" % Period)
        if data.ErrorCode != 0:
            self.logger.error('wind获取交易日期错误,请检查!')
            return
        tradeDayList = data.Data[0]
        tradeDayList = [
            tradeDay.strftime('%Y-%m-%d') for tradeDay in tradeDayList
        ]
        # w.close()
        return tradeDayList
Ejemplo n.º 10
0
class GetDataFromWindAndMySql:
    def __init__(self):
        self.wsetData = [
            "000001.SH", "399300.SZ", "000016.SH", "000905.SH", "000906.SH"
        ]  # 要获取数据的证券代码
        self.indexFieldName = [
            "open", "high", "low", "close", "volume", "amt", "chg", "pct_chg",
            "turn"
        ]  # 要获取的数据字段
        self.fundFieldName = ["nav", "NAV_acc", "sec_name"]
        self.monetaryFund = [
            "mmf_annualizedyield", "mmf_unityield", "sec_name"
        ]
        self.stockFieldName = [
            "open", "high", "low", "close", "volume", "amt", "turn",
            "mkt_cap_ard", "pe_ttm", "ps_ttm", "pb_lf"
        ]
        self.engine = MysqlCon().getMysqlCon(flag='engine')
        self.conn = MysqlCon().getMysqlCon(flag='connect')
        self.GetDataToMysqlDemo = GetDataToMysql()
        self.logger = mylog.logger

    def getBelongIndustry(self, codeList, tradeDate='2018-12-31'):
        '''
        获取股票所属的行业
        注:查询所属行业前,应确保查询日期大于股票ipo日期,该逻辑不在本方法验证。
        :return:
        '''
        sqlStr = "select * from stock_industry_value where stock_code in %s and update_time='%s'" % (
            tuple(codeList), tradeDate)
        resultDf = pd.read_sql(sql=sqlStr, con=self.engine)
        if resultDf.empty:
            self.logger.debug("getBelongIndustry从wind获取!")
            tradeDateParam = tradeDate[:4] + tradeDate[5:7] + tradeDate[8:]
            wssData = w.wss(codes=codeList,
                            fields=["industry_citic"],
                            options="tradeDate=%s;industryType=1" %
                            tradeDateParam)
            if wssData.ErrorCode != 0:
                self.logger.error("获取指数成分股数据有误,错误代码" + str(wssData.ErrorCode))
                return pd.DataFrame()
            df = pd.DataFrame(wssData.Data,
                              index=wssData.Fields,
                              columns=wssData.Codes).T
            df.rename(columns={"INDUSTRY_CITIC": "industry_name"},
                      inplace=True)
            df['stock_code'] = df.index.tolist()
            df['update_time'] = tradeDate
            df['industry_wind_code'] = ["industry_citic"] * df.shape[0]
            df['industry_flag'] = [1] * df.shape[0]
            self.GetDataToMysqlDemo.GetMain(df, 'stock_industry_value')
            resultDf = df[['stock_code', 'industry_name']]
        else:
            self.logger.debug("getBelongIndustry从本地数据库获取!")
            resultDf = resultDf[['stock_code', 'industry_name']]
        return resultDf

    def getFactorReportData(self,
                            codeList,
                            factors,
                            rptDate='2018-12-31',
                            backYears=0):
        # 单个年报数据的获取
        sqlStr = "select stock_code,item_value from stock_factor_value where stock_code in %s and update_time='%s' and item_wind_code='%s' " \
                 % (tuple(codeList), rptDate, factors[0])
        resultDf = pd.read_sql(sql=sqlStr, con=self.engine)
        if resultDf.empty:
            self.logger.debug("getFactorReportData从wind数据库获取%s!" % factors[0])
            rptDateParam = rptDate[:4] + rptDate[5:7] + rptDate[8:]
            if backYears == 0:
                if "wgsd_assets" not in factors:
                    wssData = w.wss(codes=codeList,
                                    fields=factors,
                                    options="rptDate=%s" % rptDateParam)
                else:
                    wssData = w.wss(
                        codes=codeList,
                        fields=factors,
                        options="unit=1;rptDate=%s;rptType=1;currencyType=" %
                        rptDateParam)
            else:
                wssData = w.wss(codes=codeList,
                                fields=factors,
                                options="rptDate=%s;N=%s" %
                                (rptDateParam, str(backYears)))

            if wssData.ErrorCode != 0:
                self.logger.error("getFactorDailyData获取%s有误,错误代码%s" %
                                  (factors, str(wssData.ErrorCode)))
                return pd.DataFrame()
            df = pd.DataFrame(wssData.Data,
                              columns=wssData.Codes,
                              index=factors).T

            resultDf = df.copy()
            df['stock_code'] = df.index.tolist()
            df['update_time'] = rptDate
            df.rename(columns={factors: "item_value"}, inplace=True)
            df['item_wind_code'] = factors[0]
            df['rpt_flag'] = 1
            self.GetDataToMysqlDemo.GetMain(df, 'stock_factor_value')
        else:
            self.logger.debug("getFactorReportData从本地数据库获取%s!" % factors[0])
            resultDf = resultDf.set_index(
                "stock_code",
                drop=True).rename(columns={"item_value": factors[0]})
        return resultDf

    def getPetChg(self, codeList, startDate, endDate):
        '''
        获取股票区间涨跌幅数据
        :return:
        '''
        sqlStr = "select stock_code,pct_chg_value from stock_range_updown_value where stock_code in %s and start_date='%s' and end_date='%s'" % (
            tuple(codeList), startDate, endDate)
        resultDf = pd.read_sql(sql=sqlStr, con=self.engine)
        if not resultDf.empty:
            lackCode = [
                code for code in codeList
                if code not in resultDf['stock_code'].tolist()
            ]
        else:
            lackCode = codeList[:]

        if lackCode:
            self.logger.debug("getPetChg从wind获取!")
            startDateParam = startDate[:4] + startDate[5:7] + startDate[8:]
            endDateParam = endDate[:4] + endDate[5:7] + endDate[8:]
            wssData = w.wss(codes=lackCode,
                            fields=["pct_chg_per"],
                            options="startDate=%s;endDate=%s" %
                            (startDateParam, endDateParam))
            if wssData.ErrorCode != 0:
                self.logger.error("getPetChg获取pct_chg_per有误,错误代码%s" %
                                  (str(wssData.ErrorCode)))
                return pd.DataFrame()
            df = pd.DataFrame(wssData.Data,
                              index=["pct_chg_value"],
                              columns=wssData.Codes).T
            df['stock_code'] = df.index.tolist()
            df['start_date'] = startDate
            df['end_date'] = endDate
            self.GetDataToMysqlDemo.GetMain(df, 'stock_range_updown_value')
            resultDf = pd.concat([resultDf, df], axis=0,
                                 sort=True)[['stock_code', 'pct_chg_value']]
        else:
            self.logger.debug("getPetChg从本地数据库获取!")
        resultDf.set_index('stock_code', inplace=True, drop=True)
        return resultDf

    def getFactorDailyData(self, codeList, factors, tradeDate='2018-12-31'):
        # 单个非年报数据的获取
        sqlStr = "select stock_code,item_value from stock_factor_value where stock_code in %s and update_time='%s' and item_wind_code='%s' " \
                 % (tuple(codeList), tradeDate, factors[0])
        resultDf = pd.read_sql(sql=sqlStr, con=self.engine)
        if resultDf.empty:
            self.logger.debug("getFactorDailyData从wind获取%s!" % factors[0])
            tradeDateParam = tradeDate[:4] + tradeDate[5:7] + tradeDate[8:]
            if 'mkt_cap_float' not in factors:
                wssData = w.wss(codes=codeList,
                                fields=factors,
                                options="tradeDate=%s" % tradeDateParam)
            else:
                wssData = w.wss(codes=codeList,
                                fields=factors,
                                options="unit=1;tradeDate=%s;currencyType=" %
                                tradeDateParam)
            if wssData.ErrorCode != 0:
                self.logger.error("getFactorDailyData获取%s有误,错误代码%s" %
                                  (factors, str(wssData.ErrorCode)))
                return pd.DataFrame()
            df = pd.DataFrame(wssData.Data,
                              index=factors,
                              columns=wssData.Codes).T
            resultDf = df.copy()
            df['stock_code'] = df.index.tolist()
            df['update_time'] = tradeDate
            df.rename(columns={factors[0]: "item_value"}, inplace=True)
            df['item_wind_code'] = factors[0]
            df['rpt_flag'] = 0
            self.GetDataToMysqlDemo.GetMain(df, 'stock_factor_value')
        else:
            self.logger.debug("getFactorDailyData从本地数据库获取%s!" % factors[0])
            resultDf = resultDf.set_index(
                "stock_code",
                drop=True).rename(columns={"item_value": factors[0]})
        return resultDf

    def getIndexConstituent(self,
                            indexCode='000300.SH',
                            getDate='2019-06-06',
                            indexOrSector='index'):
        '''
        获取指数成分股
        '''
        sqlStr = "select * from index_constituent where index_code='%s' and update_time='%s'" % (
            indexCode, getDate)
        resultDf = pd.read_sql(sql=sqlStr, con=self.engine)
        if resultDf.empty:
            self.logger.debug("getIndexConstituent从wind获取!")
            if indexOrSector == 'index':
                wsetdata = w.wset("indexconstituent",
                                  "date=%s;windcode=%s" % (getDate, indexCode))
            else:
                wsetdata = w.wset("sectorconstituent",
                                  "date=2019-08-21;windcode=%s" % indexCode)

            if wsetdata.ErrorCode != 0:
                self.logger.error("获取指数成分股数据有误,错误代码" + str(wsetdata.ErrorCode))
                return pd.DataFrame()

            resultDf = pd.DataFrame(wsetdata.Data, index=wsetdata.Fields).T
            dateList = [
                datetampStr.strftime('%Y-%m-%d')
                for datetampStr in resultDf['date'].tolist()
            ]
            resultDf['date'] = dateList
            nameDic = {
                'date': 'adjust_time',
                'wind_code': 'stock_code',
                "sec_name": 'stock_name',
                'i_weight': 'stock_weight'
            }
            resultDf.rename(columns=nameDic, inplace=True)
            resultDf['update_time'] = getDate
            resultDf['index_code'] = indexCode
            if 'stock_weight' not in resultDf.columns.tolist():
                resultDf['stock_weight'] = np.nan
            self.GetDataToMysqlDemo.GetMain(resultDf, 'index_constituent')

        else:
            self.logger.debug("getIndexConstituent从本地获取!")
        return resultDf

    def getLackDataToMySql(self,
                           tempCode,
                           startDate,
                           endDate,
                           tableFlag='index'):
        if tableFlag == 'index':
            tableStr = 'index_value'
            codeName = 'index_code'
        elif tableFlag == 'fund':
            tableStr = 'fund_net_value'
            codeName = 'fund_code'
        elif tableFlag == 'stock':
            tableStr = 'stock_hq_value'
            codeName = 'stock_code'
        elif tableFlag == 'private':
            return
        elif tableFlag == 'monetary_fund':
            tableStr = 'monetary_fund'
            codeName = 'fund_code'

        sqlStr = "select max(update_time),min(update_time) from %s where %s='%s'" % (
            tableStr, codeName, tempCode)
        cursor = self.conn.cursor()
        cursor.execute(sqlStr)
        dateStrTuple = cursor.fetchall()[0]
        maxDate = dateStrTuple[0]
        minDate = dateStrTuple[1]

        if not maxDate:
            self.getDataFromWind(tempCode,
                                 startDate=startDate,
                                 endDate=endDate,
                                 tableFlag=tableFlag)
            return

        if endDate < minDate or startDate > minDate:
            self.getDataFromWind(tempCode,
                                 startDate=startDate,
                                 endDate=endDate,
                                 tableFlag=tableFlag)
        elif startDate <= minDate:
            if minDate <= endDate < maxDate:
                if startDate != minDate:
                    self.getDataFromWind(tempCode,
                                         startDate=startDate,
                                         endDate=minDate,
                                         tableFlag=tableFlag)
            elif endDate >= maxDate:
                self.getDataFromWind(tempCode,
                                     startDate=startDate,
                                     endDate=minDate,
                                     tableFlag=tableFlag)
                if endDate != maxDate:
                    self.getDataFromWind(tempCode,
                                         startDate=maxDate,
                                         endDate=endDate,
                                         tableFlag=tableFlag)
        elif endDate > maxDate:
            self.getDataFromWind(tempCode,
                                 startDate=maxDate,
                                 endDate=endDate,
                                 tableFlag=tableFlag)

    def getDataFromWind(self,
                        tempCode,
                        startDate='2019-04-01',
                        endDate='2019-04-30',
                        tableFlag='index'):
        if tableFlag == 'index':
            tableStr = 'index_value'
            nameDic = {
                "OPEN": "open_price",
                "HIGH": "high_price",
                "LOW": "low_price",
                "CLOSE": "close_price",
                "VOLUME": "volume",
                "AMT": "amt",
                "CHG": "chg",
                "PCT_CHG": "pct_chg",
                "TURN": "turn"
            }
            fields = self.indexFieldName
            codeName = 'index_code'
        elif tableFlag == 'fund':
            tableStr = 'fund_net_value'
            nameDic = {
                "NAV": "net_value",
                "NAV_ACC": "acc_net_value",
                "SEC_NAME": "fund_name"
            }
            fields = self.fundFieldName
            codeName = 'fund_code'
        elif tableFlag == 'stock':
            tableStr = 'stock_hq_value'
            nameDic = {
                "OPEN": "open_price",
                "HIGH": "high_price",
                "LOW": "low_price",
                "CLOSE": "close_price",
                "VOLUME": "volume",
                "AMT": "amt",
                "TURN": "turn",
                "MKT_CAP_ARD": "market_value",
                "PE_TTM": "pe_ttm",
                "PS_TTM": "ps_ttm",
                "PB_LF": "pb_lf"
            }
            fields = self.stockFieldName
            codeName = 'stock_code'
        elif tableFlag == 'monetary_fund':
            tableStr = 'monetary_fund'
            nameDic = {
                "MMF_ANNUALIZEDYIELD": "week_annual_return",
                "MMF_UNITYIELD": "wan_unit_return",
                "SEC_NAME": "fund_name"
            }
            fields = self.monetaryFund
            codeName = 'fund_code'

        if tableFlag == 'stock':
            wsetdata = w.wsd(codes=tempCode,
                             fields=fields,
                             beginTime=startDate,
                             endTime=endDate,
                             options="PriceAdj=F")
        else:
            wsetdata = w.wsd(codes=tempCode,
                             fields=fields,
                             beginTime=startDate,
                             endTime=endDate)

        if wsetdata.ErrorCode != 0:
            self.logger.error("获取行情数据有误,错误代码" + str(wsetdata.ErrorCode))
            return

        tempDf = pd.DataFrame(wsetdata.Data,
                              index=wsetdata.Fields,
                              columns=wsetdata.Times).T
        tempDf.dropna(how='all', inplace=True)
        tempDf[codeName] = tempCode
        tempDf['update_time'] = tempDf.index.tolist()
        tempDf.rename(columns=nameDic, inplace=True)
        dateList = [
            dateStr.strftime("%Y-%m-%d")
            for dateStr in tempDf['update_time'].tolist()
        ]
        tempDf['update_time'] = dateList
        self.GetDataToMysqlDemo.GetMain(tempDf, tableStr)
        return tempDf

    def getDataFromMySql(self,
                         tempCode,
                         startDate,
                         endDate,
                         tableFlag='index',
                         nameList=['close_price']):
        if not nameList:
            self.logger.error('传入获取指数的字段不合法,请检查!')

        if tableFlag == 'index':
            tableStr = 'index_value'
            codeName = 'index_code'
        elif tableFlag == 'fund':
            codeName = 'fund_code'
            tableStr = 'fund_net_value'
        elif tableFlag == 'stock':
            codeName = 'stock_code'
            tableStr = 'stock_hq_value'
        elif tableFlag == 'private':
            codeName = 'fund_code'
            tableStr = 'private_net_value'
        elif tableFlag == 'monetary_fund':
            codeName = 'fund_code'
            tableStr = 'monetary_fund'

        sqlStr = "select %s,update_time from %s where %s='%s' and  update_time>='%s'" \
                 " and update_time<='%s'" % (','.join(nameList), tableStr, codeName, tempCode, startDate, endDate)
        resultDf = pd.read_sql(sql=sqlStr, con=self.engine)
        resultDf = resultDf.drop_duplicates('update_time').sort_index()
        resultDf.set_index(keys='update_time', inplace=True, drop=True)
        return resultDf

    def checkLackMonthData(self, tempDf, codeList):
        totalCodeList = list(tempDf['stock_code'].unique())
        lackCode = [code for code in codeList if code not in totalCodeList]
        haveCode = list(set(codeList).difference(lackCode))
        for code in haveCode:
            if tempDf[tempDf['stock_code'] == code].shape[0] < 2:
                lackCode.append(code)
        return lackCode

    def getMonthData(self,
                     codeList=[],
                     startDate='2019-03-01',
                     endDate='2019-05-30'):
        totalTradeList = [startDate, endDate]
        sqlstr = "select * from stock_month_value where stock_code in %s and update_time in %s" % (
            tuple(codeList), tuple(totalTradeList))
        tempDf = pd.read_sql(sql=sqlstr, con=self.engine)
        lackCode = self.checkLackMonthData(tempDf, codeList)
        if lackCode:
            self.logger.debug("getMonthData从wind获取,缺失code: %s" %
                              ','.join(lackCode))
            dfList = []

            for tradeDate in [startDate, endDate]:
                tradeDateStr = tradeDate[:4] + tradeDate[5:7] + tradeDate[8:]
                wssData = w.wss(codes=lackCode,
                                fields=['close', 'sec_name'],
                                options="tradeDate=%s;priceAdj=F;cycle=M" %
                                tradeDateStr)
                if wssData.ErrorCode != 0:
                    self.logger.error("获取股票截面行情价格有误,错误代码" +
                                      str(wssData.ErrorCode))
                    return pd.DataFrame()
                df = pd.DataFrame(wssData.Data,
                                  columns=wssData.Codes,
                                  index=wssData.Fields).T
                df.rename(columns={
                    "CLOSE": "close_price",
                    "SEC_NAME": "stock_name"
                },
                          inplace=True)
                df['update_time'] = [tradeDate] * len(df)
                df['stock_code'] = df.index.tolist()
                dfList.append(df)
            tempLackDf = pd.concat(dfList, axis=0, sort=True)
            self.GetDataToMysqlDemo.GetMain(tempLackDf, 'stock_month_value')
            tempDf = pd.concat([tempDf, tempLackDf], axis=0, sort=True)
            tempDf = tempDf.drop_duplicates(
                subset=['stock_code', 'update_time'])
        else:
            self.logger.debug("getMonthData从本地数据库获取!")
        return tempDf[['stock_code', 'close_price', 'update_time']]

    def getHQData(self,
                  tempCode,
                  startDate='2019-03-01',
                  endDate='2019-05-30',
                  tableFlag='index',
                  nameList=['close_price']):
        '''
        #获取指数行情数据入口
        '''
        self.getLackDataToMySql(tempCode, startDate, endDate, tableFlag)
        resultDf = self.getDataFromMySql(tempCode,
                                         startDate,
                                         endDate,
                                         tableFlag=tableFlag,
                                         nameList=nameList)
        return resultDf

    def getRiskFree(self, startDate='2019-03-01', endDate='2019-05-30'):
        wsetdata = w.wsd(codes=["SHI3MS1Y.IR"],
                         fields=["close"],
                         beginTime=startDate,
                         endTime=endDate)
        if wsetdata.ErrorCode != 0:
            self.logger.error("获取行情数据有误,错误代码" + str(wsetdata.ErrorCode))
            return
        tempDf = pd.DataFrame(wsetdata.Data,
                              index=wsetdata.Fields,
                              columns=wsetdata.Times).T
        return tempDf

    def getTradeDay(self, startDate, endDate, Period=''):
        '''
        获取指定周期交易日,封装wind接口
        :param Period: ''日,W周,M月,Q季,S半年,Y年
        :return:
        '''
        # w.start()
        data = w.tdays(beginTime=startDate,
                       endTime=endDate,
                       options="Period=%s" % Period)
        if data.ErrorCode != 0:
            self.logger.error('wind获取交易日期错误,请检查!')
            return
        tradeDayList = data.Data[0]
        tradeDayList = [
            tradeDay.strftime('%Y-%m-%d') for tradeDay in tradeDayList
        ]
        df = pd.DataFrame(tradeDayList, columns=['tradeDate'])
        df['startDate'] = [startDate] * df.shape[0]
        df['endDate'] = [endDate] * df.shape[0]
        df['Period'] = [Period] * df.shape[0]
        return df
Ejemplo n.º 11
0
 def __init__(self):
     self.conn = MysqlCon().getMysqlCon(flag='connect')
     self.logger = mylog.logger
Ejemplo n.º 12
0
 def __init__(self):
     MysqlConDemo = MysqlCon()
     self.engine = MysqlConDemo.getMysqlCon('engine')
     self.GetDataTotalMainDemo = GetDataTotalMain(data_resource='wind')
     self.industry_trade_limit = 20000000