Esempio n. 1
0
 def get_data_to_mysql(self):
     GetDataToMysqlDemo = GetDataToMysql()
     total_file = os.listdir(self.target_path)
     name_dic = {"管理费率\n[单位]%": "管理费率百分", "托管费率\n[单位]%": "托管费率百分"}
     for file_name in total_file:
         if file_name.find('同花顺主题基金') != -1:
             work_book = load_workbook(self.target_path + file_name)
             all_sheets = work_book.sheetnames
             if all_sheets:
                 for sheet_name in all_sheets:
                     self.logger.info("读取%s主题基金..." % sheet_name)
                     temp_df = pd.read_excel(
                         self.target_path + file_name,
                         sheet_name=sheet_name).iloc[:-2]
                     temp_df.rename(columns=name_dic, inplace=True)
                     temp_df.replace({'上市日期': '--'}, np.nan, inplace=True)
                     temp_df.replace({'是否分级基金': '--'}, '否', inplace=True)
                     temp_df.replace({'跟踪指数同花顺代码': '--'},
                                     np.nan,
                                     inplace=True)
                     temp_df['主题名称'] = sheet_name
                     temp_df.rename(columns=self.mysql_name_dic,
                                    inplace=True)
                     temp_df['record_time'] = datetime.today().strftime(
                         "%Y-%m-%d")
                     GetDataToMysqlDemo.GetMain(temp_df,
                                                tableName='ths_topic_fund')
                     self.logger.info("%s主题基金存储数据库成功!" % sheet_name)
