class BaseBarFeedCalculator: ''' 基于单个instrument的因子打分计算基类 ''' __metaclass__ = abc.ABCMeta __slots__ = ('factorManager', 'feed', 'outputScore', 'name') logger = logger.getLogger("BaseBarFeedCalculator") def __init__(self, factorManager, barFeed): ''' :param factorManagerCls: 可能会调用从factorManagerCls中加载面板计算的指标 :param barFeed: barFeed操作方法与cta策略类似 :param maxLen: ''' self.factorManager = factorManager self.feed = barFeed self.name = self.__class__.__name__ self.logger.debug("The factor manager: {}".format(factorManager)) self.logger.debug("The feed type: {}".format(self.feed)) @abc.abstractmethod def calScore(self, barFeed, dateTime, bar): ''' 当新时刻的数据到达时,计算最新的因子打分值并返回 :param barFeed: :param dateTime: :param bar: :return: 无结果填充np.nan ''' raise NotImplementedError
class ReportWriter: ''' 通过调用factorTest下的TestReportGenerator类来实现图表的计算和存储 一种方式是通过传入factorTest下的DefaultFactorTest类对象,调用其下的几个panel 还有一种是通过传入h5BatchPanelReader,通过读取h5文件来获取panel ''' logger = logger.getLogger("ReportWriter") def __init__(self, factorName, defaultFactorTest=None, h5BatchPanelReader=None): ''' param factorName: 因子名 param defaultFactorTest: 因子检测类对象 param h5BatchPanelReader: h5文件读取类对象 ''' self.factorName = factorName self.defaultFactorTest = defaultFactorTest self.h5BatchPanelReader = h5BatchPanelReader self.frequency = defaultFactorTest.frequency if defaultFactorTest\ else h5BatchPanelReader.frequency if self.defaultFactorTest and self.h5BatchPanelReader: self.logger.info( "Either defaultFactorTest or h5BatchPanelReader must be None") return self.testReportGenerator = factorTest.TestReportGenerator( self.defaultFactorTest, self.h5BatchPanelReader) def write(self): ''' 将分层收益图和分层统计量分别写入对于的图和表文件 ''' currentDT = datetime.datetime.now() # 储存分层收益图 figName = self.factorName + '_Report_' +\ const.DataFrequency.freq2lable(self.frequency) + '.png' # currentDT.strftime("_%Y%m%d_%H%M") + '.png' path = pathSelector.PathSelector.getFactorFilePath( factorName=self.factorName, # 因子计算数据的文件路径 factorFrequency=const.DataFrequency.freq2lable(self.frequency), fileName=figName) self.testReportGenerator.plotGroupret(_show=False, path=path) # 储存分层统计量 statisticFileName = self.factorName + '_Statistic_' +\ const.DataFrequency.freq2lable(self.frequency) + '.xls' # currentDT.strftime("_%Y%m%d_%H%M") + '.xls' self.testReportGenerator.statistic( path=pathSelector.PathSelector.getFactorFilePath( factorName=self.factorName, # 因子计算数据的文件路径 factorFrequency=const.DataFrequency.freq2lable(self.frequency), fileName=statisticFileName))
class CSVFactorPanelReader(CSVPanelReader): logger = logger.getLogger('factorReader') def __init__(self, poolName, factorName, sectionName, fileName=None, frequency=bar.Frequency.MINUTE, isInstrmentCol=True, start=None, end=None): path = pathSelector.PathSelector.getFactorFilePath( poolName, factorName, sectionName) super(CSVFactorPanelReader, self).__init__(path, fileName, frequency, isInstrmentCol, start, end)
class BasePanelCalculator: ''' 矩阵化处理计算因子,某些因子比较简单,使用这种方法速度更快 ''' __metaclass__ = abc.ABCMeta __slots__ = ('factorManager', 'feed', 'outputScore', 'name') logger = logger.getLogger("BasePanelCalculator") def __init__(self, factorManager, panelFeed): # 可能会调用从factorManager中加载面板计算的指标 self.factorManager = factorManager self.feed = panelFeed self.name = self.__class__.__name__ self.logger.debug("The factor manager used: {}".format(factorManager)) self.logger.debug("The feed type: {}".format(self.feed)) @abc.abstractmethod def calSore(self, panelFeed, dateTime, df): raise NotImplementedError
class TestReportGenerator: ''' 生成检测报告的图和表 ''' logger.getLogger("TestReportGenerator") def __init__(self, defaultFactorTest, h5BatchPanelReader): ''' 当写新的因子时,传入因子检测类直接用因子检测类中的panel作图作表 当续写因子时,传入h5BatchPanelReader类,通过reader读取的数据作图作表 param defaultFactorTest: 因子检测类 param h5BatchPanelReader: h5文件读取类 ''' self.defaultFactorTest = defaultFactorTest self.h5BatchPanelReader = h5BatchPanelReader if self.defaultFactorTest and self.h5BatchPanelReader: self.logger.info( "Either defaultFactorTest or h5BatchPanelReader must be None") return if self.defaultFactorTest: self.groupRet = self.defaultFactorTest.groupRetPanel.to_frame() self.IC = self.defaultFactorTest.ICPanel.to_frame() self.rankIC = self.defaultFactorTest.rankICPanel.to_frame() self.turn = self.defaultFactorTest.turnPanel.to_frame() self.cost = self.defaultFactorTest.costPanel.to_frame() self.groupNumber = self.defaultFactorTest.groupNumberPanel.to_frame( ) elif self.h5BatchPanelReader: for key, value in self.h5BatchPanelReader.staticPanelDict.items(): indicator = key.split("_")[1] if indicator == "groupRet": self.groupRet = value.to_frame() if indicator == "IC": self.IC = value.to_frame() if indicator == "rankIC": self.rankIC = value.to_frame() if indicator == "turn": self.turn = value.to_frame() if indicator == "cost": self.cost = value.to_frame() if indicator == "groupNumber": self.groupNumber = value.to_frame() self.turn.index = [str(i) for i in self.turn.index.values] self.cost.index = [str(i) for i in self.cost.index.values] self.groupRet.index = [str(i) for i in self.groupRet.index.values] self.IC.index = [str(i) for i in self.IC.index.values] self.rankIC.index = [str(i) for i in self.rankIC.index.values] self.groupNumber.index = [ str(i) for i in self.groupNumber.index.values ] def plotGroupret(self, _show=True, path=None): '''绘制分层收益''' #设置画布大小,把画布分成4块 plt.figure(figsize=(24, 12)) #第一个子图画总收益,柱状图 ax1 = plt.subplot(221) self.groupRet.sum().plot(kind='bar', ax=ax1, rot=20) plt.title('GroupRet', fontsize=20) plt.tight_layout() #第二个子图画分组的收益走势图 ax2 = plt.subplot(223) self.groupRet.cumsum().plot(ax=ax2, rot=15) plt.legend(loc='upper left', fontsize=10) plt.title('CumRet', fontsize=20) plt.tight_layout() #画分组的累计IC走势 ax3 = plt.subplot(222) self.IC.cumsum().plot(ax=ax3, rot=15) plt.title('IC_CUMSUM', fontsize=20) plt.legend(loc='upper left', fontsize=10) plt.tight_layout() #画分组的累计rankIC走势 ax4 = plt.subplot(224) self.rankIC.cumsum().plot(ax=ax4, rot=15) plt.title('RankIC_CUMSUM', fontsize=20) plt.legend(loc='upper left', fontsize=10) plt.tight_layout() if _show: plt.show() else: pass if path != None: plt.savefig(path) def plotGroupStat(self): self.groupRet.index = [str(i) for i in self.groupRet.index.values] ret = sectionCalculator.RET(self.groupRet) sharp = sectionCalculator.SHARP(self.groupRet, self.defaultFactorTest.frequency) maxdd = sectionCalculator.maxDrawDown(self.groupRet) statistic = pd.DataFrame(columns=self.groupRet.columns) statistic = statistic.append(ret) statistic = statistic.append(sharp) statistic = statistic.append(maxdd) statistic = statistic.round(3) print(statistic) plt.table( cellText=statistic.values, rowLabels=statistic.index, colLabels=statistic.columns, ) plt.show() def statistic(self, path=None): ''' :param path: :return: 把计算出的数据储存为excel ''' #计算换手率和交易成本 turn = sectionCalculator.MeanTurn(self.turn) cost = sectionCalculator.SumCost(self.cost) #分别计算ret,sharp,maxdd ret = sectionCalculator.RET(self.groupRet) sharp = sectionCalculator.SHARP(self.groupRet, self.defaultFactorTest.frequency) maxdd = sectionCalculator.maxDrawDown(self.groupRet) #提取IC和rankIC ic = sectionCalculator.MeanIC(self.IC) rankIc = sectionCalculator.MeanRankIC(self.rankIC) #储存每组的平均持股数量 number = sectionCalculator.MeanNumber(self.groupNumber) #储存lag和费率 lag = pd.Series([self.defaultFactorTest.lag] + [np.nan] * (len(self.groupRet.columns) - 1), index=self.groupRet.columns, name='lag') fee = pd.Series([self.defaultFactorTest.fee] + [np.nan] * (len(self.groupRet.columns) - 1), index=self.groupRet.columns, name='fee') pool = pd.Series([self.defaultFactorTest.poolNum] + [np.nan] * (len(self.groupRet.columns) - 1), index=self.groupRet.columns, name='pool') #写入dataframe statistic = pd.DataFrame(columns=self.groupRet.columns) appendList = [ ret, sharp, maxdd, ic, rankIc, turn, cost, number, lag, fee, pool ] for item in appendList: statistic = statistic.append(item) #存为excel if path != None: statistic.to_excel(path) def plotSumCurve(self, testSeries): testData = testSeries[:] allTimeList = testSeries.getDateTimes() fig, ax = plt.subplots(1, 1) ax.plot(np.array(testData).cumsum(), label='cumsum') plt.legend(loc='best') def format_date(x, pos=None): # 改变横坐标格式 if x < 0 or x > len(allTimeList) - 1: return '' else: return allTimeList[int(x)] ax.xaxis.set_major_formatter( ticker.FuncFormatter(format_date)) # 将横坐标设置为日期 fig.show() def plotProdCurve(self, testSeries): testData = testSeries[:] allTimeList = testSeries.getDateTimes() testData = 1 + testData fig, ax = plt.subplots(1, 1) ax.plot(np.array(testData).cumprod(), label='prod') plt.legend(loc='best') def format_date(x, pos=None): # 改变横坐标格式 if x < 0 or x > len(allTimeList) - 1: return '' else: return allTimeList[int(x)] ax.xaxis.set_major_formatter(ticker.FuncFormatter(format_date)) fig.show()
class FactorUpdate: '''因子检测数据写入及更新''' logger = logger.getLogger("factorUpdate") def __init__(self, instruments, market=bar.Market.STOCK, start=None, end=None, testFreq=None, isRelReturn=False, fee=0.003, lag=1): ''' 初始化因子检测参数 param instruments: 代码 "SZ50", "HS300", or "ZZ500" param market: 市场 bar.Market.STOCK, or bar.Market.FUTURES param frequency: 数据频率 bar.Frequency.MINUTE or bar.Frequency.HOUR param start: 因子检测开始时间,当为空值时将使用H5DataReader的默认开始时间 param end: 因子检测结束时间,当为空值时将使用H5DataReader的默认结束时间 param testFreq: 测试的resample频率 param isRelReturn: True为计算相对收益, False为计算绝对收益 param fee: 开仓手续费,用于计算交易成本 ''' self.instruments = instruments self.market = market self.start = start self.end = end self.newFactor = [] self.factorDefPath = pathSelector.PathSelector.getFactorDefPath() self.factorDataPath = pathSelector.PathSelector.getFactorFilePath() self.fee = fee self.isRelReturn = isRelReturn self.lag = lag #设置要回测的时间频率,默认测试 5,30, 60, 120分钟的 self.resampleFreqNum = [bar.Frequency.MINUTE5, bar.Frequency.MINUTE30, bar.Frequency.HOUR, bar.Frequency.HOUR2] if not testFreq else testFreq self.resampleFreqStr = [const.DataFrequency.freq2lable(freq) for freq in self.resampleFreqNum] # 存储resample相关对象的字典 self.reasampleFeedDict = {} self._return_Dict = {} self.rawFactorDict = {} self.factorTesterDict = {} self.dictOldResultDict = {} self.dictFilePathDict = {} def getPanelFeed(self): '''获取一个新的panelFeed''' panelFeed = DataFeedFactory.getHistFeed(instruments=self.instruments, market=self.market, frequency=bar.Frequency.MINUTE, start=self.start, end=self.end) return panelFeed def getBenchPanel(self): '''获取基准指数panel''' benchNameDict = {"SZ50": "IH.CCFX.csv", "HS300": "IF.CCFX.csv", "ZZ500": "IC.CCFX.csv"} if self.instruments in ["SZ50", "HS300", "ZZ500"]: fileName = benchNameDict[self.instruments] else: self.logger.info("The input instruments do not have benchmark. Please re-input.") return filePath = pathSelector.PathSelector.getDataFilePath(market=const.DataMarket.FUTURES, types=const.DataType.OHLCV, frequency=const.DataFrequency.MINUTE, fileName=fileName) indexReader = CSVPanelReader(filePath=filePath, fields=['open', 'high', 'low', 'close', 'volume'], frequency=bar.Frequency.MINUTE, isInstrumentCol=False, start=self.start) indexReader.loads() benchPanel = series.SequenceDataPanel.from_reader(indexReader) return benchPanel def newFactorList(self): '''获取新增的因子列表''' allFactors = [factor.split('.')[0] for factor in os.listdir(self.factorDefPath) \ if factor not in ['__init__.py', '__pycache__']] # self.logger.info("All factors defined: {}".format(allFactors)) self.newFactor = sorted(list(set(allFactors) - set(os.listdir(self.factorDataPath)))) if self.newFactor: self.logger.info("The new factors:{}".format(self.newFactor)) else: self.logger.info("No new factors seen, the factor updating process will end soon") def writeNewFactor(self): ''' 存储数据文件 ''' self.newFactorList() if self.newFactor: # 仅在有新增因子的情况下才进行后续的因子计算、检验及存储 for factor in self.newFactor: # 对新增因子列表里的因子进行计算和数据存储 if factor == 'broker': continue self.logger.info( "****************** Writing FactorData for {} ******************".format(factor)) modulePath = "cpa.factorPool.factors.{}".format(factor) # 因子模块路径 module = importlib.import_module(modulePath) # 导入模块 factorObject = getattr(module, 'Factor') # 获取因子对象的名称 e.g. cpa.factorPool.factors.dmaEwv.Factor panelFeed = self.getPanelFeed() # 为新的因子匹配一个新的panelFeed # 计算绝对收益 if self.isRelReturn is False: # 对各resample周期创建相应的格模块类 for freqNum, freqStr in zip(self.resampleFreqNum, self.resampleFreqStr): self.reasampleFeedDict[freqStr] = ResampledPanelFeed(panelFeed, freqNum) self._return_Dict[freqStr] = returns.Returns(self.reasampleFeedDict[freqStr], lag=self.lag, maxLen=1024) self.rawFactorDict[freqStr] = factorBase.FactorPanel(self.reasampleFeedDict[freqStr], factorObject) self.factorTesterDict[freqStr] = DefaultFactorTest(self.reasampleFeedDict[freqStr], self.rawFactorDict[freqStr], self._return_Dict[freqStr], indicators=['IC', 'rankIC', 'beta', 'gpIC', 'tbdf', 'turn', 'groupRet'], lag=self.lag, cut=0.1, fee=self.fee) panelFeed.run(_print=True) # 由panelFeed同时驱动各resampleFeed # 计算相对收益 elif self.isRelReturn is True: # 生成一个存放resampleFeed的字典 for freqNum, freqStr in zip(self.resampleFreqNum, self.resampleFreqStr): self.reasampleFeedDict[freqStr] = ResampledPanelFeed(panelFeed, freqNum) baseFeedDict = {"base": panelFeed} # panelFeed字典 combinedDict = {**baseFeedDict, **self.reasampleFeedDict} #合并字典 benchPanel = self.getBenchPanel() # 基准指数panel advFeed = AdvancedFeed(feedDict=combinedDict, panelDict={'bench': benchPanel}) for freqStr in self.resampleFreqStr: # 对各resample周期创建相应的格模块类 self._return_Dict[freqStr] = returns.RelativeReturns(advFeed, isResample=True, resampleType=freqStr, lag=self.lag, maxLen=1024) self.rawFactorDict[freqStr] = factorBase.FactorPanel(self.reasampleFeedDict[freqStr], factorObject) # self.rawFactorDict[freqStr] = factorBase.FactorPanel(advFeed, # factorObject, # isResample=True, # resampleType=freqStr) self.factorTesterDict[freqStr] = DefaultFactorTest(advFeed, self.rawFactorDict[freqStr], self._return_Dict[freqStr], isResample = True, resampleType = freqStr, indicators = ['IC', 'rankIC', 'beta', 'gpIC', 'tbdf', 'turn', 'groupRet'], lag=self.lag, cut=0.1, fee=self.fee) advFeed.run(_print=True) # 由advancedFeed同时驱动各resampleFeed # 若数据长度不符合因子检验标准,则不存储 if len(self._return_Dict[self.resampleFreqStr[0]]) <= 2 * self.lag: self.logger.warning( "The length of the return panel <= 2 * the required lag. Data will not be saved.") return # 写h5文件和图表 for freqStr in self.resampleFreqStr: h5PanelWriter = h5Writer.H5PanelWriter(factor, self.factorTesterDict[freqStr]) h5PanelWriter.write(mode="new") reportWriter = ReportWriter(factorName=factor, defaultFactorTest=self.factorTesterDict[freqStr]) reportWriter.write() def updateFactor(self, factor, nBizDaysAhead=30): ''' 续写一个因子文件夹下的所有文件 param factor: 因子名 param nBizDaysAhead: 以旧数据结束日期提前n个工作日开始计算新数据,根据策略需要调整 例如使用MA20的策略,对于2h的数据,至少要提前10个工作日 ''' self.logger.info("****************** Updating FactorData for {} ******************".format(factor)) factorReader = h5Reader.H5BatchPanelReader(factorName=factor, frequency=None, allFolders=True) factorReader.prepareOutputData() dateRangeDict = factorReader.getDateRange() # 获取存放首尾数据日期的字典 endDateList = sorted([range[1] for range in dateRangeDict.values()]) # 取所有的数据结束日期, 并排序 endDate = endDateList[-1].to_pydatetime() # 取所有数据结束日期中最晚的一个 timeDiff = pd.tseries.offsets.BusinessDay(n=nBizDaysAhead) # 比结束日期提前n个工作日开始计算新数据 self.start = endDate - timeDiff # 计算新数据所开始的时间 self.logger.info("The end time in the original data is {}\n" "The input time difference is {}\n" "The start time for calculating the new data is {}\n" "The end time for calculating the new data is {}\n" .format(endDate, timeDiff, self.start, self.end)) panelFeed = self.getPanelFeed() # 以新的start获取一个新的panelFeed modulePath = "cpa.factorPool.factors.{}".format(factor) # 因子模块路径 module = importlib.import_module(modulePath) # 导入模块 factorObject = getattr(module, 'Factor') # 获取因子对象的名称 e.g. cpa.factorPool.factors.dmaEwv.Factor for freqNum, freqStr in zip(self.resampleFreqNum, self.resampleFreqStr): folderPath = pathSelector.PathSelector.getFactorFilePath(factorName=factor, factorFrequency=freqStr) # 读取因子检测的参数值 csvFileName = [name for name in os.listdir(folderPath) if name.endswith(".csv")][0] csvFilePath = os.path.join(folderPath, csvFileName) fields = ["frequency", "lag", "nGroup", "cut", "fee", "poolNum"] settingReader = csvReader.CSVPanelReader(filePath=csvFilePath, fields=fields, frequency=freqNum, isInstrumentCol=False) settingReader.loads() # 读取不同周期的h5文件 freqReader = h5Reader.H5BatchPanelReader(factorName=factor, frequency=freqNum, allFolders=False) freqReader.prepareOutputData() # 存入相应的字典中 oldResultDict = freqReader.to_frame() # 获取存放dataframe数据的字典 filePathDict = freqReader.getFilePath() # 获取原来H5文件的路径 # 对各resample周期创建相应的模块类 self.dictOldResultDict[freqStr] = oldResultDict self.dictFilePathDict[freqStr] = filePathDict self.reasampleFeedDict[freqStr] = ResampledPanelFeed(panelFeed, freqNum) self._return_Dict[freqStr] = returns.Returns(self.reasampleFeedDict[freqStr], lag=self.lag, maxLen=1024) self.rawFactorDict[freqStr] = factorBase.FactorPanel(self.reasampleFeedDict[freqStr], factorObject) self.factorTesterDict[freqStr] = DefaultFactorTest(feed=self.reasampleFeedDict[freqStr], factorPanel=self.rawFactorDict[freqStr], returnPanel=self._return_Dict[freqStr], indicators=['IC', 'rankIC', 'beta', 'gpIC', 'tbdf', 'turn', 'groupRet'], lag=self.lag, cut=0.1, fee=self.fee) panelFeed.run(_print=True) # 由panelFeed同时驱动各resampleFeed for freqStr, oldResultDict in self.dictOldResultDict.items(): # 将旧的文件移入以时间命名的文件夹 oldDateTime = list(self.dictOldResultDict[freqStr].keys())[0][-16:-3] freqFolderPath = pathSelector.PathSelector.getFactorFilePath(factorName=factor, factorFrequency=freqStr) destFolderPath = os.path.join(freqFolderPath, oldDateTime) if not os.path.exists(destFolderPath): os.mkdir(destFolderPath) fileList = [name for name in os.listdir(freqFolderPath) if os.path.isfile(os.path.join(freqFolderPath, name))] for file in fileList: sourceFilePath = os.path.join(freqFolderPath, file) shutil.move(sourceFilePath, destFolderPath) # 写新的h5文件 h5PanelWriter = h5Writer.H5PanelWriter(factorName=factor, defaultFactorTest=self.factorTesterDict[freqStr]) h5PanelWriter.write(mode="append", oldResultDict=oldResultDict) # 使用append模式写入 for freqNum in self.resampleFreqNum: # 写新的图表文件 secondReader = h5Reader.H5BatchPanelReader(factorName=factor, frequency=freqNum) secondReader.prepareOutputData() reportWriter = ReportWriter(factorName=factor, h5BatchPanelReader=secondReader, csvPanelReader=settingReader) reportWriter.write() def updateFactorPool(self, nBizDaysAhead=30): ''' 续写factorData下所有的因子文件夹 param nBizDaysAhead: 以旧数据结束日期提前n个工作日开始计算新数据,根据策略需要调整 例如使用MA20的策略,对于2h的数据,至少要提前10个工作日 ''' factorNameList = [name for name in os.listdir(self.factorDataPath) if # 取factorData文件下的子文件夹名 os.path.isdir(os.path.join(self.factorDataPath, name))] for factor in factorNameList: self.updateFactor(factor, nBizDaysAhead=nBizDaysAhead)
class AdvancedFeed(PanelFeed): ''' 有时候feed不在一个数据源,此时为保证feed同步,将一个或多个feed合并成一个调用 指数的lable为'base' ''' logger = logger.getLogger('AdvancedFeed') class EmptyDataSource: fields = [] @classmethod def getFields(cls): return [] def __init__(self, feedDict=None, panelDict=None): ''' :param feedDict: {lable: panelFeed} ''' super().__init__(self.EmptyDataSource, [], None, None) self.feedDict = feedDict if feedDict is not None else {} self.panelDict = panelDict if panelDict is not None else {} self.sortedFeeds = [] # 同时包含feed和panel两种数据 self.sortedLable = [] self.dataSource = [] self.instruments = None self.frequency = np.inf self.maxLen = 0 self.synchronizedNextValue = {} # lable : (dateTime, value) self.isEof = False self.available = None self.benchPanel = None self.baseDataLable = None # OHLCV所在feed为baseFeed multiDict = {} if feedDict is not None: multiDict = dict(multiDict, **feedDict) if panelDict is not None: multiDict = dict(multiDict, **panelDict) for lable, value in multiDict.items(): if lable in feedDict: self.attachFeed(lable, value) else: self.attachPanel(lable, value) # 更新sequenceDataPanel def __attachBaseInfo(self, lable, value, source): ''' :param initialize base info, panelFeed and sequence DataPanel :return: ''' if isinstance(source, BaseDataReader) and 'close' in source.getFields(): self.sortedFeeds.insert(0, value) self.sortedLable.insert(0, lable) self.baseDataLable = lable else: self.sortedFeeds.append(value) self.sortedLable.append(lable) self.dataSource.append(source) if not (isinstance(source, BasePanelReader) and source.isInstrumentCol is False): if self.instruments is None: self.instruments = source.getRegisteredInstruments() elif self.instruments != source.getRegisteredInstruments(): removeList = list( set(self.instruments) - set(source.getRegisteredInstruments())) # 求差集 self.logger.info( "Miss Data stock list:\n{}".format(removeList)) self.instruments = list( set(self.instruments) & set(source.getRegisteredInstruments())) # 求交集 if source.getFrequency() < self.frequency: self.frequency = source.getFrequency() if value.getMaxLen() > self.maxLen: self.maxLen = value.getMaxLen() def hasOhlcv(self): ''' :return:数据中含有高开低收数据 ''' return self.baseDataLable is not None def attachFeed(self, lable, otherFeed): ''' :param panelFeed: 将两个不同数据源的feed数据合并到一个调用(如OHLC数据和财务数据分属不同数据源) :param lable:标签 :return: ''' otherDataSource = otherFeed.getDataSource() self.__attachBaseInfo(lable, otherFeed, otherDataSource) self.feedDict[lable] = otherFeed for field in otherFeed.getFields(): assert field not in self.fields, 'duplicate field attached {}'.format( field) self.fields.append(field) if 'close' in otherDataSource.getFields(): self.openPanel = otherFeed.openPanel self.highPanel = otherFeed.highPanel self.closePanel = otherFeed.closePanel self.lowPanel = otherFeed.lowPanel self.volumePanel = otherFeed.volumePanel if otherFeed.getFrequency() != self.getFrequency( ) and self.getFrequency() != np.inf: warnings.warn(u'不同周期的数据元同步,会导致数据内部的时间不同步, 进行运算时请留意') for extraField in otherFeed.getExtra(): self.extraPanel[extraField] = otherFeed.getExtra(extraField) self.fields.append(extraField) def attachBenchMark(self, benchMarkPanel): ''' :param benchMarkPanel:添加指数行情 :return: ''' self.attachPanel('bench', benchMarkPanel) def attachPanel(self, lable, otherPanel): ''' :param lable: :param panel:dataPanel 数据 :return: ''' otherDataSource = otherPanel.getDataSource() self.__attachBaseInfo(lable, otherPanel, otherDataSource) self.panelDict[lable] = otherPanel if lable == 'bench': self.benchPanel = otherPanel else: self.extraPanel[lable] = otherPanel assert lable not in self.fields, 'duplicate field attached {}'.format( lable) self.fields.append(lable) def getBench(self): ''' :return:返回指数 ''' return getattr(self, 'benchPanel', None) def getBase(self): ''' :return: 返回价量所在feed ''' assert 'close' in self.fields return self.sortedFeeds[0] def peekNextValues(self, available=None): ''' :param available: 批量读取下一行数据,如果上次没有使用则不再读取 :return: ''' available = self.sortedLable if available is None else available for lable in available: lableIdx = self.sortedLable.index(lable) if not self.sortedFeeds[lableIdx].getDataSource().eof(): dateTime, value = self.sortedFeeds[lableIdx].getDataSource( ).getNextValues() # concurrent day level frequency to 15:00 if self.sortedFeeds[lableIdx].getDataSource().getFrequency( ) >= bar.Frequency.DAY > self.getFrequency(): dateTime = dateTime.replace(hour=15) self.synchronizedNextValue[lable] = (dateTime, value) else: del self.synchronizedNextValue[lable] if self.synchronizedNextValue == {}: self.isEof = True self.available = available def peekNextDatetime(self, sychronizedNextValues): ''' :param sychronizedNextValues:从候选列表中选择最小时间 :return: ''' return list( sorted([ dateTime for dateTime, _ in sychronizedNextValues.values() ]))[0] def getAvailable(self): ''' :return:返回当前时刻有数据更新的feed或panel的lable名 ''' return self.available def getNextValues(self): ''' :return: ''' self.peekNextValues(self.available) nextTime = self.peekNextDatetime(self.synchronizedNextValue) if self.hasOhlcv( ) and self.baseDataLable not in self.synchronizedNextValue: self.logger.info( 'there is no extra ohlc values, strategy will end') self.stopped = True return None, None available = [] for lable, value in self.synchronizedNextValue.items(): dateTime, value = value if dateTime == nextTime: available.append(lable) if lable in self.feedDict: self.feedDict[lable].appendNextValues(dateTime, value) else: self.panelDict[lable].appendWithDateTime( dateTime, value.values) self.available = available if self.benchPanel is not None and self.synchronizedNextValue['bench'][ 0] > nextTime: self.logger.info( 'benchMark exists, lower than {} time is ignored, current time {}' .format(self.synchronizedNextValue['bench'][0], nextTime)) return None, None else: self.dispatchNewValueEvent(self, nextTime, None) return nextTime, None def eof(self): ''' :return: 全部数据为空 ''' return self.isEof def run(self, stopCount=None, _print=False): counter = 0 while not self.eof() and not self.stopped: dateTime, df = self.getNextValues() if dateTime: counter += 1 if _print: print(dateTime, self.available) if stopCount is not None and counter > stopCount: break
class H5BatchPanelReader(BasePanelReader): ''' h5 panel 数据读取接口,用于同时读取多个文件 ''' logger = logger.getLogger("H5PanelReader") def __init__(self, factorName=None, frequency=None, start=None, end=None): ''' 初始化 param path: 文件夹路径 param frequency: 数据频率 param start: 所需要获取的数据开始时间 param end: 所需要获取的数据结束时间 ''' super().__init__() self.frequency = frequency self.market = bar.Market.STOCK self.filePathDict = {} # 用于存储文件路径的字典 self.testResultDict = {} # 用于所获得的存储数据dataframe的字典 self.readerDict = {} # 用于存储单因子读取H5PanelReader对象的字典 self.nextValueDict = {} # 用于存储所有文件下一行数据的字典 self.staticPanelDict = {} # 用于存储静态panel或者series的字典 self.start = start self.end = end self.isEof = False self.availFctList = None self.factorName = factorName self.setFilePath() def setFilePath(self): ''' 设置文件夹路径 :param section: 使用pathSelector,读取datapath.ini中预先配置好的本地数据路径 :return: ''' pathSelctor = pathSelector.PathSelector() # 生成PathSelector类,按照所输入的section,获取存放h5的路径 self.path = pathSelctor.getFactorFilePath( factorName=self.factorName, factorFrequency=self.frequency) def prepareOutputData(self): ''' 获取数据,生成dataframe,并将文件名及对应的路径、reader对象、和dataframe存入相应的字典 ''' # 判断路径下是否有子文件夹,有的话将所有子文件夹的路径存入一个list folderNameList = [ name for name in os.listdir(self.path) if os.path.isdir(os.path.join(self.path, name)) ] print(folderNameList) # 判断folderNameList是否有值,若有,说明包含子文件夹,则读取所有resample文件夹 if folderNameList: # 将上述子文件夹生成全路径后,存入list folderDirList = [ os.path.join(self.path, folderName) for folderName in folderNameList ] # 遍历每个resample文件夹 for folderDir in folderDirList: # 获取每个因子文件夹下的所有h5文件 fileNameList = [ name for name in os.listdir(folderDir) if ".h5" in name ] # fileNameList = os.listdir(folderDir) # 将上述h5文件生成全路径后,存入list filePathList = [ os.path.join(folderDir, fileName) for fileName in fileNameList ] # 遍历单个resample文件夹下的所有h5文件 for fileName, filePath in zip(fileNameList, filePathList): # 将文件名及文件全路径存入相应字典 self.filePathDict[fileName] = filePath # 将文件名及reader对象存入相应字典 self.readerDict[fileName] = H5PanelReader( filePath, self.frequency, self.start, self.end) # 将文件名及所读取的dataframe存入相应字典 self.testResultDict[fileName] = self.readerDict[ fileName].retrieve() # 若folderNameList为空,说明路径下没有子文件夹,则读取单个resample文件夹 else: # 获取所需因子文件夹下的所有h5文件 fileNameList = [ name for name in os.listdir(self.path) if ".h5" in name ] # fileNameList = os.listdir(self.path) # 将上述h5文件生成全路径后,存入list filePathList = [ os.path.join(self.path, fileName) for fileName in fileNameList ] for fileName, filePath in zip(fileNameList, filePathList): self.filePathDict[fileName] = filePath self.readerDict[fileName] = H5PanelReader( filePath, self.frequency, self.start, self.end) self.testResultDict[fileName] = self.readerDict[ fileName].retrieve() def getDir(self): ''' 返回文件路径 ''' return self.path def getFrequency(self): ''' 返回数据频率 ''' return self.frequency def getFilePath(self): ''' 返回存储文件路径的字典 ''' return self.filePathDict def getTestResult(self): ''' 返回存储检测数据dataframe的字典 ''' return self.testResultDict def getReader(self): ''' 返回H5PanelReader对象的字典 ''' return self.readerDict def getDataShape(self): ''' 返回包含数据长度的字典 ''' self.dataShapeDict = {} for key, value in self.readerDict.items(): self.dataShapeDict[key] = value.getDataShape() return self.dataShapeDict def getDateRange(self): ''' 返回包含数据起止日期的字典 ''' self.dateRangeDict = {} for key, value in self.readerDict.items(): self.dateRangeDict[key] = value.getDateRange() return self.dateRangeDict def getRegisteredInstruments(self): ''' 返回包含股票代码的字典 ''' self.instrumentDict = {} for key, value in self.readerDict.items(): self.instrumentDict[key] = value.getRegisteredInstruments() return self.instrumentDict def getIterator(self): ''' 返回包含迭代器对象的字典 ''' self.iterDict = {} for key, value in self.readerDict.items(): self.iterDict[key] = value.getIterator() return self.iterDict def peekNextValues(self, availFctList=None): ''' 通过reader字典里的reader对象调用getNextValues(),如果下一行还有值的话, 存入nextValueDict字典中 nextValueDict字典是一个只存储生成的单行值的字典 当reader下一行无值时,nextValueDict会将该reader从字典中删除,从而做到只输出存在的值 ''' availFctList = list( self.readerDict.keys()) if availFctList is None else availFctList for factor in availFctList: if not self.readerDict[factor].eof(): dateTime, value = self.readerDict[factor].getNextValues() self.nextValueDict[factor] = (dateTime, value) else: # 若某个因子已经读取完了,则将该因子在availFctList和nextValueDict中删除 availFctList.remove(factor) del self.nextValueDict[factor] if self.nextValueDict == {}: self.isEof = True # 若nextValueDict为空,则说明所有文件已经读取完 self.availFctList = availFctList # 更新availFctList def getNextValues(self): ''' 返回包含下一行数据的字典 ''' self.peekNextValues(self.availFctList) return self.nextValueDict def eof(self): ''' 返回下一行字典的判断值 ''' return self.isEof def to_static_panel(self): ''' 直接将全量数据转换为静态的SequenceDataPanel或者SequenceDataSeries,返回对象 ''' for key, value in self.readerDict.items(): self.staticPanelDict[key] = value.to_static_panel() return self.staticPanelDict def to_frame(self): ''' 获取的数据已为dataframe,此处无转换 ''' pass
class ResampledPanelFeed(baseFeed.PanelFeed): """ market is in cpa.resamplebase.Market.STOCK or cpa.resamplebase.Market.CTP """ logger = logger.getLogger('resampleFeed') def __init__(self,panelFeed, frequency, marketType=bar.Market.STOCK, maxLen=None): ''' :param panelFeed: :param resampleTo: resamplebase.Frequency.(MINUT HOUR DAY WEEK MONTH) :param maxLen: ''' if not isinstance(panelFeed, baseFeed.PanelFeed): raise Exception("panelFeed must be a baseFeed.panelFeed instance") if maxLen is None: maxLen = panelFeed.maxLen super(ResampledPanelFeed,self).__init__(panelFeed.getDataSource(), panelFeed.getInstruments(), frequency, maxLen) self.dataSource = None #reset dataSource self.market = marketType self.panelFeed = panelFeed self.grouper = PanelGrouper(panelFeed) self.isResampleFeed = True self.range = None self.__currentDateTime = np.nan self.__needUpdateResampleBar = ResampleState.NOTHING self.__laggedTime = None self.__nearlyEndingEvent = observer.Event() self.__updateValuesEvent = observer.Event() panelFeed.getNewPanelsEvent(priority=panelFeed.EventPriority.RESAMPLE).subscribe(self.__onNewValues) def isResampleFeed(self): return self.isResampleFeed def __onNewValues(self, panelFeed, dateTime, df): ''' :param panelFeed: :param dateTime: :param df: :return: ''' if self.range is None: self.range = resamplebase.build_range( dateTime, self.getFrequency(), market=self.market) self.grouper.addCounter() elif self.range.belongs(dateTime): self.grouper.addCounter() if self.getFrequency() >= bar.Frequency.MINUTE * 15 and self.range.nearlyMarketEnding(dateTime): if self.__laggedTime is None or dateTime.day != self.__laggedTime.day: self.appendValues(dateTime.replace( hour=15, minute=0), self.grouper.getGrouped()) self.__needUpdateResampleBar = ResampleState.LAG_ONE self.__laggedTime = dateTime elif not self.range.belongs(dateTime): self.grouper.addCounter() if self.range.outEnding(dateTime): grouped = self.grouper.getGrouped(end=-1) self.appendValues(dateTime, grouped) self.grouper.resetCounter() self.range = resamplebase.build_range( dateTime, self.getFrequency(), market=self.market) else: self.range = None self.appendValues(dateTime, self.grouper.getGrouped()) self.grouper.resetCounter() def appendValues(self, dateTime, grouped): ''' :param dateTime: 当前时间 :param grouped: grouped里面的时间,可能和当前时间不一样 :return: ''' if self.__needUpdateResampleBar == ResampleState.LAG_ONE: self.updateWithDateTime(dateTime, grouped) self.__needUpdateResampleBar = ResampleState.READY else: self.appendWithDateTime(dateTime, grouped) def appendWithDateTime(self, dateTime, grouped): self.__currentDateTime = dateTime groupedTime, grouped = grouped self.openPanel.appendWithDateTime(groupedTime, grouped['open']) self.highPanel.appendWithDateTime(groupedTime, grouped['high']) self.lowPanel.appendWithDateTime(groupedTime, grouped['low']) self.closePanel.appendWithDateTime(groupedTime, grouped['close']) self.volumePanel.appendWithDateTime(groupedTime, grouped['volume']) self.dispatchNewValueEvent(self, dateTime, None) self.logger.debug('ResampledTime %s: %s'.format(const.DataFrequency.freq2lable(self.frequency), groupedTime.strftime('%Y-%m-%d %H:%M:%S'))) def updateWithDateTime(self, dateTime, grouped): self.__currentDateTime = dateTime groupedTime, grouped = grouped self.openPanel.updateWithDateTime(groupedTime, grouped['open']) self.highPanel.updateWithDateTime(groupedTime, grouped['high']) self.lowPanel.updateWithDateTime(groupedTime, grouped['low']) self.closePanel.updateWithDateTime(groupedTime, grouped['close']) self.volumePanel.updateWithDateTime(groupedTime, grouped['volume']) def getCurrentDatetime(self): return self.__currentDateTime def getNearlyEndingEvent(self): return self.__nearlyEndingEvent def eof(self): return self.panelFeed.eof() def getNextValues(self): dateTime, df = self.panelFeed.getNextValues() return dateTime, df
class FinanceReader(BaseDataReader): ''' 读取财务数据三大表 ''' logger = logger.getLogger('FinanceReader') class FinanceType: ''' 财务数据三大表文件名 ''' BALANCESHEET = 'ASHAREBALANCESHEET.csv' INCOME = 'ASHAREINCOME.csv' CASHFLOW = 'ASHARECASHFLOW.csv' def __init__(self, fileName, instruments, fields, start=None, end=None): ''' 初始化 param fileName: 文件名 param instruments: 所选股票code param fields: 所选字段 param start: 起始时间 param end: 结束时间 ''' super().__init__(instruments, fields, start) # 同时满足两种输入方法,e.g. "BALANCESHEET" or "ASHAREBALANCESHEET.csv" fileName = fileName if "." in fileName else getattr( self.FinanceType, fileName.upper()) self.filePath = pathSelector.PathSelector.getDataFilePath( market=const.DataMarket.STOCK, types=const.DataType.FINANCE, frequency=const.DataFrequency.QUARTER, fileName=fileName) self.end = end self.isEof = False self.valGen = None def loads(self): ''' 读取csv数据并转换为DataFrame,并做必要的数据处理 return: dataframe ''' self.df = pd.read_csv(self.filePath) self.df = self.df[pd.notnull(self.df["ANN_DT"])] # 删除公告日期为NAN的行 self.df["ANN_DT"] = pd.to_datetime(self.df["ANN_DT"], format="%Y%m%d") self.df["ANN_DT"] = self.df["ANN_DT"].dt.date if self.start: # 取开始时间之后的数据 self.df = self.df[ self.df["ANN_DT"] > pd.Timestamp(self.start).date()] if self.end: # 取结束时间之前的数据 self.df = self.df[ self.df["ANN_DT"] < pd.Timestamp(self.end).date()] self.df[["WIND_CODE", "temp"]] = self.df["WIND_CODE"].str.split( ".", expand=True) # 删除股票代码后缀,保留数字 self.df.drop(columns="temp", inplace=True) self.availInstruments = sorted( set(self.df["WIND_CODE"]) & set(self.instruments)) # 取股票代码交集,并排序 self.df = self.df[self.df["WIND_CODE"].isin( self.availInstruments)] # 删除不在股票代码集合中的行 # 当出现同一公告日同一股票多条数据的情况时,将公告日递延一天 self.df.sort_values(by=["WIND_CODE", "ANN_DT"], inplace=True) self.df.reset_index(inplace=True) boolDF = self.df[["WIND_CODE", "ANN_DT" ]].eq(self.df[["WIND_CODE", "ANN_DT" ]].shift(1)).all(axis="columns") duplicatedRowIndex = boolDF.index[boolDF == True].tolist() for idx in duplicatedRowIndex: datePlusOne = self.df.iloc[idx - 1]["ANN_DT"] + datetime.timedelta( days=1) self.df.loc[idx, "ANN_DT"] = datePlusOne if idx <= len(self.df) - 2: if self.df.loc[idx]["ANN_DT"] == self.df.loc[idx + 1]["ANN_DT"]: self.df.loc[idx + 1, "ANN_DT"] = self.df.iloc[idx][ "ANN_DT"] + datetime.timedelta(days=1) self.df.set_index(["ANN_DT", "WIND_CODE"], inplace=True) # 设置双index,并排序 self.df.sort_index(axis=0, inplace=True) self.df = self.df[self.fields] # 取所需字段的数据,删除其他列 self.allDate = [ date for date in np.unique( self.df.index.get_level_values(level="ANN_DT")) ] # 生成一个含所有日期的列表 self.valGen = self.valueGenerator() # 生成generator并赋值 return self.df def valueGenerator(self): ''' 生成器 return: 从df中读取下一个时间的数据,返回时间和一个dataframe,行是股票代码,列是所选字段 ''' for idx, date in enumerate(self.allDate): tempDF = self.df.loc[date] # tempDF = tempDF[~tempDF.index.duplicated(keep="first")] # 使用递延处理,这里注释掉只保留一行的处理 # 生成一个列表,列表包含存在于股票交集而不在tempDF中的股票 emptyInsList = [ ins for ins in self.availInstruments if ins not in tempDF.index.values ] emptyDF = pd.DataFrame( columns=self.fields, index=emptyInsList) # 生成一个只包含行名和列名的空dataframe tempDF = tempDF.append(emptyDF) tempDF.sort_index(inplace=True) yield date, tempDF if idx == len(self.allDate) - 2: self.isEof = True def getDir(self): ''' 返回文件路径 return: 文件路径 ''' return self.filePath def getFrequency(self): ''' 返回数据周期 return: 数据周期,以秒计 ''' return bar.Frequency.QUARTER def getFields(self): ''' 返回所选字段 return: 所选字段list ''' return self.fields def getDataShape(self): ''' 返回未填充空值的dataframe数据长度 return: 包含数据长度的tuple ''' return self.df.shape def getDateRange(self): ''' 返回经过筛选后的数据起止日期 return: 包含数据起止日期的tuple ''' return (self.allDate[0], self.allDate[-1]) def getRegisteredInstruments(self): ''' 返回用户查询的股票代码与财务数据中存在的股票代码的交集 return: 包含交集股票代码的list ''' return self.availInstruments def getNextValues(self): ''' 返回生成器返回的下一个值 return: datetime, value ''' return next(self.valGen) def eof(self): ''' 返回end of file判断变量 return: True or False ''' return self.isEof
class CSVFutureDataReader(BaseDataReader): ''' CSV feed数据读取接口。 这是用于读取期货数据的类。 使用pathSelector 统一路径 ''' logger = logger.getLogger('CSVFunterDataReader') def __init__(self, instruments, fields, startTime, endTime, frequency=bar.Frequency.MINUTE, limit=None): super().__init__(instruments, fields, startTime) self.instruments = sorted(instruments) if instruments else None self.frequency = frequency self.market = bar.Market.FUTURES if startTime is None: # 若startTime为空,则默认从最开始的数据开始输出 self.start = pd.to_datetime('20150601') else: self.start = pd.to_datetime(startTime) # 输入的起始时间,转换为datetime格式 if endTime is None: self.end = pd.to_datetime('20200101') # 若endTime为空,则默认输出到最后的数据 else: self.end = pd.to_datetime(endTime) # 输入的截止时间,转换为datetime格式 self.limit = limit # 设置输出的最多的数据条数 if fields is None: self.fields = ['open', 'high', 'low', 'close', 'volume'] else: pass # 设置要读取的字段,如果没有给定输入,则默认读取开、高、低、收、量 self.registeredInstruments = 'Apply prepareOutputData function first and then check the registered instruments.' self.isEof = False # 判断是否全部输出完毕的signal self.valGen = None # 储存每一次yield的数据 self.currentInstruments = [] self.path = pathSelector.PathSelector.getDataFilePath( market=const.DataMarket.FUTURES, types=const.DataType.OHLCV, frequency=const.DataFrequency.MINUTE, fileName=None) self.getAllFutureNames() def setDir(self, path): self.path = path def getAllFutureNames(self): ''' 查找本地所有期货品种的数据 :return: 用self.allFutures储存本地所有期货品种,键是品种的简称,值是对应的文件名 同时把要查询的品种的并且本地没有数据的剔除,初始化registeredInstruments ''' fileLists = os.listdir(self.path) # 读取文件夹下所有文件名 self.allFutures = {file.split('.')[0]: file for file in fileLists} # 取符号.之前的string registeredInstruments = [j for j in self.instruments if j in self.allFutures.keys()] if \ self.instruments else sorted(list(self.allFutures.keys())) # 只保留本地有的品种 print("RGT:", registeredInstruments) self.registeredInstruments = registeredInstruments def readCSVFile(self, fileName): ''' 读取期货csv文件的函数 :param path: 存储期货文件的路径 :param fileName: 相应的文件名 :return: 返回dataframe,index是datetime格式的时间,介于start和end之间,读取的列是要查询的字段和'symbol' ''' df = pd.read_csv(self.path + r"\\" + fileName) df['datetime'] = pd.to_datetime(df['datetime'].str.slice(0, 19)) df = df.set_index('datetime') # 只保留查询起始日期start之后,end之后的那些行 df = df[(df.index >= self.start) & (df.index <= self.end)] df = df[self.fields + ['symbol']] df['symbol'] = df['symbol'].apply(lambda x: x.split(".")[0]) return df def prepareOutputData(self): ''' 读取加工所需所有数据的函数,并设置好要读取的时间索引 :return: ''' self.futureDict = {} # 用来存放每个需要查询的期货的dataframe self.firstDict = {} # 用来存放对应dataframe的index的第一个值 wholeIndex = [] # 用来存放所有index的并集 removeList = [] for future in self.registeredInstruments: # 遍历 df = self.readCSVFile(self.allFutures[future]) if len(df) > 0: self.futureDict[future] = df tmpIndex = list(df.index) self.firstDict[future] = tmpIndex[0] wholeIndex = wholeIndex + tmpIndex del tmpIndex self.logger.info('{} data get'.format(future)) else: # 若某品种的第一个交易时间晚于end,把它移出registeredInstruments removeList.append(future) print(future, 'data is not available before end time.') for j in removeList: self.registeredInstruments.remove(j) # 初始化self.wholeIndex,是一个列表,储存start和end之间所有的交易时间 self.wholeIndex = np.unique(wholeIndex) del wholeIndex if len(self.wholeIndex) == 0: # 若start和end之间一个品种的数据都没有,把终止信号设为TRUE self.isEof = True print('No available data within the given period.') else: if self.limit is not None: # 如果limit不是None,那么保留wholeIndex的前limit个值 if len(self.wholeIndex) > self.limit: self.wholeIndex = self.wholeIndex[:self.limit] self.actualStart = self.wholeIndex[0] # 记录真正的起、止时间 self.actualEnd = self.wholeIndex[-1] removeList = [] # 把第一条数据晚于actualEnd的品种剔除registeredInstruments for future in self.registeredInstruments: if self.firstDict[future] > self.actualEnd: removeList.append(future) for j in removeList: self.registeredInstruments.remove(j) print( future, 'data not available within the first %d output.' % self.limit) def prepareGenerator(self): ''' 初始化生成器 :return: ''' self.prepareOutputData() # 准备数据 self.valGen = self.valueGenerator() # 初始化生成器 def valueGenerator(self): ''' 数据生成器 :return: 返回每一个时间的数据查询结果 ''' for idx, date in enumerate(self.wholeIndex): df = self.adjustData(date) print(df) yield date, df del df if idx == len(self.wholeIndex) - 2: # 如果输出完了全部的数据,把终止信号设为TRUE self.isEof = True def adjustData(self, date): ''' 调整输出的数据格式的函数 :param date: 要查询的数据对应的时间(datetime格式) :return: ''' combined = pd.DataFrame() toAddList = [] for future in self.registeredInstruments: # 遍历要查询的品种 tmp = self.futureDict[future].copy() tmp = tmp[tmp.index == date].reset_index() # 取对应时间的数据 if len(tmp) > 0: tmp = pd.DataFrame( tmp.iloc[0]).T # 有些品种的数据中有些交易日的11:30:00会有两条记录!此时只取第一条 else: tmp = pd.DataFrame( index=[0], columns=tmp.columns) # 有些品种在一些时间会有数据缺失,此时用nan填充 # tmp.loc[0,'symbol']=self.allFutures[future][:-4] tmp.loc[0, 'symbol'] = future combined = pd.concat([combined, tmp]) del tmp combined = combined.drop(['datetime'], axis=1).set_index( 'symbol') # 删去'datetime'列,把'symbol'设置为索引,并转置 tmp = combined.dropna(axis=0, how='any') self.currentInstruments = list(tmp.index) del tmp return combined def getDir(self): ''' 返回期货的本地数据路径 :return: ''' return self.path def getFrequency(self): ''' 返回数据的频率 :return: ''' return self.frequency def getDataShape(self): ''' return: 返回数据的维度(字典):{'timeLength','stockNumber','fieldNumber'} 即{本次查询的数据总长度、期货数量、字段数量} ''' shape = { 'timeLength': len(self.wholeIndex), 'futureNumber': len(self.registeredInstruments), 'fieldNumber': len(self.fields) } return shape def getDateRange(self): ''' :return: 返回[数据起始时间,数据截止时间] ''' if len(self.wholeIndex) == 0: return 'No data get in the given period.' else: return self.wholeIndex[0], self.wholeIndex[-1] def getRegisteredInstruments(self): ''' :return: 返回实际使用的instruments ''' return self.registeredInstruments def getCurrentInstruments(self): ''' :return: 返回当前时刻,不是nan的期货品种列表 ''' return self.currentInstruments def getNextValues(self): ''' :return: 获取生成器返回的下一个值 ''' return next(self.valGen) def eof(self): ''' :return: 如果触发停止条件:返回True, 否则返回False 触发条件:输出完要查询的全部的数据;或者输出的数量达到上限 ''' return self.isEof def getFields(self): ''' 返回要查询的字段列表 :return: ''' return self.fields
class CSVPanelReader(BasePanelReader): ''' csv panel 数据读取接口, 不再主动调用 ''' logger = logger.getLogger("CSVPanelReader") def __init__(self, path=None, fileName=None, frequency=bar.Frequency.MINUTE, isInstrmentCol=True, start=None, end=None): ''' 初始化 param frequency: 数据频率 param dir: 文件路径 :param:isInstrumentCol代表是否列名为codes param startTime: 所需要获取的数据开始时间 param endTime: 所需要获取的数据结束时间 ''' super().__init__() self.path = path self.fileName = fileName self.isInstrumentCol = isInstrmentCol self.frequency = frequency self.start = start self.end = end self.isEof = False self.count = 0 self.staticPanel = None self.staticSeries = None self.iterator = None self.df = None def setDir(self, path): self.path = path def loads(self): ''' 获取数据,存入dataframe ''' self.df = pd.read_csv(os.path.join(self.path, self.fileName), index_col=0, parse_dates=[0]) # 若未输入所需开始时间,则取数据自身的开始时间 self.start = pd.to_datetime( self.start) if self.start else self.df.index[0] if self.start < self.df.index[0]: self.logger.warning( "The input start date {} is before the data's start date {}". format(self.start, self.df.index[0])) # 若未输入所需结束时间,则取数据自身的结束时间 self.end = pd.to_datetime(self.end) if self.end else self.df.index[-1] if self.end > self.df.index[-1]: self.logger.warning( "The input end date {} is after the data's end date {}".format( self.end, self.df.index[-1])) self.df = self.df.loc[self.start:self.end] self.logger.info("{} to {} {} data got.".format( str(self.start)[:10], str(self.end)[:10], self.fileName)) self.getIterator() def getIterator(self): ''' 返回一个pandas迭代器 ''' # 判断所读取的数据为dataframe还是series,相应地生成迭代器 self.iterator = self.df.iterrows() if len( self.df.shape) == 2 else self.df.iteritems() return self.iterator def getDir(self): ''' 返回文件路径 ''' return self.path def getFrequency(self): ''' 返回数据频率 ''' return self.frequency def getDataShape(self): ''' 返回数据长度 ''' return self.df.shape def getDateRange(self): ''' 返回数据起止日期,start, end, 实时数据只有start,没有end ''' return (self.df.index[0], self.df.index[-1]) def getRegisteredInstruments(self): ''' 返回股票代码 ''' if len(self.df.shape) == 2 and self.isInstrumentCol: instrumentList = list(self.df) else: instrumentList = None return instrumentList def getColumns(self): return list(self.df) def getNextValues(self): ''' 返回datetime, value ''' self.count += 1 return next(self.iterator) def eof(self): ''' 判断迭代器是否达到了dataframe最后一行并返回判断变量 ''' if self.count == len(self.df.index): self.isEof = True return self.isEof def to_static_panel(self, maxLen=None): ''' 直接将全量数据转换为静态的SequenceDataPanel,返回对象 ''' if len(self.df.shape) == 2: self.staticPanel = series.SequenceDataPanel( self.getRegisteredInstruments(), maxLen=maxLen, dtype=np.float32) for index, row in self.df.iterrows(): self.staticPanel.appendWithDateTime(index, row) return self.staticPanel else: self.staticSeries = series.SequenceDataSeries(maxLen=maxLen) for index, row in self.df.iteritems(): self.staticSeries.appendWithDateTime(index, row) return self.staticSeries def to_frame(self): ''' 获取的数据已为dataframe,此处无转换 ''' return self.df
class FactorPanel(SequenceDataPanel): logger = logger.getLogger("FactorPanel") def __init__(self, panelFeed, factorCalculatorCls, maxLen=None, **kwargs): ''' :param panelFeed:原始的数据feed :param factorCalculatorCls: 因子计算类 :param maxLen: 存储因子值的最大长度 :param kwargs: 其他需要传入factorCalculatorcls里的参数 :indicator 如果需要内置指标从此处初始化 :return: ''' super(FactorPanel, self).__init__(colNames=panelFeed.getInstruments(), maxLen=maxLen) self.panelFeed = panelFeed self.instruments = panelFeed.getInstruments() self.colLen = len(self.instruments) self.kwargs = kwargs self.factorName = factorCalculatorCls.__name__ self.initializeFactorCals(factorCalculatorCls) # 注册回调事件,优先级为因子计算 self.panelFeed.getNewPanelsEvent( priority=baseFeed.PanelFeed.EventPriority.FACTOR).subscribe( self.onNewValues) self.logger.debug("The maximum length for factors: {}".format(maxLen)) def attachFactorNormalizer( self, normType=factorNormalizer.NormalizedFeed.NormalizedType.ZSCORE): ''' :param normType: 因子规范化步骤,纳入自身属性 :return: ''' if normType == factorNormalizer.NormalizedFeed.NormalizedType.ZSCORE: self.zPanel = factorNormalizer.NormalizedFeed(self, normType=normType, inplace=False) elif normType == factorNormalizer.NormalizedFeed.NormalizedType.RANK: self.rPanel = factorNormalizer.NormalizedFeed(self, normType=normType, inplace=False) def getZscoreFactor(self): return getattr(self, 'zPanel', None) def getRankFactor(self): return getattr(self, 'rPanel', None) def getRawFactor(self): return self def initializeFactorCals(self, factorCalculatorCls): ''' :return: 初始化打分结果存储和计算对象 1.如果使用barFeed进行打分,需要挨个生成计算器对象,并存储在barFeedScoreCals字典中 2.使用panelFeed直接打分,只需要生成一个并存储在panelFeedScoreCals字典中 ''' self.barFeedFactorCal, self.panelFeedFactorCal = None, None if issubclass(factorCalculatorCls, factorModel.BaseBarFeedCalculator): factorCalculators = [] for instrument in self.panelFeed.barFeeds: factorCalculator = factorCalculatorCls( self, self.panelFeed.barFeeds[instrument], **self.kwargs) factorCalculators.append(factorCalculator) self.barFeedFactorCal = factorCalculators self.logger.debug( "The factor calculator type: {}".format(factorCalculators)) elif issubclass(factorCalculatorCls, factorModel.BasePanelCalculator): factorCalculator = factorCalculatorCls(self, self.panelFeed, **self.kwargs) self.panelFeedFactorCal = factorCalculator self.logger.debug( "The factor calculator type: {}".format(factorCalculator)) def calRawScores(self, panelFeed, dateTime, df): ''' :param panelFeed: :param dateTime: :return: ''' # barFeed类因子计算 if self.barFeedFactorCal is not None: score = np.zeros((self.colLen, )) for j, instrument in enumerate(self.instruments): sliceFeed = self.panelFeed.barFeeds[instrument] score[j] = self.barFeedFactorCal[j].calScore( sliceFeed, dateTime, sliceFeed.getLastBar()) # self.logger.info("Calculation of {}: \n{}".format(str(self.barFeedFactorCal.__class__).split('.')[-2], score[j])) # panelFeed类因子计算 else: score = self.panelFeedFactorCal.calScore(panelFeed, dateTime, df) # self.logger.info("Calculation of {}: \n{}".format(str(self.panelFeedFactorCal.__class__).split('.')[-2], score)) return score def onNewValues(self, panelFeed, dateTime, df): rawScore = self.calRawScores(panelFeed, dateTime, df) self.appendWithDateTime(dateTime, rawScore) def factorPlot(self, codes, seriesName): ''' :param codes:可选作哪个或哪组code的图 :param seriesName: OHLC以及score作图 :return: ''' if seriesName == 'score': plotData = pd.DataFrame(index=self.getDateTimes(), columns=self.getColumnNames(), data=self[:, :]) elif seriesName == 'open': self.priceData = self.panelFeed.openPanel elif seriesName == 'close': self.priceData = self.panelFeed.closePanel elif seriesName == 'high': self.priceData = self.panelFeed.highPanel elif seriesName == 'low': self.priceData = self.panelFeed.lowPanel plotData = pd.DataFrame(index=self.priceData.getDateTimes(), columns=self.priceData.getColumnNames(), data=self.priceData[:, :]) plotData = plotData.dropna(axis=1, how='any') # 去除含nan的股票数据 plotData = plotData / plotData.iloc[0, :] # 消除量纲的影响,起始点均为1 thisPlotDataData = plotData[codes] timeList = list(thisPlotDataData.index) N = len(thisPlotDataData) ind = np.arange(N) fig, ax = plt.subplots(1, 1) ax.plot(ind, thisPlotDataData) plt.legend(list(thisPlotDataData.columns), loc='best') def format_date(x, pos=None): # 改变横坐标格式 if x < 0 or x > len(timeList) - 1: return '' else: return timeList[int(x)] ax.xaxis.set_major_formatter(ticker.FuncFormatter(format_date)) plt.show() def getFactorCalculator(self, instrument=None): ''' :param instrument: 返回因子计算器对象,如果是单个因子计算的返回数组或者依据instrument返回该计算对象 面板计算则直接返回面板计算器 :return: ''' assert (self.panelFeedFactorCal and self.barFeedFactorCal) is None # 限定只有一个计算器 if self.barFeedFactorCal: return self.barFeedFactorCal if self.instruments is None else self.barFeedFactorCal[ np.where(np.array(self.instruments) == instrument)[0][0]] else: return self.panelFeedFactorCal
class DefaultFactorTest: """ 默认因子检验模块 """ logger = logger.getLogger("DefaultFactorTest") def __init__(self, panelFeed, factorPanel, returnPanel, lag=1, indicators=None, nGroup=10, cut=0.1, fee=0.003): ''' :param frequency: :param indicators: 需要计算的指标List,indicators = ['IC', 'rankIC', 'beta', 'gpIC', 'tbdf'] :param ngroup: 分成ngroup组 :param cut: 分组信息,0.1代表前10 % -后10 % :param fee: 开仓手续费,用于计算交易成本 ''' self.frequency = panelFeed.frequency self.panelFeed = panelFeed self.factorPanel = factorPanel self.returnPanel = returnPanel self.indicators = indicators self.lag = 1 self.nGroup = nGroup self.cut = cut self.fee = fee self.maxLen = panelFeed.dataSource.getDataShape()[0] # 需要计算的指标 self.betaSeries = series.SequenceDataSeries(self.maxLen) self.gpICSeries = series.SequenceDataSeries(self.maxLen) self.tbdfSeries = series.SequenceDataSeries( self.maxLen) # top组平均收益-bottom组平均收益 self.turnSeries = series.SequenceDataSeries(self.maxLen) self.ICPanel = series.SequenceDataPanel( ['Group%s' % i for i in range(1, self.nGroup + 1)], self.maxLen) self.rankICPanel = series.SequenceDataPanel( ['Group%s' % i for i in range(1, self.nGroup + 1)], self.maxLen) self.groupRetPanel = series.SequenceDataPanel( ['Group%s' % i for i in range(1, self.nGroup + 1)], self.maxLen) self.turnPanel = series.SequenceDataPanel( ['Group%s' % i for i in range(1, self.nGroup + 1)], self.maxLen) self.costPanel = series.SequenceDataPanel( ['Group%s' % i for i in range(1, self.nGroup + 1)], self.maxLen) self.groupNumberPanel = series.SequenceDataPanel( ['Group%s' % i for i in range(1, self.nGroup + 1)], self.maxLen) self.groupingCodesDict = {} self.topCodes = {} self.botCodes = {} self.poolNum = 0 # 注册回调事件,优先级为因子计算 self.panelFeed.getNewPanelsEvent( priority=baseFeed.PanelFeed.EventPriority.FACTOR).subscribe( self.updateIndicators) self.obervedRank = [] # 存储当期交易的股票收益 def updateIndicators(self, panelFeed, dateTime, df): ''' :param scorePanel: 分值矩阵 :param returnPanel: 收益矩阵 :return:因子效果评价指标,IC,rankIC,beta,gpIC,tbdf ''' if len( self.returnPanel ) > 2 * self.lag: # 前期数据过短时,return值不完善,比return向前lag期因子值也不完善,因此定为2 * self.lag self.dateTime = dateTime '''对缺失值进行处理''' thisAllReturn = self.returnPanel[-1, :] # 当期所有数据 lastAllFactor = self.factorPanel[-self.lag - 1, :] # 向前lag期因子 returnNotNan = np.argwhere(1 - np.isnan(thisAllReturn)) # 找出非nan值的位置 factorNotNan = np.argwhere(1 - np.isnan(lastAllFactor)) # 股票因子数量小于2无法分层 if factorNotNan.__len__() < 1: return notNanLocate = np.intersect1d(returnNotNan, factorNotNan) # 求两者交集 self.thisReturn = np.nan_to_num( thisAllReturn[notNanLocate].reshape( (len(notNanLocate), ))) # 当期去除nan后的数据 self.lastFactor = np.nan_to_num( lastAllFactor[notNanLocate].reshape((len(notNanLocate), ))) '''指标计算''' if 'IC' in self.indicators: ICPanel = sectionCalculator.ICGrouping(self.lastFactor, self.thisReturn, self.nGroup) self.ICPanel.appendWithDateTime(dateTime, ICPanel) if 'rankIC' in self.indicators: rankICPanel = sectionCalculator.RankICGrouping( self.lastFactor, self.thisReturn, self.nGroup) self.rankICPanel.appendWithDateTime(dateTime, rankICPanel) if 'gpIC' in self.indicators: gpIC = sectionCalculator.GPIC(self.lastFactor, self.thisReturn, self.nGroup) if 1 - np.isnan(gpIC): self.gpICSeries.appendWithDateTime(self.dateTime, gpIC) # 分n组后的相关系数 if 'beta' in self.indicators: beta = sectionCalculator.BETA(self.lastFactor, self.thisReturn) if 1 - np.isnan(beta): self.betaSeries.appendWithDateTime(self.dateTime, beta) # 单因子回归斜率 if 'tbdf' in self.indicators: tbdf = sectionCalculator.TBDF(self.lastFactor, self.thisReturn, self.cut) if 1 - np.isnan(tbdf): self.tbdfSeries.appendWithDateTime( self.dateTime, tbdf) # top平均收益 - bottom平均收益 if 'turn' in self.indicators: # 把分组的信息,写入groupingCodes totalLength = len(self.panelFeed.getInstruments()) if self.poolNum == 0: self.poolNum = totalLength #记录资产池中的总资产数量 groupResult = sectionCalculator.Grouping( self.lastFactor, self.nGroup) #返回的groupResult中每个元素的长度,是该时间截面 # 非空的股票的长度 self.groupNumberPanel.appendWithDateTime( self.dateTime, [sum(i) for i in groupResult]) # 记录每一组该时间截面的持仓数量 #print ('groupResult:',groupResult) #print(np.array(self.panelFeed.getInstruments())) adjustedGroupResult = sectionCalculator.adjustShape( groupResult, len(np.array(self.panelFeed.getInstruments())), notNanLocate) # 调用adjustShape函数,使每个元素的长度,和去除缺失值前一致 #print (adjustedGroupResult) groupCodeSlice = [ list(np.array(self.panelFeed.getInstruments())[i]) for i in adjustedGroupResult ] #获取每个组对应的股票代码 self.groupingCodesDict[ self. dateTime] = groupCodeSlice #把它写入self.groupingCodesDict这个字典 #如果是第一个,我们先把换手率设为100%(为了正确计算交易费用,使交易费用的panel和ret的panel形状相同) # 在最后求平均换手率时,把第一次的100%换手率剔除 if len(self.groupingCodesDict) == 1: turnGroup = [1] * self.nGroup self.turnPanel.appendWithDateTime(self.dateTime, turnGroup) # 根据turnover rate 计算交易成本 cost = list(np.array(turnGroup) * self.fee) self.costPanel.appendWithDateTime(self.dateTime, cost) #在第二个及之后,调用TurnGrouping函数,计算换手率 if len(self.groupingCodesDict) >= 2: lastGrouping = self.groupingCodesDict[ self.returnPanel.getDateTimes()[-2]] thisGrouping = self.groupingCodesDict[self.dateTime] turnGroup = sectionCalculator.TurnGrouping( lastGrouping, thisGrouping, self.nGroup, ) self.turnPanel.appendWithDateTime(self.dateTime, turnGroup) #根据turnover rate 计算交易成本 cost = list(np.array(turnGroup) * self.fee) self.costPanel.appendWithDateTime(self.dateTime, cost) if 'groupRet' in self.indicators: groupRetPanel = sectionCalculator.GroupRet( self.lastFactor, self.thisReturn, self.nGroup) self.groupRetPanel.appendWithDateTime( dateTime, list(np.array(groupRetPanel) - np.array(cost))) def getIndicators(self): '''获取所需要的检测指标及相应对象,存入字典中''' self.indicatorDict = {} for indicator in self.indicators: if indicator in ['groupRet', 'IC', 'rankIC', 'turn']: indicatorObjName = indicator + "Panel" else: indicatorObjName = indicator + "Series" indicatorObj = getattr(self, indicatorObjName) self.indicatorDict[indicator] = indicatorObj if indicator == 'turn': self.indicatorDict['cost'] = getattr(self, 'costPanel') self.indicatorDict['groupNumber'] = getattr( self, "groupNumberPanel") return self.indicatorDict
class CSVSampleDataReader(BaseDataReader): ''' 从本地数据库读取股票测试分钟行情数据 :return: 每运行一次getNextValues则输出一个时间截面所有股票数据df,行为代码,列为开高低收量信息 ''' logger = logger.getLogger("csvReader") def __init__(self, frequency=bar.Frequency.MINUTE, instruments=None, start=None): super().__init__(instruments, fields=const.DataField.OHLCV, start=start) self.frequency = frequency self.instruments = instruments self.dfs = [] self.isEof = False self.valGen = None self.path = pathSelector.PathSelector.getDataFilePath( market=const.DataMarket.STOCK, types=const.DataType.SAMPLE, frequency=const.DataFrequency.MINUTE) def _iter_(self): return self def setDir(self, path): ''' :param path: 设置文件读取路径 :return: ''' self.path = path def get_file_list(self): ''' :return:获取路径下的所有csv文件 ''' fileLists = os.listdir(self.path) # 读取文件夹下所有文件名 print(self.path) ret = [] for file in fileLists: if '.csv' in file.lower() and 'index' not in file.lower(): ret.append(file.split('.')[0]) return ret def loads(self): ''' :return: 从文件路径中读取相应的csv文件 ''' fileLists = self.get_file_list() if self.instruments is None: self.instruments = fileLists removeList = list(set(self.instruments) - set(fileLists)) # 求差集 if len(removeList) > 0: self.logger.info("Miss Data stock list: {}".format(removeList)) self.instruments = list(set(self.instruments) & set(fileLists)) # 求交集 self.instruments.sort() self.logger.info("The stock list: {}".format(self.instruments)) self.codeFrame = pd.DataFrame({'code': self.instruments}) for instrument in self.instruments: self.logger.info('Loading Data {}'.format(instrument)) thisData = read_csv_with_filter(self.path, instrument, self.start) self.dfs.append(thisData) self.dfs = pd.concat(self.dfs, axis=0) self.dfs.sort_values(['date', 'code'], ascending=True, inplace=True) self.allTime = np.unique(self.dfs['date']) self.valGen = self.valueGenerator() def valueGenerator(self): ''' :return:从dfs中读取下一个时间的数据,返回时间和一个dataframe:行为code,列为高开低收量,{下一时刻时间:对应股票数据dataframe} ''' for idx, _date in enumerate(self.allTime): ret = self.dfs[self.dfs['date'] == _date] ret = pd.merge(ret, self.codeFrame, on='code', how='outer') ret = ret.set_index('code') del ret['date'] yield pd.to_datetime(_date), ret if idx == len(self.allTime) - 2: self.isEof = True def getDir(self): return self.path def getFrequency(self): ''' :return:数据必须包含周期 ''' return self.frequency def getDataShape(self): ''' 返回数据长度(总行数,instrument数,field数) ''' return self.allTime.shape[0], self.instruments.__len__( ), self.fields.__len__() def getDateRange(self): ''' 返回数据起止日期,start, end, 实时数据只有start,没有end ''' return self.allTime[0], self.allTime[-1] def getRegisteredInstruments(self): return self.instruments def getNextValues(self): return next(self.valGen) def getFields(self): return self.fields def eof(self): return self.isEof
class AdvancedFeed: ''' 有时候feed不在一个数据源,此时为保证feed同步,将一个或多个feed合并成一个调用 指数的lable为'base' ''' logger = logger.getLogger('AdvancedFeed') def __init__(self, feedDict=None, panelDict=None): ''' :param feedDict: {lable: panelFeed} ''' self.feedDict = feedDict if feedDict is not None else {} self.panelDict = panelDict if panelDict is not None else {} self.resamples = {} #重采样数据 self.dataSource = {} self.instruments = None self.frequency = np.inf self.maxLen = 0 self.__currentDatetime = None self.setUseEventDateTimeInLogs(True) self.synchronizedNextValue = {} # lable : (dateTime, value) self.isEof = False self.stopped = False self.available = None self.allAvailable = None self.__panelEvents = collections.OrderedDict({ e: observer.Event() for e in PanelFeed.EventPriority.getEventsType() }) for lable, value in self.feedDict.items(): self.attachFeed(lable, value) for lable, value in self.panelDict.items(): self.attachPanel(lable, value) # 更新sequenceDataPanel def __attachBaseInfo(self, lable, value): ''' :param initialize base info, panelFeed and sequence DataPanel :return: ''' source = value.getDataSource() #重采样数据没有独立数据源 if source is None: self.resamples[lable] = value return self.dataSource[lable] = source if not (isinstance(source, BasePanelReader) and source.isInstrumentCol is False): if self.instruments is None: self.instruments = value.getInstruments() elif self.instruments != value.getInstruments(): removeList = list( set(self.instruments) - set(value.getInstruments())) # 求差集 self.logger.info( "Miss Data stock list:\n{}".format(removeList)) self.instruments = list( set(self.instruments) & set(value.getInstruments())) # 求交集 if source.getFrequency() < self.frequency: self.frequency = source.getFrequency() if self.getFrequency() != np.inf: warnings.warn(u'不同周期的数据元同步,会导致数据内部的时间不同步, 进行运算时请留意') if value.getMaxLen() > self.maxLen: self.maxLen = value.getMaxLen() def getNewPanelsEvent(self, priority=None): assert priority in PanelFeed.EventPriority.getEventsType() return self.__panelEvents[priority] def dispatchNewValueEvent(self, *args, **kwargs): for key, evt in self.__panelEvents.items(): evt.emit(*args, **kwargs) def hasOhlcv(self): ''' :return:数据中含有高开低收数据 ''' return self.getBase() is not None def hasBench(self): return self.getBench() is not None def getFrequency(self): return self.frequency def getInstruments(self): return self.instruments def getMaxLen(self): return self.maxLen def attachFeed(self, lable, otherFeed): ''' :param panelFeed: 将两个不同数据源的feed数据合并到一个调用(如OHLC数据和财务数据分属不同数据源) :param lable:标签 :return: ''' self.feedDict[lable] = otherFeed self.__attachBaseInfo(lable, otherFeed) def attachBaseFeed(self, baseFeed): self.attachFeed('base', baseFeed) def attachBenchMark(self, benchMarkPanel): ''' :param benchMarkPanel:添加指数行情 :return: ''' self.attachPanel('bench', benchMarkPanel) def attachPanel(self, lable, otherPanel): ''' :param lable: :param panel:dataPanel 数据 :return: ''' self.panelDict[lable] = otherPanel self.__attachBaseInfo(lable, otherPanel) def getBench(self): ''' :return:返回指数 ''' return self.panelDict.get('bench', None) def getBase(self): ''' :return: 返回价量所在feed ''' return self.feedDict.get('base', None) def getCurrentDateTime(self): return self.__currentDatetime def setUseEventDateTimeInLogs(self, useEventDateTime): if useEventDateTime: logger.Formatter.DATETIME_HOOK = self.getCurrentDateTime else: logger.Formatter.DATETIME_HOOK = None def peekNextValues(self, available=None): ''' :param available: 批量读取下一行数据,如果上次没有使用则不再读取 :return: ''' available = self.dataSource.keys() if available is None else available for lable in available: if not self.dataSource[lable].eof(): dateTime, value = self.dataSource[lable].getNextValues() # concurrent day level frequency to 15:00 if self.dataSource[lable].getFrequency( ) >= bar.Frequency.DAY > self.getFrequency(): dateTime = dateTime.replace(hour=15) self.synchronizedNextValue[lable] = (dateTime, value) else: del self.synchronizedNextValue[lable] if self.synchronizedNextValue == {}: self.isEof = True self.available = available def peekNextDatetime(self, sychronizedNextValues): ''' :param sychronizedNextValues:从候选列表中选择最小时间 :return: ''' return list( sorted([ dateTime for dateTime, _ in sychronizedNextValues.values() ]))[0] def getAvailable(self): ''' :return:返回当前时刻有数据更新的feed或panel的lable名 ''' return self.allAvailable def getNextValues(self): ''' :return: ''' self.peekNextValues(self.available) nextTime = self.peekNextDatetime(self.synchronizedNextValue) self.__currentDatetime = nextTime if self.hasOhlcv() and 'base' not in self.synchronizedNextValue: self.logger.info( 'there is no extra ohlc values, strategy will end') self.stopped = True return None, None available = [] for lable, value in self.synchronizedNextValue.items(): dateTime, value = value if dateTime == nextTime: available.append(lable) if lable in self.feedDict: self.feedDict[lable].appendNextValues(dateTime, value) self.feedDict[lable].dispatchNewValueEvent( self, dateTime, value) else: self.panelDict[lable].appendWithDateTime( dateTime, value.values) self.available = available self.allAvailable = available.copy() for lable, feed in self.resamples.items(): if feed.getCurrentDatetime() == nextTime: self.allAvailable.append(lable) if self.getBench( ) is not None and self.synchronizedNextValue['bench'][0] > nextTime: self.logger.info( 'benchMark exists, lower than {} time is ignored, current time {}' .format(self.synchronizedNextValue['bench'][0], nextTime)) return None, None else: self.dispatchNewValueEvent(self, nextTime, None) return nextTime, None def eof(self): ''' :return: 全部数据为空 ''' return self.isEof def run(self, stopCount=None, _print=False, callBack=None): counter = 0 while not self.eof() and not self.stopped: dateTime, df = self.getNextValues() if dateTime: counter += 1 if _print and callBack is None: print(dateTime, self.getAvailable()) elif _print and callBack is not None: print(dateTime, self.getAvailable(), end='\t') callBack() elif callBack is not None: callBack() if stopCount is not None and counter > stopCount: break
class H5DataReader(BaseDataReader): ''' h5 feed数据读取接口, 使用pathSelector 统一路径 ''' logger = logger.getLogger("H5DataReader") def __init__(self, frequency=bar.Frequency.MINUTE, instruments=None, fields=None, start=None, end=None, limit=-1): ''' :param frequency: 数据周期 :param instruments: 所选股票code :param fields: 所选字段,高开低收等 :param startTime: 起始时间 ''' super().__init__(instruments, fields, start) self.frequency = frequency self.market = bar.Market.STOCK if start is None: # 若start等于None,则默认start取本地最初的数据 self.start = pd.to_datetime('20150601 0930') else: self.start = pd.to_datetime(start) # 输入的起始时间,转换为datetime格式 if end is None: # 若end等于None,则默认end取本地最后的数据 self.end = pd.to_datetime('20190430 1500') else: self.end = pd.to_datetime(end) # 输入的终止时间,转换为datetime格式 self.registeredInstruments = 'Apply prepareOutputData function first and then check the registered instruments.' self.appliedTimeLine = list( pd.date_range( # 根据输入的start和end,生成日期序列,周期为月 max(self.start.strftime('%Y%m'), '201506') + '01', min(self.end.strftime('%Y%m'), '201904') + '01', freq='MS')) if fields is None: self.fields = ['open', 'high', 'low', 'close', 'volume'] else: pass # 设置要读取的字段,如果没有给定输入,则默认读取开、高、低、收、量 self.listedCodes = [] self.limit = limit self.outputTimes = 1 # 记录输出的数据的条数 self.totalLength = limit self.initialSignal = 0 self.currentTimeLine = [] self.isEof = False # 判断是否全部输出完毕的signal self.valGen = None # 储存每一次yield的数据 self.actualStart = None self.actualEnd = None self.currentInstruments = [] # 储存当前返回的dataframe中的股票 self.setFilePath() def setFilePath(self, path=None): ''' 设置文件夹路径 :param section: 使用pathSelector,读取datapath.ini中预先配置好的本地数据路径 :return: ''' if path is not None: self.path = path else: self.path = pathSelector.PathSelector.getDataFilePath( const.DataMarket.STOCK, const.DataType.OHLCV, const.DataFrequency.freq2lable(self.frequency)) # 用self.startEnd读取存放股票的上市、退市日期 self.startEnd = pd.read_excel(os.path.join(self.path, 'start_end_date.xlsx'), dtype={'index': str}) self.indexConstituent = pd.read_pickle( os.path.join(self.path, 'indexconstituent.pickle')) self.totalLength = self.getTotalLength() def getTotalLength(self): ''' :return: 返回本次查询所返回的总数据条数 ''' totalLength = 0 totalIndex = [] tmpTimeLine = self.appliedTimeLine.copy() for i in np.arange(len(tmpTimeLine)): df = self.fileterH5File(self.path, tmpTimeLine[i], 'trdstat') tmp = list([ j for j in np.unique(df.index) # 记录每个月内,交易时间介于start和end之间的数量 if (pd.to_datetime(j) <= self.end) & (pd.to_datetime(j) >= self.start) ]) totalIndex = totalIndex + tmp self.logger.debug("Checking timeline {}.".format(tmpTimeLine[i])) if (self.actualStart == None) & (len(tmp) > 0): self.actualStart = tmp[0] else: continue if len(tmp) > 0: self.actualEnd = tmp[-1] totalLength += len(tmp) del tmp if (self.limit is not None): if (totalLength >= self.limit): # 若数量已经超过limit,从for循环中break totalIndex = totalIndex[:self.limit] self.actualEnd = totalIndex[-1] break if self.instruments == None: self.registeredInstruments = list(df.columns) self.totalIndex = totalIndex return len(totalIndex) # 返回 limit和totalLength中的小值 def indexConstituentList(self, index, date): ''' :param path: pickle文件的路径 :param index: 要查询的指数 50,300,500 :param date: 查询的日期 :return: 成分股列表 ''' date = str(date)[:10] constituentStock = self.indexConstituent[index][date]['code'].tolist() constituentStockUpdate = [] # 股票代码处理:去掉后面的字母,只保留前六位数字 for code in constituentStock: constituentStockUpdate.append(code[:6]) return constituentStockUpdate def fileterH5File(self, folder, date, dataName): ''' :param folder: 存储股票分钟数据的总文件夹路径 :param date: 要读取的数据的日期,数据类型为datetime :param dataName: 要读取的字段名称,数据类型为字符串,如'close' :return: 返回一个dataframe,索引为分钟级的datetime,列为需要查询的股票代码 ''' year = str(date)[0:4] # 设置年份、月份,和文件路径 month = str(int(str(date)[5:7])) dataname = str(dataName) + '.h5' path = os.path.join(folder, year, month, dataname) df = pd.read_hdf(path) # 读取h5文件,只保留需要需要查询的股票 if (self.instruments is None) or (self.instruments == 'SZ50') or ( self.instruments == 'HS300') or (self.instruments == 'ZZ500'): cols = (df.columns) else: cols = (df.columns) & self.instruments # 若查询特定的股票列表,则取交集 # 重新设置索引为'datetime',格式为1min级别的datetime,删去先前的'date','time'两列 if platformSectionSelector() == 'lixiao': # 百度网盘数据 df = df[cols].reset_index() df['datetime'] = pd.to_datetime(df['date'].dt.strftime('%Y%m%d') + df['time'].apply(str).str.zfill(4)) df = df.drop(['date', 'time'], axis=1) df = df.set_index('datetime') elif "xuefu" in platformSectionSelector(): # 服务器数据 df = df[cols] else: # 百度网盘数据 df = df[cols].reset_index() df['datetime'] = pd.to_datetime(df['date'].dt.strftime('%Y%m%d') + df['time'].apply(str).str.zfill(4)) df = df.drop(['date', 'time'], axis=1) df = df.set_index('datetime') if (dataName == 'volume'): return (df / 100).round(0) * 100 elif (dataName == 'amount'): return df.round(0) else: return df.round(2) def getNeededData(self, dataName): ''' 根据输入的字段名称,获取对应字段数据,每次调用返回1个月的数据 :param dataName: 要查询的字段名称,如'close' :return: 返回一个dataframe ''' date = self.appliedTimeLine[0] # 所要读取的日期为,appliedTimeLine的第一个值 df = self.fileterH5File(self.path, date, dataName) df = df[(df.index >= self.start)] # 只保留查询起始日期start之后的那些行 df['field'] = dataName # 新加一列,列名为'field',值全部填充为该字段的名字 year = str(date)[0:4] month = str(int(str(date)[5:7])) self.logger.info("{} {} {} data got.".format(year, month, dataName)) return df def getAllFields(self): ''' 这是获取每个所需字段数据的函数 :return: ''' df = pd.DataFrame() for field in self.fields: # 遍历需要查询的字段 tmp = self.getNeededData(field) df = pd.concat([df, tmp]) del tmp if len(self.appliedTimeLine) > 0: # 如果所读取的月份不是本地的最后一个月份,那么把第一个值pop掉 self.appliedTimeLine.pop(0) return df def prepareOutputData(self): ''' 准备数据,存进self.dfs里面 ''' dfs = self.getAllFields() # 调用getAllFields函数,得到拼接好的1个月数据 # 将df根据'datetime'和'field'排序 toDel = [ col for col in dfs.columns if (col not in list(self.startEnd['index'])) & (col != 'field') ] dfs = dfs.drop(toDel, axis=1) # 把市场中没有的股票代码统一去除 dfs = dfs.sort_values(['datetime', 'field']) if self.initialSignal == 0: if (self.instruments == 'SZ50') or (self.instruments == 'HS300') or (self.instruments == 'ZZ500'): self.registeredInstruments = self.setRegisteredInstruments() else: if self.instruments is not None: tmp = list(dfs.columns)[:-1] else: tmp = [ j for j in (self.registeredInstruments) if j in list(self.startEnd['index']) ] beforeDelisted = list(self.startEnd[( self.startEnd['end_date'] >= self.actualStart)]['index']) afterlisted = list(self.startEnd[(self.startEnd['start_date'] <= self.actualEnd)]['index']) self.registeredInstruments = sorted([ j for j in tmp if (j in beforeDelisted) & (j in afterlisted) ]) self.initialSignal = 1 self.currentInstruments = self.registeredInstruments if self.instruments == None: dfs = pd.DataFrame(dfs, columns=self.registeredInstruments + ['field']) self.dfs = dfs tmp = np.unique(dfs.index) # 用tmp储存当前读取出的这个月份的数据的所有交易分钟 self.currentTimeLine = [ j for j in tmp if pd.to_datetime(j) <= self.end ] # 把这些交易时间中,不大于end的时间,用currentTimeLine储存 def prepareGenerator(self): ''' 初始化生成器,并检查新prepareOutputData得到的数据是否满足继续输出的条件 ''' if len(self.currentTimeLine) > 0: self.prepareOutputData() else: # 如果在start非交易日,并且当月start之后也没有交易日,那就要运行两遍prepareOutputData while len(self.currentTimeLine) == 0: self.prepareOutputData() if len(self.currentTimeLine) > 0: self.valGen = self.valueGenerator() else: # 如果新prepareOutputData所得的currentTimeLine长度为0,则设置终止信号 self.isEof = True def valueGenerator(self): ''' 生成器 :return:从dfs中读取下一个时间的数据,返回时间和一个dataframe:行为code,列为高开低收量 ''' if self.limit - self.outputTimes == 0: self.isEof = True for idx, date in enumerate( self.currentTimeLine): # 遍历currentTimeLine中的时间 tmp = self.dfs[self.dfs.index == date] # 获取每一分钟的数据,将索引设置为'field' ret = self.adjustData(tmp, date) # 调用adjusatData函数,对输出数据的格式进行调整 ret = ret.set_index('field').T # 转置矩阵 ret.index.names = ['code'] if idx == len(self.currentTimeLine) - 1: # 把数据全部逐条输出完成后,进行如下判断 if len(self.appliedTimeLine) == 0: self.isEof = True # 若本地已经没有可查询的数据,把self.isEof设置为True else: self.prepareGenerator() # 否则,调用prepareGenerator函数,初始化生成器 self.outputTimes += 1 # 输出数据,并记录输出次数 yield pd.to_datetime(date), ret if (self.outputTimes == self.limit): self.isEof = True def setRegisteredInstruments(self): ''' :return: 返回成分股信息,用于在regitsteredInstruments中储存当前的成分股 ''' if self.instruments == 'SZ50': cols = self.indexConstituentList(index=50, date=self.actualStart) elif self.instruments == 'HS300': cols = self.indexConstituentList(index=300, date=self.actualStart) elif self.instruments == 'ZZ500': cols = self.indexConstituentList(index=500, date=self.actualStart) else: raise '%s未记录成分股' % self.instruments return sorted(cols) def adjustData(self, df, date): ''' 对输出的数据进行格式加工 :param df: 时间对应的股票数据,格式为dataframe :param date: 分钟级别的一个时间,格式为datetime :return: 返回一个处理后的股票数据dataframe ''' if (self.instruments == 'SZ50') or (self.instruments == 'HS300') or (self.instruments == 'ZZ500'): if self.instruments == 'SZ50': cols = sorted( self.indexConstituentList(index=50, date=self.start)) + ['field'] elif self.instruments == 'HS300': cols = sorted( self.indexConstituentList(index=300, date=self.start)) + ['field'] elif self.instruments == 'ZZ500': cols = sorted( self.indexConstituentList(index=500, date=self.start)) + ['field'] else: raise '%s未记录成分股' % self.instruments df = df[cols] self.currentInstruments = list(df.columns)[:-1] # 更新当前包含的股票代码 else: # 如果instruments不是指数 # 存储该时间,已经退市的股票代码 notDelisted = list( self.startEnd[(self.startEnd['end_date'] > date)]['index']) df = pd.DataFrame(df, columns=notDelisted + ['field']) df = pd.DataFrame(df, columns=self.registeredInstruments + ['field']) tmp = df.dropna(axis=1, how='any') self.currentInstruments = list(tmp.columns).remove('field') del tmp return df def getRegisteredInstruments(self): ''' :return: 返回实际使用的instruments 若查询的是指数,则返回的是查询的起始时刻,该指数的成分股; 若查询的是 ''' return self.registeredInstruments def getNextValues(self): ''' :return: 获取生成器返回的下一个值 ''' return next(self.valGen) def eof(self): ''' :return: 如果触发停止条件:返回True, 否则返回False 触发条件:1.没有数据更新 2.输出数据条数达到预先设定值 3.start和end之间的数据全部输出完毕, ''' return self.isEof def getDateRange(self): ''' 返回一个列表,[所要查询的起始日期,所要查询的截止日期] :return: start, end ''' if len(self.totalIndex) == 0: return 'No data get in the given period.' return [self.actualStart, self.actualEnd] def getDataShape(self): shape = self.totalLength, len(self.registeredInstruments), len( self.fields) return shape def getDir(self): ''' :return: 返回数据的文件夹路径 ''' return self.path def getFields(self): ''' :return:返回所要查询的数据字段列表 ''' return self.fields def getFrequency(self): ''' :return: 返回查询的数据频率 ''' return self.frequency def getCurrentInstruments(self): ''' :return: 返回当前时刻,在市场中正常交易(已上市,未退市)的股票列表; 若查询的是指数,则返回当前该指数的成分股 ''' return self.currentInstruments
class H5PanelWriter(BaseWriter): ''' 因子计算及检测数据h5文件写入接口 @Time: 2019/8/10 20:30 @Author: Yanggang Fang 说明:用来写因子检测数据的,有两种写新的和续写两种模式,用cpa.factorPool.factorUpdate里面的相应函数调用。 ''' logger = logger.getLogger("H5PanelWriter") def __init__(self, defaultFactorTest, factorName): ''' 初始化 param defaultFactorTest: factorTest.py下的DefaultFactorTest类对象 param factorName: 因子名 ''' self.defaultFactorTest = defaultFactorTest self.testReportGenerator = factorTest.TestReportGenerator(defaultFactorTest=self.defaultFactorTest, h5BatchPanelReader=None) # self.h5BatchPanelReader = h5BatchPanelReader # if self.defaultFactorTest and self.h5BatchPanelReader: # self.testReportGenerator = factorTest.TestReportGenerator(defaultFactorTest=None, # h5BatchPanelReader=self.h5BatchPanelReader) # elif self.defaultFactorTest and not self.h5BatchPanelReader: # self.testReportGenerator = factorTest.TestReportGenerator(defaultFactorTest=self.defaultFactorTest, # h5BatchPanelReader=None) self.frequency = defaultFactorTest.frequency self.factorName = factorName self.count = 0 self.name = self.__class__.__name__ def getDir(self): ''' 使用pathSelector生成路径,此类不需要传参文件路径 ''' pass def write(self, mode, oldResultDict=None): ''' 写入函数 param mode: 写入模式, "new" or "append" param oldResultDict: 存储旧h5文件数据的字典,由h5PanelReader生成 ''' # 存储路径命名 currentDT = datetime.datetime.now() factorFolderPath = pathSelector.PathSelector.getFactorFilePath(factorName=self.factorName, # 因子计算数据的文件路径 factorFrequency=const.DataFrequency.freq2lable(self.frequency)) # 因子文件夹路径 calFileName = self.factorName + "_factor_" +\ const.DataFrequency.freq2lable(self.frequency) +\ currentDT.strftime("_%Y%m%d_%H%M") + ".h5" # 因子计算值文件名 calFilePath = pathSelector.PathSelector.getFactorFilePath(factorName=self.factorName, # 因子计算数据的文件路径 factorFrequency=const.DataFrequency.freq2lable(self.frequency), fileName=calFileName) # 写入新h5文件 if mode == "new": # 因子计算数据存储 self.testReportGenerator.defaultFactorTest.factorPanel.to_frame().to_hdf(path_or_buf=calFilePath, # 使用pandas存储h5文件 key=self.factorName, format="table", data_columns=True, mode="w") self.logger.info("The file {} has been saved".format(calFileName)) # 因子检测数据存储 indicatorDict = self.testReportGenerator.defaultFactorTest.getIndicators() # 取包含因子检测对象的字典 for key, value in indicatorDict.items(): testFileName = self.factorName + "_" + key + "_" +\ const.DataFrequency.freq2lable(self.frequency) +\ currentDT.strftime("_%Y%m%d_%H%M") + ".h5" # 因子检测数据文件名 testFilePath = pathSelector.PathSelector.getFactorFilePath(factorName=self.factorName, # 因子计算数据的文件路径 factorFrequency=const.DataFrequency.freq2lable(self.frequency), fileName=testFileName) # 因子检测数据文件路径 if indicatorDict[key].__len__(): # 当存储因子检测值的series不为空时进行存储 if key in ['groupRet', 'IC', 'rankIC', 'turn', 'cost', "groupNumber"]: value.to_frame().to_hdf(path_or_buf=testFilePath, key=key, format="table", data_columns=True, mode="w") else: value.to_series().to_hdf(path_or_buf=testFilePath, key=key, format="table", data_columns=True, mode="w") self.logger.info("The file {} has been saved".format(testFileName)) else: # 当存储因子检测值的series为空时,不进行存储,并记入日志 self.logger.info("The calculation of {} failed".format(key)) # 储存分层收益图 figName = self.factorName + '_Report_' +\ const.DataFrequency.freq2lable(self.frequency) + '.png' path = pathSelector.PathSelector.getFactorFilePath(factorName=self.factorName, # 因子计算数据的文件路径 factorFrequency=const.DataFrequency.freq2lable(self.frequency), fileName=figName) self.testReportGenerator.plotGroupret(_show=False, path=path) # 储存分层统计量 statisticFileName = self.factorName + '_Statistic_' +\ const.DataFrequency.freq2lable(self.frequency) + '.xls' self.testReportGenerator.statistic(path=pathSelector.PathSelector.getFactorFilePath( factorName=self.factorName, # 因子计算数据的文件路径 factorFrequency=const.DataFrequency.freq2lable(self.frequency), fileName=statisticFileName)) # 续写h5文件 elif mode == "append": # 因子计算数据存储 for key in oldResultDict.keys(): if "factor" in key: oldCalFileName = key appendDataFrame = self.testReportGenerator.defaultFactorTest.factorPanel.to_frame() # 新生成的因子计算值dataframe # 当新dataframe的最早时间晚于旧dataframe的最早时间并早于旧dataframe的最晚时间才进行拼接 if appendDataFrame.index[0] > oldResultDict[oldCalFileName].index[0] \ and appendDataFrame.index[0] < oldResultDict[oldCalFileName].index[-1]: newDataFrame = oldResultDict[oldCalFileName].append(appendDataFrame) newDataFrame = newDataFrame.loc[~newDataFrame.index.duplicated(keep="first")] # 删除时间重复的数据 newDataFrame.to_hdf(path_or_buf=calFilePath, # 使用pandas存储h5文件 key=self.factorName, format="table", data_columns=True, mode="w") self.count += 1 self.logger.info("The file {} has been saved".format(calFileName)) else: self.logger.info("The earliest time of {} is before the one of {}. " "The new dataframe will not be appended.".format(calFileName, oldCalFileName)) # 因子检测数据存储 indicatorDict = self.testReportGenerator.defaultFactorTest.getIndicators() for key, value in indicatorDict.items(): # 遍历新生成的检测数据 testFileName = self.factorName + "_" + key + "_" +\ const.DataFrequency.freq2lable(self.frequency) +\ currentDT.strftime("_%Y%m%d_%H%M") + ".h5" # 命名因子检测数据文件 testFilePath = os.path.join(factorFolderPath, testFileName) if indicatorDict[key].__len__(): # 当存储因子检测值不为空时进行存储 appendData = value.to_frame() if key in ['groupRet', 'IC', 'rankIC', 'turn', 'cost', 'groupNumber']\ else value.to_series() # 生成新因子检测值df或者series for key, value in oldResultDict.items(): # 提取存放旧h5数据的字典 # 检测新因子文件名是否与旧因子文件名相同 if "".join(testFileName.split("_")[0:2]) == "".join(key.split("_")[0:2]): # 当新data的最早时间晚于旧data的最早时间并早于旧data的最晚时间才进行拼接 if appendData.index[0] > oldResultDict[key].index[0] \ and appendData.index[0] < oldResultDict[key].index[-1]: newData = oldResultDict[key].append(appendData) newData = newData.loc[~newData.index.duplicated(keep="first")] newData.to_hdf(path_or_buf=testFilePath, key=key, format="table", data_columns=True, mode="w") self.count += 1 self.logger.info("The file {} has been saved".format(testFileName)) else: self.logger.info("The earliest time of {} is before the one of {}. " "The new series will not be appended.".format(testFileName, key)) else: # 当存储因子检测值的series为空时,不进行存储,并记入日志 self.logger.info("The calculation of {} failed".format(key)) # 储存分层收益图 figName = self.factorName + '_Report_' +\ const.DataFrequency.freq2lable(self.frequency) + '.png' path = pathSelector.PathSelector.getFactorFilePath(factorName=self.factorName, # 因子计算数据的文件路径 factorFrequency=const.DataFrequency.freq2lable(self.frequency), fileName=figName) self.testReportGenerator.plotGroupret(_show=False, path=path) # 储存分层统计量 statisticFileName = self.factorName + '_Statistic_' +\ const.DataFrequency.freq2lable(self.frequency) + '.xls' self.testReportGenerator.statistic(path=pathSelector.PathSelector.getFactorFilePath( factorName=self.factorName, # 因子计算数据的文件路径 factorFrequency=const.DataFrequency.freq2lable(self.frequency), fileName=statisticFileName)) else: raise ValueError("An argument except 'new' or 'append' was passed into the write() function for the mode")
class H5PanelReader(BasePanelReader): ''' h5 panel 数据读取接口 ''' logger = logger.getLogger("H5PanelReader") def __init__(self, dir, frequency=bar.Frequency.MINUTE, start=None, end=None): ''' 初始化 param frequency: 数据频率 param dir: 文件路径 param start: 所需要获取的数据开始时间 param end: 所需要获取的数据结束时间 ''' super().__init__() self.frequency = frequency self.start = start self.end = end self.dir = dir self.isEof = False self.count = 0 self.staticPanel = None self.staticSeries = None self.iterator = None def retrieve(self): ''' 获取数据,存入dataframe ''' # 读取单个文件 self.df = pd.read_hdf(path_or_buf=self.dir) # 若未输入所需开始时间,则取数据自身的开始时间 self.start = pd.Timestamp( self.start) if self.start else self.df.index[0] if self.start < self.df.index[0]: self.logger.warning( "The input start date {} is before the data's start date {}". format(self.start, self.df.index[0])) # 若未输入所需结束时间,则取数据自身的结束时间 self.end = pd.Timestamp(self.end) if self.end else self.df.index[-1] if self.end > self.df.index[-1]: self.logger.warning( "The input end date {} is after the data's end date {}".format( self.end, self.df.index[-1])) self.df = self.df.loc[self.start:self.end] return self.df def getIterator(self): ''' 返回一个pandas迭代器 ''' # 判断所读取的数据为dataframe还是series,相应地生成迭代器 self.iterator = self.df.iterrows() if len( self.df.shape) == 2 else self.df.iteritems() return self.iterator def getDir(self): ''' 返回文件路径 ''' return self.dir def getFrequency(self): ''' 返回数据频率 ''' return self.frequency def getDataShape(self): ''' 返回数据长度 ''' return self.df.shape def getDateRange(self): ''' 返回数据起止日期 ''' return (self.df.index[0], self.df.index[-1]) def getRegisteredInstruments(self): ''' 返回股票代码 ''' self.instrumentList = [] if len(self.df.shape) == 2: self.instrumentList = self.df.columns.values.tolist() return self.instrumentList def getNextValues(self): ''' 返回datetime, value ''' self.count += 1 return next(self.iterator) def eof(self): ''' 判断迭代器是否达到了dataframe最后一行并返回判断变量 ''' if self.count == len(self.df.index): self.isEof = True return self.isEof def to_static_panel(self, maxLen=None): ''' 直接将全量数据转换为静态的SequenceDataPanel或者SequenceDataSeries,返回对象 ''' if len(self.df.shape) == 2: self.staticPanel = series.SequenceDataPanel( self.getRegisteredInstruments(), maxLen=maxLen, dtype=np.float32) for index, row in self.df.iterrows(): self.staticPanel.appendWithDateTime(index, row) return self.staticPanel else: self.staticSeries = series.SequenceDataSeries(maxLen=maxLen) for index, row in self.df.iteritems(): self.staticSeries.appendWithDateTime(index, row) return self.staticSeries def to_frame(self): ''' 获取的数据已为dataframe,此处无转换 ''' pass
class FactorUpdate: """因子检测数据更新""" logger = logger.getLogger("factorUpdate") def __init__(self, instruments, market=bar.Market.STOCK, start=None, end=None, dataFreq=bar.Frequency.MINUTE, testFreq=None, fee=0.003): ''' 初始化因子检测参数并读取股票文件 param instruments: 代码 "SZ50", "HS300", or "ZZ500" param market: 市场 bar.Market.STOCK, or bar.Market.FUTURES param frequency: 数据频率 bar.Frequency.MINUTE or bar.Frequency.HOUR param start: 因子检测开始时间,当为空值时将使用H5DataReader的默认开始时间 param end: 因子检测结束时间,当为空值时将使用H5DataReader的默认结束时间 ''' self.instruments = instruments self.market = market self.start = start self.end = end self.newFactor = [] self.factorDefPath = pathSelector.PathSelector.getFactorDefPath() self.factorDataPath = pathSelector.PathSelector.getFactorFilePath() self.fee = fee self.dataFreq = dataFreq self.testFreq = [bar.Frequency.MINUTE5, bar.Frequency.MINUTE30, bar.Frequency.HOUR, bar.Frequency.HOUR2] if not testFreq else testFreq #设置要回测的时间频率,默认测试 5,30,60,120分钟的 def getPanelFeed(self): '''获取一个新的panelFeed''' panelFeed = DataFeedFactory.getHistFeed(instruments=self.instruments, market=self.market, frequency=bar.Frequency.MINUTE, start=self.start, end=self.end) return panelFeed def newFactorList(self): '''获取新增的因子列表''' allFactors = [factor.split('.')[0] for factor in os.listdir(self.factorDefPath) \ if factor not in ['__init__.py', '__pycache__']] # self.logger.info("All factors defined: {}".format(allFactors)) self.newFactor = sorted(list(set(allFactors) - set(os.listdir(self.factorDataPath)))) if self.newFactor: self.logger.info("The new factors:{}".format(self.newFactor)) else: self.logger.info("No new factors seen, the factor updating process will end soon") def writeNewFactor(self, F=1): ''' 存储数据文件 param F: 调仓频率 ''' self.newFactorList() if self.newFactor: # 仅在有新增因子的情况下才进行后续的因子计算、检验及存储 for factor in self.newFactor: # 对新增因子列表里的因子进行计算和数据存储 if factor == 'broker': continue self.logger.info( "************************ Writing FactorData for {} ************************".format(factor)) modulePath = "cpa.factorPool.factors.{}".format(factor) # 因子模块路径 module = importlib.import_module(modulePath) self.logger.info("The module {} has been imported successfully".format(factor)) panelFeed = self.getPanelFeed() # 为新的因子匹配一个新的panelFeed reasampleFeedDict = {} # 几个字典,分别储存相应时间频率的变量 _return_Dict = {} factorObjectDict = {} rawFactorDict = {} factorTesterDict = {} for freq in self.testFreq: reasampleFeedDict[freq] = ResampledPanelFeed(panelFeed, freq) _return_Dict[freq] = returns.Returns(reasampleFeedDict[freq], lag=F, maxLen=1024) factorObjectDict[freq] = getattr(module, 'Factor') rawFactorDict[freq] = factorBase.FactorPanel(reasampleFeedDict[freq], factorObjectDict[freq]) factorTesterDict[freq] = DefaultFactorTest(reasampleFeedDict[freq], rawFactorDict[freq], _return_Dict[freq], indicators=['IC', 'rankIC', 'beta', 'gpIC', 'tbdf', 'turn', 'groupRet'], lag=F, cut=0.1, fee=self.fee) panelFeed.run(2000) if len(_return_Dict[self.testFreq[0]]) <= 2 * F: # 若数据长度不符合因子检验标准,则不存储 self.logger.warning( "The length of the return panel <= 2 * the required lag. Data will not be saved.") return for freq in self.testFreq: h5PanelWriter = h5Writer.H5PanelWriter(factorTesterDict[freq], factor) h5PanelWriter.write(mode="new") def updateFactor(self, factor, removeOld=True, F=1): ''' 续写一个因子文件夹下的所有文件 param factor: 因子名 param removeOld: 是否删除原有文件 param F: 调仓频率 ''' self.logger.info("************************Updating FactorData for {}************************".format(factor)) factorReader = h5Reader.H5BatchPanelReader(factorName=factor, frequency=None) factorReader.prepareOutputData() dateRangeDict = factorReader.getDateRange() # 获取存放首尾数据日期的字典 endDateList = sorted([range[1] for range in dateRangeDict.values()]) # 取所有的数据结束日期, 并排序 firstEndTime = endDateList[0].to_pydatetime() # 取所有数据结束日期中最早的一个 timeDiff = pd.tseries.offsets.BusinessDay(n=np.floor(2*F*self.dataFreq/86400) + 1) # 将2F转换成天数后+1 self.start = firstEndTime - timeDiff # 计算数据读取开始的时间 panelFeed = self.getPanelFeed() # 以新的start获取一个新的panelFeed modulePath = "cpa.factorPool.factors.{}".format(factor) # 因子模块路径 module = importlib.import_module(modulePath) # 导入模块 self.logger.info("The module {} has been imported successfully".format(factor)) factorObject = getattr(module, 'Factor') # 获取因子对象的名称 e.g. cpa.factorPool.factors.dmaEwv.Factor resampleFeedDict = {} returnDict = {} rawFactorDict = {} factorTesterDict = {} dictOldResultDict = {} dictFilePathDict = {} for resample in self.testFreq: frequencyStr = const.DataFrequency.freq2lable(resample) resampleReader = h5Reader.H5BatchPanelReader(factorName=factor, frequency=frequencyStr) # 读取文件夹内所有文件 resampleReader.prepareOutputData() # 存入相应的字典中 oldResultDict = resampleReader.getTestResult() # 获取存放dataframe数据的字典 filePathDict = factorReader.getFilePath() # 获取原来H5文件的路径 key = str(resample).split(".")[-1] dictOldResultDict[key] = oldResultDict dictFilePathDict[key] = filePathDict resampleFeedDict[key] = ResampledPanelFeed(panelFeed, resample) returnDict[key] = returns.Returns(resampleFeedDict[key], lag=F, maxLen=1024) rawFactorDict[key] = factorBase.FactorPanel(resampleFeedDict[key], factorObject) factorTesterDict[key] = DefaultFactorTest(panelFeed=resampleFeedDict[key], factorPanel=rawFactorDict[key], returnPanel=returnDict[key], indicators=['IC', 'rankIC', 'beta', 'gpIC', 'tbdf', 'turn', 'groupRet'], lag=F, cut=0.1) # rawFactor = factorBase.FactorPanel(panelFeed, factorObject) # factorTester = DefaultFactorTest(panelFeed, rawFactor, _return, # indicators=['IC', 'rankIC', 'beta', 'gpIC', 'tbdf', 'turn'], # lag=F, # cut=0.1) panelFeed.run(2000) for key, oldResultDict in dictOldResultDict.items(): h5PanelWriter = h5Writer.H5PanelWriter(factorTesterDict[key], factor) h5PanelWriter.write(mode="append", oldResultDict=oldResultDict) # 使用append模式写入 # for resample in self.testFreq: # frequencyStr = const.DataFrequency.freq2lable(resample) # secondReader = h5Reader.H5BatchPanelReader(factorName=factor, # frequency=frequencyStr) # 读取文件夹内所有文件 # secondReader.prepareOutputData() # 存入相应的字典中 # h5PanelWriter.writeRepStat(secondReader) # 默认在写入新文件之后删除原来的文件 # if removeOld: # if h5PanelWriter.count == len(oldResultDict): # 当写入的文件数等于原来的文件数时,删除原来的文件 # for file in dictFilePathDict[key].values(): # os.remove(file) # else: # self.logger.info( # "The number of the new written files does not equal to the number of the old") def updateFactorPool(self, removeOld=True): ''' 续写factorData下所有的因子文件夹 param removeOld: 是否删除原有文件 ''' factorNameList = [name for name in os.listdir(self.factorDataPath) if # 取factorData文件下的子文件夹名 os.path.isdir(os.path.join(self.factorDataPath, name))] for factor in factorNameList: self.updateFactor(factor, removeOld=removeOld, F=1)
class DataFeedFactory: ''' 数据集中接口 ''' logger = logger.getLogger('feedFactory') @classmethod def getHistFeed(cls, market=bar.Market.STOCK, frequency=bar.Frequency.MINUTE, instruments=None, field=const.DataField.OHLCV, start=None, end=None, types=const.DataType.OHLCV, maxLen=1024): ''' :param market: 市场类型 :param frequency: 周期 :param instruments: instrumets列表 :param field: 数据列(OHLCV) :param start: 起止时间 :param end: :param types: 数据类别,默认是价量数据,也可以是财务数据 :param maxLen: dataFeed缓存长度 :return: ''' if field is None: field = const.DataField.OHLCV if types is None: types = const.DataType.OHLCV # 股票市场数据 if market == bar.Market.STOCK: if types == const.DataType.OHLCV: dataReader = h5Reader.H5DataReader(frequency=frequency, instruments=instruments, fields=field, start=start, end=end) dataReader.prepareGenerator() # 调用生成器 elif types == const.DataType.SAMPLE: assert frequency == bar.Frequency.MINUTE, '样本数据只包含1分钟频率' dataReader = csvReader.CSVSampleDataReader(frequency, instruments=None) dataReader.loads() elif types == const.DataType.FINANCE: assert frequency == bar.Frequency.QUARTER, '财务数据为季度' dataReader = csvReader.CSVFinanceReader.FinanceReader( fileName=None, instruments=None, fields=None, start=None) else: dataReader = None panelFeed = PanelFeed(dataReader, dataReader.getRegisteredInstruments(), maxLen=maxLen) return panelFeed # 期货市场数据 elif market == bar.Market.FUTURES: if frequency == bar.Frequency.MINUTE: dataReader = csvReader.CSVFutureDataReader( instruments, const.DataField.OHLCV, start, end) dataReader.prepareGenerator() panelFeed = PanelFeed(dataReader, dataReader.getRegisteredInstruments(), maxLen=maxLen) return panelFeed elif frequency != bar.Frequency.MINUTE and {instruments} <= {[ 'IH', 'IF', 'IC' ]}: dataReader = csvReader.CSVFutureDataReader( instruments, const.DataField.OHLCV, start, end) dataReader.prepareGenerator() panelFeed = PanelFeed(dataReader, dataReader.getRegisteredInstruments(), maxLen=maxLen) resampleFeed = ResampledPanelFeed(panelFeed, frequency=frequency) return resampleFeed else: assert frequency == bar.Frequency.MINUTE, '商品期货数据仅支持1分钟频率,后续请李霄完善' @classmethod def getLiveFeed(cls, source='tq', instruments=None, start=None, end=None, maxLen=1024): ''' :param source: 实时数据流 :param instruments: :param start: :param end: :param maxLen: :return: ''' pass