Esempio n. 2
0
class GetDataFromWindAndMySql:
    def __init__(self):
        self.wsetData = [
            "000001.SH", "399300.SZ", "000016.SH", "000905.SH", "000906.SH"
        ]  # 要获取数据的证券代码
        self.indexFieldName = [
            "open", "high", "low", "close", "volume", "amt", "chg", "pct_chg",
            "turn"
        ]  # 要获取的数据字段
        self.fundFieldName = ["nav", "NAV_acc", "sec_name"]
        self.monetaryFund = [
            "mmf_annualizedyield", "mmf_unityield", "sec_name"
        ]
        self.stockFieldName = [
            "open", "high", "low", "close", "volume", "amt", "turn",
            "mkt_cap_ard", "pe_ttm", "ps_ttm", "pb_lf"
        ]
        self.engine = MysqlCon().getMysqlCon(flag='engine')
        self.conn = MysqlCon().getMysqlCon(flag='connect')
        self.GetDataToMysqlDemo = GetDataToMysql()
        self.logger = mylog.logger
        # w.start()

        log_state = THS_iFinDLogin('zszq5072', '754628')
        if log_state == 0:
            self.logger.info("同花顺账号登录成功!")
        else:
            self.logger.error("同花顺账号登录异常,请检查!")
            return

    def getStockMonthToMySql(self):
        start_date = '2011-11-01'
        end_date = '2019-12-31'
        total_trade_list = self.getTradeDay(startDate=start_date,
                                            endDate=end_date,
                                            Period='M')
        wsetdata = w.wset(
            "sectorconstituent",
            "date=%s;sectorid=a001010100000000" % total_trade_list[0])
        if wsetdata.ErrorCode != 0:
            self.logger.debug("获取全A股数据有误,错误代码" + str(wsetdata.ErrorCode))
        index_df = pd.DataFrame(wsetdata.Data,
                                index=wsetdata.Fields,
                                columns=wsetdata.Codes).T
        if index_df.empty:
            return

        for trade_date in total_trade_list:
            optionstr = "tradeDate=%s;cycle=M" % (
                trade_date[:4] + trade_date[5:7] + trade_date[8:])
            wssdata = w.wss(
                codes=index_df['wind_code'].tolist(),
                fields=["mkt_freeshares", "pe_ttm", "ps_ttm", "pct_chg"],
                options=optionstr)
            if wssdata.ErrorCode != 0:
                self.logger.debug("获取因子数据有误,错误代码" + str(wssdata.ErrorCode))
                return pd.DataFrame()
            resultDf = pd.DataFrame(wssdata.Data,
                                    index=wssdata.Fields,
                                    columns=wssdata.Codes).T

            df_list = []
            for col in resultDf:
                temp_df = pd.DataFrame(resultDf[col].values,
                                       index=resultDf.index,
                                       columns=['factor_value'])
                temp_df['update_time'] = trade_date
                temp_df['stock_code'] = resultDf.index.tolist()
                temp_df['factor_name'] = col
                df_list.append(temp_df)
            total_fa_df = pd.concat(df_list, axis=0, sort=True)
            self.GetDataToMysqlDemo.GetMain(total_fa_df,
                                            'stock_factor_month_value')
            self.logger.info("存储日期%s因子数据成功!" % trade_date)

    def getFactorValue(self,
                       code_list=[],
                       factor_list=[],
                       start_date='2019-04-01',
                       end_date='2019-05-01'):
        # 获取截面因子数据
        sqlStr = "select * from stock_factor_month_value where stock_code in %s and factor_name in %s and update_time='%s'" % (
            str(tuple(code_list)), str(tuple(factor_list)), start_date)
        resultDf = pd.read_sql(sql=sqlStr, con=self.engine)

        if resultDf.empty:
            dateFormat = start_date[:4] + start_date[5:7] + start_date[8:]
            wssdata = w.wss(codes=code_list,
                            fields=factor_list,
                            options="unit=1;tradeDate=%s" % dateFormat)
            if wssdata.ErrorCode != 0:
                self.logger.debug("获取因子数据有误,错误代码" + str(wssdata.ErrorCode))
                return pd.DataFrame()
            resultDf = pd.DataFrame(wssdata.Data,
                                    index=wssdata.Fields,
                                    columns=wssdata.Codes).T
        else:
            df_list = []
            for factor, temp_df in resultDf.groupby(by='factor_name'):
                temp = pd.DataFrame(temp_df['factor_value'].values,
                                    index=temp_df['stock_code'],
                                    columns=[factor])
                df_list.append(temp)
            resultDf = pd.concat(df_list, sort=True, axis=1)
            return resultDf
        return resultDf

    def getIndexConstituent(self, indexCode='000300.SH', getDate='2019-06-06'):
        '''
        获取指数成分股
        :param indexCode:
        :param getDate:
        :return:
        '''
        if indexCode != '全A股':
            sqlStr = "select * from index_constituent where index_code='%s' and update_time='%s'" % (
                indexCode, getDate)
            resultDf = pd.read_sql(sql=sqlStr, con=self.engine)
            if resultDf.empty:
                wsetdata = w.wset("indexconstituent",
                                  "date=%s;windcode=%s" % (getDate, indexCode))
                if wsetdata.ErrorCode != 0:
                    self.logger.error("获取指数成分股数据有误,错误代码" +
                                      str(wsetdata.ErrorCode))
                    return pd.DataFrame()

                resultDf = pd.DataFrame(wsetdata.Data, index=wsetdata.Fields).T
                if resultDf.empty:
                    wsetdata = w.wset(
                        "sectorconstituent",
                        "date=%s;windcode=%s" % (getDate, indexCode))
                    if wsetdata.ErrorCode != 0:
                        self.logger.error("获取板块指数成分股数据有误,错误代码" +
                                          str(wsetdata.ErrorCode))
                        return pd.DataFrame()

                    resultDf = pd.DataFrame(wsetdata.Data,
                                            index=wsetdata.Fields).T
                    if resultDf.empty:
                        self.logger.info("指定日期内,未找到有效成分股数据")
                    return resultDf

                dateList = [
                    datetampStr.strftime('%Y-%m-%d')
                    for datetampStr in resultDf['date'].tolist()
                ]
                resultDf['date'] = dateList
                nameDic = {
                    'date': 'adjust_time',
                    'wind_code': 'stock_code',
                    "sec_name": 'stock_name',
                    'i_weight': 'stock_weight'
                }
                resultDf.rename(columns=nameDic, inplace=True)
                resultDf['update_time'] = getDate
                resultDf['index_code'] = indexCode
                if 'industry' in resultDf:
                    resultDf.drop(labels='industry', inplace=True, axis=1)
                self.GetDataToMysqlDemo.GetMain(resultDf, 'index_constituent')
        else:
            sqlStr = "select distinct stock_code from stock_factor_month_value where update_time='%s'" % getDate
            resultDf = pd.read_sql(sql=sqlStr, con=self.engine)
        return resultDf

    def getLackDataToMySql(self, tempCode, startDate, endDate):
        tableStr = 'index_value'
        codeName = 'index_code'

        sqlStr = "select max(update_time),min(update_time) from %s where %s='%s'" % (
            tableStr, codeName, tempCode)
        cursor = self.conn.cursor()
        cursor.execute(sqlStr)
        dateStrTuple = cursor.fetchall()[0]
        maxDate = dateStrTuple[0]
        minDate = dateStrTuple[1]

        if not maxDate:
            self.get_hq_data_from_ths(tempCode,
                                      startDate=startDate,
                                      endDate=endDate)
            return

        if endDate < minDate or startDate > minDate:
            self.get_hq_data_from_ths(tempCode,
                                      startDate=startDate,
                                      endDate=endDate)
        elif startDate <= minDate:
            if minDate <= endDate < maxDate:
                if startDate != minDate:
                    self.get_hq_data_from_ths(tempCode,
                                              startDate=startDate,
                                              endDate=minDate)
            elif endDate >= maxDate:
                self.get_hq_data_from_ths(tempCode,
                                          startDate=startDate,
                                          endDate=minDate)
                if endDate != maxDate:
                    self.get_hq_data_from_ths(tempCode,
                                              startDate=maxDate,
                                              endDate=endDate)
        elif endDate > maxDate:
            self.get_hq_data_from_ths(tempCode,
                                      startDate=maxDate,
                                      endDate=endDate)

    def getDataFromMySql(self,
                         tempCode,
                         startDate,
                         endDate,
                         tableFlag='index',
                         nameList=['close_price']):
        if not nameList:
            self.logger.error('传入获取指数的字段不合法,请检查!')
        if tableFlag == 'index':
            tableStr = 'index_value'
            codeName = 'index_code'
        elif tableFlag == 'fund':
            codeName = 'fund_code'
            tableStr = 'fund_net_value'
        elif tableFlag == 'stock':
            codeName = 'stock_code'
            tableStr = 'stock_hq_value'
        elif tableFlag == 'private':
            codeName = 'fund_code'
            tableStr = 'private_net_value'
        elif tableFlag == 'monetary_fund':
            codeName = 'fund_code'
            tableStr = 'monetary_fund'

        sqlStr = "select %s,update_time from %s where %s='%s' and  update_time>='%s'" \
                 " and update_time<='%s'" % (','.join(nameList), tableStr, codeName, tempCode, startDate, endDate)
        resultDf = pd.read_sql(sql=sqlStr, con=self.engine)
        resultDf = resultDf.drop_duplicates('update_time').sort_index()
        resultDf.set_index(keys='update_time', inplace=True, drop=True)
        return resultDf

    def getCurrentNameData(self,
                           tempCodeList,
                           startDate,
                           endDate,
                           tableFlag='stock',
                           nameStr='close_price'):
        '''
        获取指定字段的数据
        '''
        if tableFlag == 'stock':
            totalCodeStr = ''
            for stockCode in tempCodeList:
                totalCodeStr = totalCodeStr + stockCode + "','"

            sqlStr1 = "select max(update_time),min(update_time) from stock_hq_value where stock_code in ('%s')" % totalCodeStr[:
                                                                                                                               -3]
            cursor = self.conn.cursor()
            cursor.execute(sqlStr1)
            dateStrTuple = cursor.fetchall()[0]
            maxDate = dateStrTuple[0]
            minDate = dateStrTuple[1]

            if not maxDate:
                for tempCode in tempCodeList:
                    self.getDataFromWind(tempCode,
                                         startDate=startDate,
                                         endDate=endDate,
                                         tableFlag=tableFlag)
                    return
            else:
                if endDate < minDate or startDate > minDate:
                    for tempCode in tempCodeList:
                        self.getDataFromWind(tempCode,
                                             startDate=startDate,
                                             endDate=endDate,
                                             tableFlag=tableFlag)
                elif startDate <= minDate:
                    if minDate <= endDate < maxDate:
                        for tempCode in tempCodeList:
                            self.getDataFromWind(tempCode,
                                                 startDate=startDate,
                                                 endDate=minDate,
                                                 tableFlag=tableFlag)
                    elif endDate >= maxDate:
                        for tempCode in tempCodeList:
                            self.getDataFromWind(tempCode,
                                                 startDate=startDate,
                                                 endDate=minDate,
                                                 tableFlag=tableFlag)
                            self.getDataFromWind(tempCode,
                                                 startDate=maxDate,
                                                 endDate=endDate,
                                                 tableFlag=tableFlag)
                elif endDate >= maxDate:
                    for tempCode in tempCodeList:
                        self.getDataFromWind(tempCode,
                                             startDate=maxDate,
                                             endDate=endDate,
                                             tableFlag=tableFlag)

            sqlStr = "select %s,update_time,stock_code from stock_hq_value where stock_code in ('%s') and update_time<='%s' " \
                     "and update_time>='%s'" % (nameStr, totalCodeStr, endDate, startDate)
            resultDf = pd.read_sql(sql=sqlStr, con=self.engine)
            dfList = []
            for code, tempDf in resultDf.groupby('stock_code'):
                df = pd.DataFrame(tempDf[nameStr].values,
                                  index=tempDf['update_time'],
                                  columns=[code])
                dfList.append(df)
            resultDf = pd.concat(dfList, axis=1)
            return resultDf

    def getCurrentDateData(self,
                           tempCodeList,
                           getDate,
                           tableFlag='stock',
                           nameList=['close_price']):
        '''
        获取指定日期的截面数据
        :return:
        '''
        if tableFlag == 'stock':
            totalCodeStr = ""
            for stockCode in tempCodeList:
                totalCodeStr = totalCodeStr + stockCode + "','"

            sqlStr = "select * from stock_hq_value where stock_code in ('%s') and update_time='%s'" % (
                totalCodeStr[:-3], getDate)
            resultDf = pd.read_sql(sql=sqlStr, con=self.engine)
            if resultDf.empty:
                codes = tempCodeList
                fields = self.stockFieldName
                tradeDate = getDate
                wssData = w.wss(codes=codes,
                                fields=fields,
                                options="tradeDate=%s;priceAdj=F;cycle=D" %
                                tradeDate)
                if wssData.ErrorCode != 0:
                    self.logger.error("获取行情数据有误,错误代码" + str(wssData.ErrorCode))
                    return pd.DataFrame()
                tempDf = pd.DataFrame(wssData.Data,
                                      index=fields,
                                      columns=codes).T
                tempDf.dropna(inplace=True)
                if tempDf.empty:
                    self.logger.error("当前日期%s无行情" % getDate)
                    return pd.DataFrame()

                tempDf['update_time'] = getDate
                nameDic = {
                    "open": "open_price",
                    "high": "high_price",
                    "low": "low_price",
                    "close": "close_price",
                    "mkt_cap_ard": "market_value",
                }
                tempDf.rename(columns=nameDic, inplace=True)

                tempDf['stock_code'] = tempDf.index.tolist()
                self.GetDataToMysqlDemo.GetMain(tempDf, 'stock_hq_value')
                returnDf = tempDf[nameList]
                return returnDf
            else:
                resultDf.set_index('stock_code', drop=True, inplace=True)
                returnDf = resultDf[nameList]
                return returnDf

    def getFirstDayData(self, codeList, tableFlag='fund'):
        if tableFlag == 'fund':
            wssData = w.wss(codes=codeList, fields=["fund_setupDate"])
            if wssData.ErrorCode != 0:
                self.logger.error("getFirstDayData获取wind数据错误,错误代码%s" %
                                  wssData.ErrorCode)
                return pd.DataFrame()
            self.logger.debug("getFirstDayData获取wind数据成功!")
            resultDf = pd.DataFrame(wssData.Data,
                                    columns=wssData.Codes,
                                    index=wssData.Fields)
            return resultDf

    def get_hq_data_from_ths(
        self,
        tempCode,
        startDate='2019-04-01',
        endDate='2019-04-30',
    ):
        ths_data = THS_HistoryQuotes(
            thscode=tempCode,
            jsonIndicator=
            'open,high,low,close,volume,changeRatio,turnoverRatio,change',
            jsonparam=
            'Interval:D,CPS:6,baseDate:1900-01-01,Currency:YSHB,fill:Previous',
            begintime=startDate,
            endtime=endDate,
            outflag=True)
        if ths_data['errorcode'] != 0:
            self.logger.error("同花顺获取行情数据错误,请检查:%s" % ths_data['errmsg'])
            return

        tempDf = THS_Trans2DataFrame(ths_data)
        tempDf.dropna(how='all', inplace=True)
        tempDf[codeName] = tempCode
        tempDf['update_time'] = tempDf.index.tolist()
        tempDf.rename(columns=nameDic, inplace=True)
        dateList = [
            dateStr.strftime("%Y-%m-%d")
            for dateStr in tempDf['update_time'].tolist()
        ]
        tempDf['update_time'] = dateList
        self.GetDataToMysqlDemo.GetMain(tempDf, 'index_value')
        return tempDf

    def getHQData(self,
                  tempCode,
                  startDate='2019-03-01',
                  endDate='2019-05-30',
                  tableFlag='index',
                  nameList=['close_price']):
        '''
        #获取指数行情数据入口
        '''
        self.getLackDataToMySql(tempCode, startDate, endDate, tableFlag)
        resultDf = self.getDataFromMySql(tempCode,
                                         startDate,
                                         endDate,
                                         tableFlag=tableFlag,
                                         nameList=nameList)
        return resultDf

    def getTradeDay(self, startDate, endDate, Period='M'):
        '''
        获取指定周期交易日,封装wind接口
        :param Period: ''日,W周,M月,Q季,S半年,Y年
        :return:
        '''
        # w.start()
        data = w.tdays(beginTime=startDate,
                       endTime=endDate,
                       options="Period=%s" % Period)
        if data.ErrorCode != 0:
            self.logger.error('wind获取交易日期错误,请检查!')
            return
        tradeDayList = data.Data[0]
        tradeDayList = [
            tradeDay.strftime('%Y-%m-%d') for tradeDay in tradeDayList
        ]
        # w.close()
        return tradeDayList
class GetDataFromWindAndMySql:
    def __init__(self):
        self.wsetData = [
            "000001.SH", "399300.SZ", "000016.SH", "000905.SH", "000906.SH"
        ]  # 要获取数据的证券代码
        self.indexFieldName = [
            "open", "high", "low", "close", "volume", "amt", "chg", "pct_chg",
            "turn"
        ]  # 要获取的数据字段
        self.fundFieldName = ["nav", "NAV_acc", "sec_name"]
        self.monetaryFund = [
            "mmf_annualizedyield", "mmf_unityield", "sec_name"
        ]
        self.stockFieldName = [
            "open", "high", "low", "close", "volume", "amt", "turn",
            "mkt_cap_ard", "pe_ttm", "ps_ttm", "pb_lf"
        ]
        self.engine = MysqlCon().getMysqlCon(flag='engine')
        self.conn = MysqlCon().getMysqlCon(flag='connect')
        self.GetDataToMysqlDemo = GetDataToMysql()
        self.logger = mylog.logger

    def getBelongIndustry(self, codeList, tradeDate='2018-12-31'):
        '''
        获取股票所属的行业
        注:查询所属行业前,应确保查询日期大于股票ipo日期,该逻辑不在本方法验证。
        :return:
        '''
        sqlStr = "select * from stock_industry_value where stock_code in %s and update_time='%s'" % (
            tuple(codeList), tradeDate)
        resultDf = pd.read_sql(sql=sqlStr, con=self.engine)
        if resultDf.empty:
            self.logger.debug("getBelongIndustry从wind获取!")
            tradeDateParam = tradeDate[:4] + tradeDate[5:7] + tradeDate[8:]
            wssData = w.wss(codes=codeList,
                            fields=["industry_citic"],
                            options="tradeDate=%s;industryType=1" %
                            tradeDateParam)
            if wssData.ErrorCode != 0:
                self.logger.error("获取指数成分股数据有误,错误代码" + str(wssData.ErrorCode))
                return pd.DataFrame()
            df = pd.DataFrame(wssData.Data,
                              index=wssData.Fields,
                              columns=wssData.Codes).T
            df.rename(columns={"INDUSTRY_CITIC": "industry_name"},
                      inplace=True)
            df['stock_code'] = df.index.tolist()
            df['update_time'] = tradeDate
            df['industry_wind_code'] = ["industry_citic"] * df.shape[0]
            df['industry_flag'] = [1] * df.shape[0]
            self.GetDataToMysqlDemo.GetMain(df, 'stock_industry_value')
            resultDf = df[['stock_code', 'industry_name']]
        else:
            self.logger.debug("getBelongIndustry从本地数据库获取!")
            resultDf = resultDf[['stock_code', 'industry_name']]
        return resultDf

    def getFactorReportData(self,
                            codeList,
                            factors,
                            rptDate='2018-12-31',
                            backYears=0):
        # 单个年报数据的获取
        sqlStr = "select stock_code,item_value from stock_factor_value where stock_code in %s and update_time='%s' and item_wind_code='%s' " \
                 % (tuple(codeList), rptDate, factors[0])
        resultDf = pd.read_sql(sql=sqlStr, con=self.engine)
        if resultDf.empty:
            self.logger.debug("getFactorReportData从wind数据库获取%s!" % factors[0])
            rptDateParam = rptDate[:4] + rptDate[5:7] + rptDate[8:]
            if backYears == 0:
                if "wgsd_assets" not in factors:
                    wssData = w.wss(codes=codeList,
                                    fields=factors,
                                    options="rptDate=%s" % rptDateParam)
                else:
                    wssData = w.wss(
                        codes=codeList,
                        fields=factors,
                        options="unit=1;rptDate=%s;rptType=1;currencyType=" %
                        rptDateParam)
            else:
                wssData = w.wss(codes=codeList,
                                fields=factors,
                                options="rptDate=%s;N=%s" %
                                (rptDateParam, str(backYears)))

            if wssData.ErrorCode != 0:
                self.logger.error("getFactorDailyData获取%s有误,错误代码%s" %
                                  (factors, str(wssData.ErrorCode)))
                return pd.DataFrame()
            df = pd.DataFrame(wssData.Data,
                              columns=wssData.Codes,
                              index=factors).T

            resultDf = df.copy()
            df['stock_code'] = df.index.tolist()
            df['update_time'] = rptDate
            df.rename(columns={factors: "item_value"}, inplace=True)
            df['item_wind_code'] = factors[0]
            df['rpt_flag'] = 1
            self.GetDataToMysqlDemo.GetMain(df, 'stock_factor_value')
        else:
            self.logger.debug("getFactorReportData从本地数据库获取%s!" % factors[0])
            resultDf = resultDf.set_index(
                "stock_code",
                drop=True).rename(columns={"item_value": factors[0]})
        return resultDf

    def getPetChg(self, codeList, startDate, endDate):
        '''
        获取股票区间涨跌幅数据
        :return:
        '''
        sqlStr = "select stock_code,pct_chg_value from stock_range_updown_value where stock_code in %s and start_date='%s' and end_date='%s'" % (
            tuple(codeList), startDate, endDate)
        resultDf = pd.read_sql(sql=sqlStr, con=self.engine)
        if not resultDf.empty:
            lackCode = [
                code for code in codeList
                if code not in resultDf['stock_code'].tolist()
            ]
        else:
            lackCode = codeList[:]

        if lackCode:
            self.logger.debug("getPetChg从wind获取!")
            startDateParam = startDate[:4] + startDate[5:7] + startDate[8:]
            endDateParam = endDate[:4] + endDate[5:7] + endDate[8:]
            wssData = w.wss(codes=lackCode,
                            fields=["pct_chg_per"],
                            options="startDate=%s;endDate=%s" %
                            (startDateParam, endDateParam))
            if wssData.ErrorCode != 0:
                self.logger.error("getPetChg获取pct_chg_per有误,错误代码%s" %
                                  (str(wssData.ErrorCode)))
                return pd.DataFrame()
            df = pd.DataFrame(wssData.Data,
                              index=["pct_chg_value"],
                              columns=wssData.Codes).T
            df['stock_code'] = df.index.tolist()
            df['start_date'] = startDate
            df['end_date'] = endDate
            self.GetDataToMysqlDemo.GetMain(df, 'stock_range_updown_value')
            resultDf = pd.concat([resultDf, df], axis=0,
                                 sort=True)[['stock_code', 'pct_chg_value']]
        else:
            self.logger.debug("getPetChg从本地数据库获取!")
        resultDf.set_index('stock_code', inplace=True, drop=True)
        return resultDf

    def getFactorDailyData(self, codeList, factors, tradeDate='2018-12-31'):
        # 单个非年报数据的获取
        sqlStr = "select stock_code,item_value from stock_factor_value where stock_code in %s and update_time='%s' and item_wind_code='%s' " \
                 % (tuple(codeList), tradeDate, factors[0])
        resultDf = pd.read_sql(sql=sqlStr, con=self.engine)
        if resultDf.empty:
            self.logger.debug("getFactorDailyData从wind获取%s!" % factors[0])
            tradeDateParam = tradeDate[:4] + tradeDate[5:7] + tradeDate[8:]
            if 'mkt_cap_float' not in factors:
                wssData = w.wss(codes=codeList,
                                fields=factors,
                                options="tradeDate=%s" % tradeDateParam)
            else:
                wssData = w.wss(codes=codeList,
                                fields=factors,
                                options="unit=1;tradeDate=%s;currencyType=" %
                                tradeDateParam)
            if wssData.ErrorCode != 0:
                self.logger.error("getFactorDailyData获取%s有误,错误代码%s" %
                                  (factors, str(wssData.ErrorCode)))
                return pd.DataFrame()
            df = pd.DataFrame(wssData.Data,
                              index=factors,
                              columns=wssData.Codes).T
            resultDf = df.copy()
            df['stock_code'] = df.index.tolist()
            df['update_time'] = tradeDate
            df.rename(columns={factors[0]: "item_value"}, inplace=True)
            df['item_wind_code'] = factors[0]
            df['rpt_flag'] = 0
            self.GetDataToMysqlDemo.GetMain(df, 'stock_factor_value')
        else:
            self.logger.debug("getFactorDailyData从本地数据库获取%s!" % factors[0])
            resultDf = resultDf.set_index(
                "stock_code",
                drop=True).rename(columns={"item_value": factors[0]})
        return resultDf

    def getIndexConstituent(self,
                            indexCode='000300.SH',
                            getDate='2019-06-06',
                            indexOrSector='index'):
        '''
        获取指数成分股
        '''
        sqlStr = "select * from index_constituent where index_code='%s' and update_time='%s'" % (
            indexCode, getDate)
        resultDf = pd.read_sql(sql=sqlStr, con=self.engine)
        if resultDf.empty:
            self.logger.debug("getIndexConstituent从wind获取!")
            if indexOrSector == 'index':
                wsetdata = w.wset("indexconstituent",
                                  "date=%s;windcode=%s" % (getDate, indexCode))
            else:
                wsetdata = w.wset("sectorconstituent",
                                  "date=2019-08-21;windcode=%s" % indexCode)

            if wsetdata.ErrorCode != 0:
                self.logger.error("获取指数成分股数据有误,错误代码" + str(wsetdata.ErrorCode))
                return pd.DataFrame()

            resultDf = pd.DataFrame(wsetdata.Data, index=wsetdata.Fields).T
            dateList = [
                datetampStr.strftime('%Y-%m-%d')
                for datetampStr in resultDf['date'].tolist()
            ]
            resultDf['date'] = dateList
            nameDic = {
                'date': 'adjust_time',
                'wind_code': 'stock_code',
                "sec_name": 'stock_name',
                'i_weight': 'stock_weight'
            }
            resultDf.rename(columns=nameDic, inplace=True)
            resultDf['update_time'] = getDate
            resultDf['index_code'] = indexCode
            if 'stock_weight' not in resultDf.columns.tolist():
                resultDf['stock_weight'] = np.nan
            self.GetDataToMysqlDemo.GetMain(resultDf, 'index_constituent')

        else:
            self.logger.debug("getIndexConstituent从本地获取!")
        return resultDf

    def getLackDataToMySql(self,
                           tempCode,
                           startDate,
                           endDate,
                           tableFlag='index'):
        if tableFlag == 'index':
            tableStr = 'index_value'
            codeName = 'index_code'
        elif tableFlag == 'fund':
            tableStr = 'fund_net_value'
            codeName = 'fund_code'
        elif tableFlag == 'stock':
            tableStr = 'stock_hq_value'
            codeName = 'stock_code'
        elif tableFlag == 'private':
            return
        elif tableFlag == 'monetary_fund':
            tableStr = 'monetary_fund'
            codeName = 'fund_code'

        sqlStr = "select max(update_time),min(update_time) from %s where %s='%s'" % (
            tableStr, codeName, tempCode)
        cursor = self.conn.cursor()
        cursor.execute(sqlStr)
        dateStrTuple = cursor.fetchall()[0]
        maxDate = dateStrTuple[0]
        minDate = dateStrTuple[1]

        if not maxDate:
            self.getDataFromWind(tempCode,
                                 startDate=startDate,
                                 endDate=endDate,
                                 tableFlag=tableFlag)
            return

        if endDate < minDate or startDate > minDate:
            self.getDataFromWind(tempCode,
                                 startDate=startDate,
                                 endDate=endDate,
                                 tableFlag=tableFlag)
        elif startDate <= minDate:
            if minDate <= endDate < maxDate:
                if startDate != minDate:
                    self.getDataFromWind(tempCode,
                                         startDate=startDate,
                                         endDate=minDate,
                                         tableFlag=tableFlag)
            elif endDate >= maxDate:
                self.getDataFromWind(tempCode,
                                     startDate=startDate,
                                     endDate=minDate,
                                     tableFlag=tableFlag)
                if endDate != maxDate:
                    self.getDataFromWind(tempCode,
                                         startDate=maxDate,
                                         endDate=endDate,
                                         tableFlag=tableFlag)
        elif endDate > maxDate:
            self.getDataFromWind(tempCode,
                                 startDate=maxDate,
                                 endDate=endDate,
                                 tableFlag=tableFlag)

    def getDataFromWind(self,
                        tempCode,
                        startDate='2019-04-01',
                        endDate='2019-04-30',
                        tableFlag='index'):
        if tableFlag == 'index':
            tableStr = 'index_value'
            nameDic = {
                "OPEN": "open_price",
                "HIGH": "high_price",
                "LOW": "low_price",
                "CLOSE": "close_price",
                "VOLUME": "volume",
                "AMT": "amt",
                "CHG": "chg",
                "PCT_CHG": "pct_chg",
                "TURN": "turn"
            }
            fields = self.indexFieldName
            codeName = 'index_code'
        elif tableFlag == 'fund':
            tableStr = 'fund_net_value'
            nameDic = {
                "NAV": "net_value",
                "NAV_ACC": "acc_net_value",
                "SEC_NAME": "fund_name"
            }
            fields = self.fundFieldName
            codeName = 'fund_code'
        elif tableFlag == 'stock':
            tableStr = 'stock_hq_value'
            nameDic = {
                "OPEN": "open_price",
                "HIGH": "high_price",
                "LOW": "low_price",
                "CLOSE": "close_price",
                "VOLUME": "volume",
                "AMT": "amt",
                "TURN": "turn",
                "MKT_CAP_ARD": "market_value",
                "PE_TTM": "pe_ttm",
                "PS_TTM": "ps_ttm",
                "PB_LF": "pb_lf"
            }
            fields = self.stockFieldName
            codeName = 'stock_code'
        elif tableFlag == 'monetary_fund':
            tableStr = 'monetary_fund'
            nameDic = {
                "MMF_ANNUALIZEDYIELD": "week_annual_return",
                "MMF_UNITYIELD": "wan_unit_return",
                "SEC_NAME": "fund_name"
            }
            fields = self.monetaryFund
            codeName = 'fund_code'

        if tableFlag == 'stock':
            wsetdata = w.wsd(codes=tempCode,
                             fields=fields,
                             beginTime=startDate,
                             endTime=endDate,
                             options="PriceAdj=F")
        else:
            wsetdata = w.wsd(codes=tempCode,
                             fields=fields,
                             beginTime=startDate,
                             endTime=endDate)

        if wsetdata.ErrorCode != 0:
            self.logger.error("获取行情数据有误,错误代码" + str(wsetdata.ErrorCode))
            return

        tempDf = pd.DataFrame(wsetdata.Data,
                              index=wsetdata.Fields,
                              columns=wsetdata.Times).T
        tempDf.dropna(how='all', inplace=True)
        tempDf[codeName] = tempCode
        tempDf['update_time'] = tempDf.index.tolist()
        tempDf.rename(columns=nameDic, inplace=True)
        dateList = [
            dateStr.strftime("%Y-%m-%d")
            for dateStr in tempDf['update_time'].tolist()
        ]
        tempDf['update_time'] = dateList
        self.GetDataToMysqlDemo.GetMain(tempDf, tableStr)
        return tempDf

    def getDataFromMySql(self,
                         tempCode,
                         startDate,
                         endDate,
                         tableFlag='index',
                         nameList=['close_price']):
        if not nameList:
            self.logger.error('传入获取指数的字段不合法,请检查!')

        if tableFlag == 'index':
            tableStr = 'index_value'
            codeName = 'index_code'
        elif tableFlag == 'fund':
            codeName = 'fund_code'
            tableStr = 'fund_net_value'
        elif tableFlag == 'stock':
            codeName = 'stock_code'
            tableStr = 'stock_hq_value'
        elif tableFlag == 'private':
            codeName = 'fund_code'
            tableStr = 'private_net_value'
        elif tableFlag == 'monetary_fund':
            codeName = 'fund_code'
            tableStr = 'monetary_fund'

        sqlStr = "select %s,update_time from %s where %s='%s' and  update_time>='%s'" \
                 " and update_time<='%s'" % (','.join(nameList), tableStr, codeName, tempCode, startDate, endDate)
        resultDf = pd.read_sql(sql=sqlStr, con=self.engine)
        resultDf = resultDf.drop_duplicates('update_time').sort_index()
        resultDf.set_index(keys='update_time', inplace=True, drop=True)
        return resultDf

    def checkLackMonthData(self, tempDf, codeList):
        totalCodeList = list(tempDf['stock_code'].unique())
        lackCode = [code for code in codeList if code not in totalCodeList]
        haveCode = list(set(codeList).difference(lackCode))
        for code in haveCode:
            if tempDf[tempDf['stock_code'] == code].shape[0] < 2:
                lackCode.append(code)
        return lackCode

    def getMonthData(self,
                     codeList=[],
                     startDate='2019-03-01',
                     endDate='2019-05-30'):
        totalTradeList = [startDate, endDate]
        sqlstr = "select * from stock_month_value where stock_code in %s and update_time in %s" % (
            tuple(codeList), tuple(totalTradeList))
        tempDf = pd.read_sql(sql=sqlstr, con=self.engine)
        lackCode = self.checkLackMonthData(tempDf, codeList)
        if lackCode:
            self.logger.debug("getMonthData从wind获取,缺失code: %s" %
                              ','.join(lackCode))
            dfList = []

            for tradeDate in [startDate, endDate]:
                tradeDateStr = tradeDate[:4] + tradeDate[5:7] + tradeDate[8:]
                wssData = w.wss(codes=lackCode,
                                fields=['close', 'sec_name'],
                                options="tradeDate=%s;priceAdj=F;cycle=M" %
                                tradeDateStr)
                if wssData.ErrorCode != 0:
                    self.logger.error("获取股票截面行情价格有误,错误代码" +
                                      str(wssData.ErrorCode))
                    return pd.DataFrame()
                df = pd.DataFrame(wssData.Data,
                                  columns=wssData.Codes,
                                  index=wssData.Fields).T
                df.rename(columns={
                    "CLOSE": "close_price",
                    "SEC_NAME": "stock_name"
                },
                          inplace=True)
                df['update_time'] = [tradeDate] * len(df)
                df['stock_code'] = df.index.tolist()
                dfList.append(df)
            tempLackDf = pd.concat(dfList, axis=0, sort=True)
            self.GetDataToMysqlDemo.GetMain(tempLackDf, 'stock_month_value')
            tempDf = pd.concat([tempDf, tempLackDf], axis=0, sort=True)
            tempDf = tempDf.drop_duplicates(
                subset=['stock_code', 'update_time'])
        else:
            self.logger.debug("getMonthData从本地数据库获取!")
        return tempDf[['stock_code', 'close_price', 'update_time']]

    def getHQData(self,
                  tempCode,
                  startDate='2019-03-01',
                  endDate='2019-05-30',
                  tableFlag='index',
                  nameList=['close_price']):
        '''
        #获取指数行情数据入口
        '''
        self.getLackDataToMySql(tempCode, startDate, endDate, tableFlag)
        resultDf = self.getDataFromMySql(tempCode,
                                         startDate,
                                         endDate,
                                         tableFlag=tableFlag,
                                         nameList=nameList)
        return resultDf

    def getRiskFree(self, startDate='2019-03-01', endDate='2019-05-30'):
        wsetdata = w.wsd(codes=["SHI3MS1Y.IR"],
                         fields=["close"],
                         beginTime=startDate,
                         endTime=endDate)
        if wsetdata.ErrorCode != 0:
            self.logger.error("获取行情数据有误,错误代码" + str(wsetdata.ErrorCode))
            return
        tempDf = pd.DataFrame(wsetdata.Data,
                              index=wsetdata.Fields,
                              columns=wsetdata.Times).T
        return tempDf

    def getTradeDay(self, startDate, endDate, Period=''):
        '''
        获取指定周期交易日,封装wind接口
        :param Period: ''日,W周,M月,Q季,S半年,Y年
        :return:
        '''
        # w.start()
        data = w.tdays(beginTime=startDate,
                       endTime=endDate,
                       options="Period=%s" % Period)
        if data.ErrorCode != 0:
            self.logger.error('wind获取交易日期错误,请检查!')
            return
        tradeDayList = data.Data[0]
        tradeDayList = [
            tradeDay.strftime('%Y-%m-%d') for tradeDay in tradeDayList
        ]
        df = pd.DataFrame(tradeDayList, columns=['tradeDate'])
        df['startDate'] = [startDate] * df.shape[0]
        df['endDate'] = [endDate] * df.shape[0]
        df['Period'] = [Period] * df.shape[0]
        return df
Esempio n. 4
0
class GetFundFinanceReportData:
    def __init__(self):
        self.logger = mylog.set_log()
        self.GetDataToMysqlDemo = GetDataToMysql()

    def get_fund_stock_info(self,
                            third_conn,
                            engine,
                            total_date_list,
                            fund_code='100053.OF'):
        rpt_date_str_list = []
        for rpt_date in total_date_list:
            if rpt_date[-5:] == '06-30':
                name_str = rpt_date[:4] + '年中报'
            else:
                name_str = rpt_date[:4] + '年年报'
            rpt_date_str_list.append(name_str)
        sql_str = "select * from fund_contain_stock_detail where rpt_date in %s and fund_code='%s'" % (
            str(tuple(rpt_date_str_list)), fund_code)
        result_df = pd.read_sql(sql=sql_str, con=engine)
        have_rpt_str_list = result_df['rpt_date'].tolist()
        lack_rpt_list = [
            rpt_date for rpt_date in rpt_date_str_list
            if rpt_date not in have_rpt_str_list
        ]
        name_mysql_dic = {
            'sec_name': 'fund_name',
            'marketvalueofstockholdings': 'market_value_of_stockholdings',
            'proportiontototalstockinvestments': 'pro_total_stock_inve',
            'proportiontonetvalue': 'pro_net_value',
            'proportiontoshareholdtocirculation': 'pro_sharehold_cir'
        }
        if lack_rpt_list:
            temp_df_list = []
            for lack_rpt in lack_rpt_list:
                lack_date = total_date_list[rpt_date_str_list.index(lack_rpt)]
                rptdate = ''.join(lack_date.split('-'))
                options = "rptdate=%s;windcode=%s" % (rptdate, fund_code)
                wset_data = third_conn.wset(tablename="allfundhelddetail",
                                            options=options)
                if wset_data.ErrorCode != 0:
                    self.logger.error('wind获取基金持股明细数据错误,错误代码%s,请检查!' %
                                      wset_data.ErrorCode)
                    return pd.DataFrame()
                temp_rpt_df = pd.DataFrame(wset_data.Data,
                                           index=wset_data.Fields,
                                           columns=wset_data.Codes).T
                if temp_rpt_df.empty:
                    continue
                temp_rpt_df['fund_code'] = fund_code
                temp_rpt_df['record_time'] = datetime.today().strftime(
                    "%Y-%m-%d")
                temp_rpt_df.rename(columns=name_mysql_dic, inplace=True)
                self.GetDataToMysqlDemo.GetMain(temp_rpt_df,
                                                'fund_contain_stock_detail')
                self.logger.info("存储%s,报告期%s持股数据成功!" % (fund_code, lack_rpt))
                temp_df_list.append(temp_rpt_df)
            if temp_df_list:
                temp_df = pd.concat(temp_df_list, axis=0, sort=True)
                result_df = pd.concat([result_df, temp_df], axis=0, sort=True)
        return result_df

    def get_main(self):
        pass
Esempio n. 5
0
class DataWash:
    def __init__(self):
        self.GetDataToMysqlDemo = GetDataToMysql()

    def getMain(self):
        dateStr = "2018-2019"

        #市净率导入
        # df = pd.read_excel("市净率%s.xlsx"%dateStr,index_col=[0])
        # dfList = []
        # for colName in df.columns:
        #     tempDf = df[[colName]].copy()
        #     tempDf['stock_code'] = tempDf.index.tolist()
        #     tempDf['item_wind_code'] = 'pb_lf'
        #     # tempDf['item_value'] = df[colName].tolist()
        #     tempDf['rpt_flag'] = [0]*tempDf.shape[0]
        #     tempDf['update_time'] = colName[-18:-8]
        #     tempDf.rename(columns={colName:"item_value"},inplace=True)
        #     dfList.append(tempDf)
        # totalDf = pd.concat(dfList,axis=0,sort=True)
        # self.GetDataToMysqlDemo.GetMain(totalDf, 'stock_factor_value')

        #总市值导入
        # df = pd.read_excel("总市值%s.xlsx" % dateStr, index_col=[0])
        # dfList = []
        # for colName in df.columns:
        #     tempDf = df[[colName]].copy()
        #     tempDf['stock_code'] = tempDf.index.tolist()
        #     tempDf['item_wind_code'] = 'mkt_cap_ard'
        #     # tempDf['item_value'] = df[colName].tolist()
        #     tempDf['rpt_flag'] = [0] * tempDf.shape[0]
        #     tempDf['update_time'] = colName[-18:-8]
        #     tempDf.rename(columns={colName: "item_value"}, inplace=True)
        #     dfList.append(tempDf)
        # totalDf = pd.concat(dfList, axis=0, sort=True)
        # self.GetDataToMysqlDemo.GetMain(totalDf, 'stock_factor_value')

        # 中信一级行业导入
        # df = pd.read_excel("中信行业%s.xlsx" % dateStr, index_col=[0])
        # dfList = []
        # for colName in df.columns:
        #     tempDf = df[[colName]].copy()
        #     tempDf['stock_code'] = tempDf.index.tolist()
        #     tempDf['industry_wind_code'] = 'industry_citic'
        #     tempDf['update_time'] = colName[-23:-13]
        #     tempDf['industry_flag'] = 1
        #     tempDf.rename(columns={colName: "industry_name"}, inplace=True)
        #     dfList.append(tempDf)
        # totalDf = pd.concat(dfList, axis=0, sort=True)
        # self.GetDataToMysqlDemo.GetMain(totalDf, 'stock_industry_value')

        # 总资产导入
        df = pd.read_excel("总资产.xlsx", index_col=[0])
        dfList = []
        for colName in df.columns:
            tempDf = df[[colName]].copy()
            tempDf['stock_code'] = tempDf.index.tolist()
            tempDf['item_wind_code'] = 'wgsd_assets'
            tempDf['rpt_flag'] = [1] * tempDf.shape[0]
            tempDf['update_time'] = colName[11:15] + '-12-31'
            tempDf.rename(columns={colName: "item_value"}, inplace=True)
            dfList.append(tempDf)
        totalDf = pd.concat(dfList, axis=0, sort=True)
        self.GetDataToMysqlDemo.GetMain(totalDf, 'stock_factor_value')
Esempio n. 6
0
class GetDataTotalMain:
    def __init__(self, data_resource='ifind'):
        self.logger = mylog.set_log()
        self.dic_init = {}
        self.dic_init['data_resource'] = data_resource
        self.dic_init['data_init_flag'] = self.log_init(data_resource)
        mysql_con_demo = MysqlCon()
        self.engine = mysql_con_demo.getMysqlCon(flag='engine')
        self.conn = mysql_con_demo.getMysqlCon(flag='connect')
        self.GetDataToMysqlDemo = GetDataToMysql()

    def log_init(self, data_resource='ifind'):
        '''
        登录客户端初始化
        :param data_resource:
        :return:
        '''
        flag = True
        if data_resource == 'ifind':
            log_state = THS_iFinDLogin('zszq5072', '754628')
            if log_state == 0:
                self.logger.info("同花顺账号登录成功!")
            else:
                self.logger.error("同花顺账号登录异常,请检查!")
                flag = False
        elif data_resource == 'wind':
            try:
                w.start()
            except:
                self.logger.info("wind启动失败")
                flag = False
        return flag

    def get_index_constituent(self,
                              indexCode='000300.SH',
                              getDate='2020-02-01'):
        '''
        获取指数成分股
        :param indexCode:指数代码
        :param getDate:获取日期
        :return:df,指数代码、成分股代码、成分股名称,权重
        '''
        sqlStr = "select * from index_constituent where index_code='%s' and update_time='%s'" % (
            indexCode, getDate)
        resultDf = pd.read_sql(sql=sqlStr, con=self.engine)
        nameDic = {
            'date': 'adjust_time',
            'wind_code': 'stock_code',
            "sec_name": 'stock_name',
            'i_weight': 'stock_weight',
            'DATE': 'adjust_time',
            'THSCODE': 'stock_code',
            "SECURITY_NAME": 'stock_name',
            'WEIGHT': 'stock_weight'
        }
        if resultDf.empty:
            if self.dic_init['data_resource'] == 'wind':
                wsetdata = w.wset("indexconstituent",
                                  "date=%s;windcode=%s" % (getDate, indexCode))
                if wsetdata.ErrorCode != 0:
                    self.logger.error("获取指数成分股数据有误,错误代码" +
                                      str(wsetdata.ErrorCode))
                    return pd.DataFrame()
                resultDf = pd.DataFrame(wsetdata.Data, index=wsetdata.Fields).T
                if resultDf.empty:
                    wsetdata = w.wset(
                        "sectorconstituent",
                        "date=%s;windcode=%s" % (getDate, indexCode))
                    if wsetdata.ErrorCode != 0:
                        self.logger.error("获取板块指数成分股数据有误,错误代码" +
                                          str(wsetdata.ErrorCode))
                        return pd.DataFrame()
                resultDf = pd.DataFrame(wsetdata.Data, index=wsetdata.Fields).T
                if resultDf.empty:
                    self.logger.info("指定日期内,未找到有效成分股数据")
                    return pd.DataFrame()
                dateList = [
                    datetampStr.strftime('%Y-%m-%d')
                    for datetampStr in resultDf['date'].tolist()
                ]
                resultDf['date'] = dateList
                if 'industry' in resultDf:
                    resultDf.drop(labels='industry', inplace=True, axis=1)
            elif self.dic_init['data_resource'] == 'ifind':
                param_name = '%s;%s' % (getDate, indexCode)
                fun_option = 'date:Y,thscode:Y,security_name:Y,weight:Y'
                ths_data = THS_DataPool(DataPoolname='index',
                                        paramname=param_name,
                                        FunOption=fun_option,
                                        outflag=False)
                if ths_data['errorcode'] != 0:
                    self.logger.error("同花顺获取指数成分股数据错误,请检查:%s" %
                                      ths_data['errmsg'])
                    return pd.DataFrame()
                resultDf = THS_Trans2DataFrame(ths_data)
            resultDf.rename(columns=nameDic, inplace=True)
            resultDf['update_time'] = getDate
            resultDf['index_code'] = indexCode
            self.GetDataToMysqlDemo.GetMain(resultDf, 'index_constituent')
        return resultDf

    def get_hq_data_to_Mysql(self,
                             code,
                             start_date='2019-04-01',
                             end_date='2019-04-30',
                             code_style='index'):
        '''
        获取指定的行情数据,保存至本地数据库
        :param code:
        :param start_date:
        :param end_date:
        :param code_style:
        :return:
        '''
        tempDf = pd.DataFrame()
        if code_style == 'index':
            code_style_name = 'index_code'
            table_str = 'index_value'
            if self.dic_init['data_resource'] == 'ifind':
                if code[-3:] == '.CS':
                    code = code[:-3] + '.CB'

                name_dic = {
                    'open': "open_price",
                    "high": "high_price",
                    "low": "low_price",
                    "close": "close_price",
                    "changeRatio": "pct_chg",
                    "turnoverRatio": "turn",
                    "change": "chg",
                    'time': "update_time",
                    "thscode": "index_code"
                }
                data_fileds = 'open,high,low,close,volume,changeRatio,turnoverRatio,change'
                data_fileds_params = 'Interval:D,CPS:6,baseDate:1900-01-01,Currency:YSHB,fill:Previous'
                ths_data = THS_HistoryQuotes(thscode=code,
                                             jsonIndicator=data_fileds,
                                             jsonparam=data_fileds_params,
                                             begintime=start_date,
                                             endtime=end_date,
                                             outflag=False)
                if ths_data['errorcode'] != 0:
                    self.logger.error("同花顺获取行情数据错误,请检查:%s" %
                                      ths_data['errmsg'])
                    return pd.DataFrame()
                tempDf = THS_Trans2DataFrame(ths_data)

            elif self.dic_init['data_resource'] == 'wind':
                name_dic = {
                    "OPEN": "open_price",
                    "HIGH": "high_price",
                    "LOW": "low_price",
                    "CLOSE": "close_price",
                    "VOLUME": "volume",
                    "AMT": "amt",
                    "CHG": "chg",
                    "PCT_CHG": "pct_chg",
                    "TURN": "turn",
                }
                data_fileds = [
                    "open", "high", "low", "close", "volume", "amt", "chg",
                    "pct_chg", "turn"
                ]  # 要获取的数据字段
                data_fileds_params = "PriceAdj=F"
                wsetdata = w.wsd(codes=code,
                                 fields=data_fileds,
                                 beginTime=start_date,
                                 endTime=end_date,
                                 options=data_fileds_params)
                if wsetdata.ErrorCode != 0:
                    self.logger.error("获取行情数据有误,错误代码" +
                                      str(wsetdata.ErrorCode))
                    return pd.DataFrame()
                tempDf = pd.DataFrame(wsetdata.Data,
                                      index=wsetdata.Fields,
                                      columns=wsetdata.Times).T
                dateList = [
                    dateStr.strftime("%Y-%m-%d")
                    for dateStr in tempDf.index.tolist()
                ]
                tempDf['update_time'] = dateList
                tempDf[code_style_name] = code
            tempDf.dropna(how='all', inplace=True)
            tempDf.rename(columns=name_dic, inplace=True)
            self.GetDataToMysqlDemo.GetMain(tempDf, table_str)
        elif code_style == 'fund':
            code_style_name = 'fund_code'
            table_str = 'fund_net_value'
            if self.dic_init['data_resource'] == 'ifind':
                name_dic = {
                    'time': "update_time",
                    "thscode": "fund_code",
                    "ths_unit_nv_fund": "net_value",
                    "ths_accum_unit_nv_fund": "acc_net_value"
                }
                data_fileds = 'ths_unit_nv_fund;ths_accum_unit_nv_fund'
                data_fileds_params = 'Days:Tradedays,Fill:Previous,Interval:D'

                ths_data = THS_DateSerial(thscode=code,
                                          jsonIndicator=data_fileds,
                                          globalparam=data_fileds_params,
                                          jsonparam=';',
                                          begintime=start_date,
                                          endtime=end_date,
                                          outflag=False)
                if ths_data['errorcode'] != 0:
                    self.logger.error("同花顺获取基金净值数据错误,请检查:%s" %
                                      ths_data['errmsg'])
                    return pd.DataFrame()
                tempDf = THS_Trans2DataFrame(ths_data)
            elif self.dic_init['data_resource'] == 'wind':
                name_dic = {
                    "NAV": "net_value",
                    "NAV_ACC": "acc_net_value",
                    "NAV_ADJ": "net_value_adj"
                }
                wsddata = w.wsd(code, "nav,NAV_acc,NAV_adj", start_date,
                                end_date, "")
                if wsddata.ErrorCode != 0:
                    self.logger.error("获取基金净值数据有误,错误代码" +
                                      str(wsddata.ErrorCode))
                    return pd.DataFrame()
                tempDf = pd.DataFrame(wsddata.Data,
                                      index=wsddata.Fields,
                                      columns=wsddata.Times).T
                dateList = [
                    dateStr.strftime("%Y-%m-%d")
                    for dateStr in tempDf.index.tolist()
                ]
                tempDf['update_time'] = dateList
                tempDf[code_style_name] = code
            tempDf.dropna(how='all', inplace=True)
            tempDf.rename(columns=name_dic, inplace=True)
            self.GetDataToMysqlDemo.GetMain(tempDf, table_str)
        elif code_style == 'etf_fund':
            code_style_name = 'fund_code'
            table_str = 'etf_hq_value'
            if self.dic_init['data_resource'] == 'ifind':
                name_dic = {
                    'open': "open_price",
                    "high": "high_price",
                    "low": "low_price",
                    "close": "close_price",
                    "changeRatio": "pct_chg",
                    "turnoverRatio": "turn",
                    "change": "chg",
                    'time': "update_time",
                    "thscode": "fund_code",
                    "avgPrice": "vwap"
                }
                data_fileds = 'open,high,low,close,volume,changeRatio,turnoverRatio,change'
                data_fileds_params = 'Interval:D,CPS:6,baseDate:1900-01-01,Currency:YSHB,fill:Previous'
                ths_data = THS_HistoryQuotes(thscode=code,
                                             jsonIndicator=data_fileds,
                                             jsonparam=data_fileds_params,
                                             begintime=start_date,
                                             endtime=end_date,
                                             outflag=False)
                if ths_data['errorcode'] != 0:
                    self.logger.error("同花顺获取行情数据错误,请检查:%s" %
                                      ths_data['errmsg'])
                    return pd.DataFrame()
                tempDf = THS_Trans2DataFrame(ths_data)

            elif self.dic_init['data_resource'] == 'wind':
                name_dic = {
                    "OPEN": "open_price",
                    "HIGH": "high_price",
                    "LOW": "low_price",
                    "CLOSE": "close_price",
                    "VOLUME": "volume",
                    "AMT": "amt",
                    "CHG": "chg",
                    "PCT_CHG": "pct_chg",
                    "TURN": "turn",
                    'VWAP': 'vwap'
                }
                data_fileds = [
                    "open", "high", "low", "close", "volume", "amt", "chg",
                    "pct_chg", "turn", "vwap"
                ]  # 要获取的数据字段
                data_fileds_params = "PriceAdj=F"
                wsetdata = w.wsd(codes=code,
                                 fields=data_fileds,
                                 beginTime=start_date,
                                 endTime=end_date,
                                 options=data_fileds_params)
                if wsetdata.ErrorCode != 0:
                    self.logger.error("获取行情数据有误,错误代码" +
                                      str(wsetdata.ErrorCode))
                    return pd.DataFrame()
                tempDf = pd.DataFrame(wsetdata.Data,
                                      index=wsetdata.Fields,
                                      columns=wsetdata.Times).T
                tempDf[tempDf.isnull()] = np.nan
                tempDf.dropna(how='all', inplace=True)
                if tempDf.empty:
                    self.logger.info("wind获取%s etf基金数据为空,请检查" % code)
                    return
                dateList = [
                    dateStr.strftime("%Y-%m-%d")
                    for dateStr in tempDf.index.tolist()
                ]
                tempDf['update_time'] = dateList
                tempDf[code_style_name] = code
                tempDf['record_time'] = datetime.today().strftime("%Y-%m-%d")
            tempDf.dropna(how='all', inplace=True)
            tempDf.rename(columns=name_dic, inplace=True)
            self.GetDataToMysqlDemo.GetMain(tempDf, table_str)

    def get_lackdata_to_MySql(self,
                              code,
                              startDate,
                              endDate,
                              code_style='index'):
        '''
        获取指定时间内的行情数据,保存至本地数据库
        :param code: 代码:指数(股票,基金待兼容)
        :param startDate: 开始日期
        :param endDate: 截止日期
        :return:
        '''
        self.logger.info("检查本地%s,%s-%s缺失行情数据..." % (code, startDate, endDate))
        if code_style == 'index':
            tableStr = 'index_value'
            codeName = 'index_code'
        elif code_style == 'fund':
            tableStr = 'fund_net_value'
            codeName = 'fund_code'
        elif code_style == 'stock':
            tableStr = 'stock_hq_value'
            codeName = 'stock_code'
        elif code_style == 'etf_fund':
            tableStr = 'etf_hq_value'
            codeName = 'fund_code'
        sqlStr = "select max(update_time),min(update_time) from %s where %s='%s'" % (
            tableStr, codeName, code)
        cursor = self.conn.cursor()
        cursor.execute(sqlStr)
        dateStrTuple = cursor.fetchall()[0]
        maxDate = dateStrTuple[0]
        minDate = dateStrTuple[1]

        codetion1 = (not maxDate) or (endDate < minDate or startDate > maxDate)
        if codetion1:
            start_date = startDate
            end_date = endDate
        elif startDate <= minDate:
            if minDate <= endDate < maxDate:
                if startDate != minDate:
                    start_date = startDate
                    end_date = minDate
                else:
                    return
            elif endDate >= maxDate:
                if startDate != minDate:
                    self.get_hq_data_to_Mysql(code,
                                              start_date=startDate,
                                              end_date=minDate,
                                              code_style=code_style)
                if endDate != maxDate:
                    start_date = maxDate
                    end_date = endDate
                else:
                    return
        elif endDate > maxDate:
            start_date = maxDate
            end_date = endDate
        elif startDate >= minDate and endDate <= maxDate:
            return

        if (datetime.strptime(end_date, "%Y-%m-%d") -
                timedelta(days=0)).strftime("%Y-%m-%d") <= start_date:
            self.logger.info("行情数据起止时间间隔不够,不用补全。startdate: %s;enddate: %s" %
                             (start_date, end_date))
            return
        else:
            self.get_hq_data_to_Mysql(code,
                                      start_date=start_date,
                                      end_date=end_date,
                                      code_style=code_style)

    def get_date_from_MySql(self,
                            code,
                            start_date,
                            end_date,
                            code_style='index',
                            name_list=['close_price']):
        if not name_list:
            self.logger.error('传入获取指数的字段不合法,请检查!')

        if code_style == 'index':
            table_str = 'index_value'
            code_name = 'index_code'
        elif code_style == 'fund':
            table_str = 'fund_net_value'
            code_name = "fund_code"
        elif code_style == 'stock':
            table_str = 'stock_hq_value'
            code_name = 'stock_code'
        elif code_style == 'etf_fund':
            table_str = 'etf_hq_value'
            code_name = 'fund_code'

        sqlStr = "select %s,update_time from %s where %s='%s' and  update_time>='%s'" \
                 " and update_time<='%s'" % (','.join(name_list), table_str, code_name, code, start_date, end_date)
        resultDf = pd.read_sql(sql=sqlStr, con=self.engine)
        resultDf = resultDf.drop_duplicates('update_time').sort_index()
        resultDf.set_index(keys='update_time', inplace=True, drop=True)
        return resultDf

    def get_hq_data(self,
                    code,
                    start_date,
                    end_date,
                    code_style='index',
                    name_list=['close_price']):
        '''
        获取行情数据主入口
        :param code:
        :param start_date:
        :param end_date:
        :param code_syle:
        :param name_list:
        :return:
        '''
        self.get_lackdata_to_MySql(code, start_date, end_date, code_style)
        resultDf = self.get_date_from_MySql(code, start_date, end_date,
                                            code_style, name_list)
        return resultDf

    def get_tradeday(self, start_date, end_date, period='M'):
        '''
        获取指定周期交易日
        :param start_date:
        :param end_date:
        :param period: ''日,W周,M月,Q季,S半年,Y年
        :return:
        '''
        if self.dic_init['data_resource'] == 'wind':
            data = w.tdays(beginTime=start_date,
                           endTime=end_date,
                           options="Period=%s" % period)
            if data.ErrorCode != 0:
                self.logger.error('wind获取交易日期错误,请检查!')
                return
            tradeDayList = data.Data[0]
            tradeDayList = [
                tradeDay.strftime('%Y-%m-%d') for tradeDay in tradeDayList
            ]
        elif self.dic_init['data_resource'] == 'ifind':
            ths_data = THS_DateQuery(
                'SSE', 'dateType:0,period:%s,dateFormat:0' % period,
                start_date, end_date)
            if ths_data['errorcode'] != 0:
                self.logger.error("同花顺获取交易日期数据错误,请检查:%s" % ths_data['errmsg'])
                return
            tradeDayList = ths_data['tables']['time']
        return tradeDayList

    def facort_wind_ifind_to_mysql(self):
        '''
        wind与ifind原始因子简称,存入到本地数据库时的映射
        :return:
        '''
        name_dic = {}
        name_dic['mkt_freeshares'] = "MKT_FREESHARES"
        name_dic['pe_ttm'] = "PE_TTM"
        name_dic['ps_ttm'] = "PS_TTM"
        name_dic['pct_chg'] = "PCT_CHG"
        name_dic['rating_avg'] = "RATING_AVG"

        name_dic['ths_pe_ttm_stock'] = "PE_TTM"
        name_dic['ths_ps_ttm_stock'] = "PS_TTM"
        name_dic['ths_chg_ratio_m_stock'] = "PCT_CHG"
        name_dic['ths_compre_rating_value_stock'] = "RATING_AVG"
        name_dic['ths_free_float_mv_stock'] = "MKT_FREESHARES"
        name_dic['ths_pb_mrq_stock'] = 'PB_MRQ'
        name_dic['ths_beta_24m_stock'] = "BETA_24M"
        return name_dic

    def get_stock_month_to_MySql(self,
                                 code_list,
                                 factor_list,
                                 start_date='2011-11-01',
                                 end_date='2019-12-31'):
        '''
        获取全A股月度因子数据,保存至本地数据库
        注:(1)因子参数需谨慎检查(2)该方法不适宜回测调用,应一次调用存入数据库后,使用get_factor_value方法
        :return:
        '''
        self.logger.info("获取股票月度因子数据,请确保因子在wind与ifind的参数名称")
        total_trade_list = self.get_tradeday(start_date, end_date, period='M')
        name_dic = self.facort_wind_ifind_to_mysql()
        for trade_date in total_trade_list:
            if self.dic_init['data_resource'] == 'wind':
                factor_list = [
                    "mkt_freeshares", "pe_ttm", "ps_ttm", "pct_chg",
                    "rating_avg"
                ]
                optionstr = "tradeDate=%s;cycle=M" % (
                    trade_date[:4] + trade_date[5:7] + trade_date[8:])
                wssdata = w.wss(codes=code_list,
                                fields=factor_list,
                                options=optionstr)
                if wssdata.ErrorCode != 0:
                    self.logger.error("wind获取因子数据有误,错误代码" +
                                      str(wssdata.ErrorCode))
                    return pd.DataFrame()
                resultDf = pd.DataFrame(wssdata.Data,
                                        index=wssdata.Fields,
                                        columns=wssdata.Codes).T

            elif self.dic_init['data_resource'] == 'ifind':
                thscode = ','.join(code_list)
                indicator = 'ths_pb_mrq_stock;ths_pe_ttm_stock;ths_ps_ttm_stock;ths_chg_ratio_m_stock;ths_compre_rating_value_stock;ths_free_float_mv_stock;ths_beta_24m_stock'
                optionstr = '%s;%s,100;%s,100;%s,100;%s,30;%s;%s' % (
                    trade_date, trade_date, trade_date, trade_date, trade_date,
                    trade_date, trade_date)

                ths_data = THS_BasicData(thsCode=thscode,
                                         indicatorName=indicator,
                                         paramOption=optionstr,
                                         outflag=False)
                if ths_data['errorcode'] != 0:
                    self.logger.error("同花顺获取因子数据错误,请检查:%s" %
                                      ths_data['errmsg'])
                    return pd.DataFrame()
                resultDf = THS_Trans2DataFrame(ths_data).set_index(
                    keys='thscode')
            resultDf.rename(columns=name_dic, inplace=True)

            df_list = []
            for col in resultDf:
                temp_df = pd.DataFrame(resultDf[col].values,
                                       index=resultDf.index,
                                       columns=['factor_value'])
                temp_df['update_time'] = trade_date
                temp_df['stock_code'] = resultDf.index.tolist()
                temp_df['factor_name'] = col
                df_list.append(temp_df)
            total_fa_df = pd.concat(df_list, axis=0, sort=True)
            self.GetDataToMysqlDemo.GetMain(total_fa_df,
                                            'stock_factor_month_value')
            self.logger.info("存储日期%s因子数据成功!" % trade_date)

    def get_factor_value(self,
                         code_list=[],
                         factor_list=[],
                         get_date='2019-08-30'):
        # 获取截面因子数据
        sqlStr = "select * from stock_factor_month_value where stock_code in %s and factor_name in %s and update_time='%s'" % (
            str(tuple(code_list)), str(tuple(factor_list)), get_date)
        resultDf = pd.read_sql(sql=sqlStr, con=self.engine)
        if not resultDf.empty:
            df_list = []
            for factor, temp_df in resultDf.groupby(by='factor_name'):
                temp = pd.DataFrame(temp_df['factor_value'].values,
                                    index=temp_df['stock_code'],
                                    columns=[factor])
                df_list.append(temp)
            resultDf = pd.concat(df_list, sort=True, axis=1)
        return resultDf

    def get_fund_size(self, code_list=[]):
        start_date = (datetime.today() - timedelta(days=30)).strftime("%Y%m%d")
        end_date = datetime.today().strftime("%Y%m%d")
        option_str = "unit=1;startDate=%s;endDate=%s" % (start_date, end_date)
        w.wsd("159962.SZ,512660.SH", "amt", "2020-07-01", "2020-07-30", "")
        wsdData = w.wsd(codes=code_list,
                        fields="amt",
                        beginTime=start_date,
                        endTime=end_date)
        if wsdData.ErrorCode != 0:
            self.logger.error("wind获取场内基金日均成交额有误,错误代码" +
                              str(wsdData.ErrorCode))
            return pd.DataFrame()
        result_Se = pd.DataFrame(wsdData.Data,
                                 index=wsdData.Codes,
                                 columns=wsdData.Times).T.median()
        result_Se.name = '日均成交额'
        return result_Se

        # wssdata = w.wss(codes=code_list, fields=["avg_amt_per"], options=option_str)
        # if wssdata.ErrorCode != 0:
        #     self.logger.error("wind获取场内基金日均成交额有误,错误代码" + str(wssdata.ErrorCode))
        #     return pd.DataFrame()
        # resultDf = pd.DataFrame(wssdata.Data, index=wssdata.Fields, columns=wssdata.Codes).T
        # resultDf.rename(columns={"avg_amt_per".upper(): "日均成交额"},inplace=True)
        # return resultDf

    def get_fund_base_info(self, fund_code_list=[]):
        wssdata = w.wss(codes=fund_code_list,
                        fields=[
                            "fund_firstinvesttype", "fund_investtype",
                            "fund_setupdate", "fund_trackindexcode",
                            "fund_fullname"
                        ])
        if wssdata.ErrorCode != 0:
            self.logger.error("wind获取场内基金日均成交额有误,错误代码" +
                              str(wssdata.ErrorCode))
            return pd.DataFrame()
        resultDf = pd.DataFrame(wssdata.Data,
                                index=wssdata.Fields,
                                columns=wssdata.Codes).T
        name_dic = {
            "fund_firstinvesttype".upper(): "基金类型",
            "fund_investtype".upper(): "产品类型",
            "fund_setupdate".upper(): "基金成立日",
            'fund_trackindexcode'.upper(): "跟踪指数代码",
            "fund_fullname".upper(): "基金全称"
        }
        resultDf.rename(columns=name_dic, inplace=True)
        return resultDf