Exemplo n.º 1
0
class PYTDXService():
    """pytdx数据服务类"""
    def __init__(self):
        """Constructor"""
        self.connected = False  # 数据服务连接状态
        self.hq_api = None  # 行情API

    def connect_api(self):
        """连接API"""
        # 连接增强行情API并检查连接情况
        try:
            if not self.connected:
                host = SETTINGS["TDX_HOST"]
                port = SETTINGS["TDX_PORT"]
                self.hq_api = TdxHq_API()
                self.hq_api.connect(host, port)
                self.connected = True
            return True
        except Exception:
            raise ConnectionError("pytdx连接错误")

    def get_realtime_data(self, symbol: str):
        """获取股票实时数据"""
        try:
            symbols = self.generate_symbols(symbol)
            df = self.hq_api.to_df(self.hq_api.get_security_quotes(symbols))
            return df
        except Exception:
            raise ValueError("股票数据获取失败")

    def get_history_transaction_data(self, symbol, date):
        """
        查询历史分笔数据
        get_history_transaction_data(TDXParams.MARKET_SZ, '000001', 0, 10, 20170209)
        参数:市场代码, 股票代码, 起始位置, 数量, 日期
        输出[time, price, vol, buyorsell(0:buy, 1:sell, 2:平)]
        """
        # 获得标的
        code, market = self.check_symbol(symbol)

        # 设置参数
        check_date = int(date)
        count = 2000
        data_list = []
        position = [6000, 4000, 2000, 0]
        for start in position:
            data = self.hq_api.to_df(
                self.hq_api.get_history_transaction_data(
                    market, code, start, count, check_date))
            data_list.append(data)

        df = pd.concat(data_list)
        df.drop_duplicates(inplace=True)
        return df

    @staticmethod
    def generate_symbols(symbol: str):
        """组装symbols数据,pytdx接收的是以市场代码和标的代码组成的元祖的list"""
        new_symbols = []
        code, exchange = symbol.split('.')
        new_symbols.append((exchange_map[exchange], code))

        return new_symbols

    @staticmethod
    def check_symbol(symbol: str):
        """检查标的格式"""
        if symbol:
            code, market = symbol.split('.')
            market = exchange_map.get(market)
            return code, market

        else:
            return False

    def close(self):
        """数据服务关闭"""
        self.connected = False
        self.hq_api.disconnect()
Exemplo n.º 2
0
class TBStockData:

    __serverList = []
    _bestIP = []
    __bestIPFile = ''
    __tdx = None
    _lastBaseHistList = pd.DataFrame()
    _xdxrData = None


    def __init__(self, autoIP = False):
        self.__serverList = hq_hosts
        self.__bestIPFile = os.path.dirname(os.path.realpath(__file__)) + '/best.ip'

        if autoIP:
            self.getBestIP()
        else:
            if os.path.exists(self.__bestIPFile):
                with open(self.__bestIPFile, 'r') as f:
                    data = f.read()
                    self._bestIP = json.loads(data)

    def ping(self, ip, port):
        api = TdxHq_API()
        time1 = datetime.datetime.now()

        try:
            with api.connect(ip, int(port)):
                if len(api.get_security_list(0, 1)) > 800:
                    return datetime.datetime.now() - time1
                else:
                    return datetime.timedelta(9, 9, 0)
        except:
            return datetime.timedelta(9, 9, 0)

    def getBestIP(self):

        pingTimeList = [self.ping(x[1], x[2]) for x in self.__serverList]
        self._bestIP = self.__serverList[pingTimeList.index(min(pingTimeList))]

        with open(self.__bestIPFile, 'w') as f:
            f.write(json.dumps(self._bestIP))

    def showAllIP(self):
        for item in self.__serverList:
            print item[0],'\t', item[1], '\t', item[2]

    def _connect(self):

        if self.__tdx is None:
            if not self._bestIP:
                self.getBestIP()

            #self.__tdx = TdxHq_API(heartbeat=True, auto_retry=True)
            self.__tdx = TdxHq_API(auto_retry=True)
            self.__tdx.connect(self._bestIP[1], int(self._bestIP[2]))

    #计算量比
    def _setVolRaito(self, row):
        date = row.name
        histList = self._lastBaseHistList[:date]
        if len(histList) < 6:
            return np.nan

        return round((histList['vol'].values[-1] / 240) / (histList[-6:-1]['vol'].sum() / 1200), 3)

    #计算各种指标
    def getData(self, df = pd.DataFrame(), indexs=['turnover', 'vol', 'ma', 'macd', 'kdj', 'cci', 'bbi', 'sar', 'trix']):

        indexs = [x.lower() for x in indexs]
        histList = pd.DataFrame()

        if not df.empty:
            histList = df.copy()
        elif not self._lastBaseHistList.empty:
            histList = self._lastBaseHistList.copy()

        if histList.empty:
            return None

        dayKStatus = False
        try:
            if int(time.mktime(time.strptime(str(histList.index[-1]), "%Y-%m-%d %X"))) - int(time.mktime(time.strptime(str(histList.index[-2]), "%Y-%m-%d %X"))) > 43200:
                #日线以上行情
                dayKStatus = True
        except:
            dayKStatus = True

        #计算涨幅
        histList['p_change'] = histList['close'].pct_change().round(5) * 100

        #量比
        histList['vol_ratio'] = histList.apply(self._setVolRaito, axis=1)

        #振幅
        histList['amp'] = ((histList['high'] - histList['low']) / histList.shift()['close'] * 100).round(3)

        #计算换手率
        if self._xdxrData is None:
            xdxrData = self.getXdxr(str(histList['code'].values[0]))
        else:
            xdxrData = self._xdxrData
        info = xdxrData[xdxrData['liquidity_after'] > 0][['liquidity_after', 'shares_after']]

        if dayKStatus:
            startDate = str(histList.index[0])[0:10]
            endDate = str(histList.index[-1])[0:10]
            info1 = info[info.index <= startDate][-1:]
            info = info1.append(info[info.index >= startDate]).drop_duplicates()
            info = info.reindex(pd.date_range(info1.index[-1], endDate))
            info = info.resample('1D').last().fillna(method='pad')[startDate:endDate]
            #info['date'] = info.index
            #info['date'] = info['date'].dt.strftime('%Y-%m-%d')
            #info = info.set_index('date')

            circulate = info['liquidity_after'] * 10000
            capital = info['shares_after'] * 10000
        else:
            circulate = info['liquidity_after'].values[-1] * 10000
            capital = info['shares_after'].values[-1] * 10000

        #histList['circulate'] = (circulate / 10000 / 10000).round(4)

        if 'turnover' in indexs and dayKStatus:
            histList['turnover'] = (histList['vol'] * 100 / circulate).round(5) * 100
            histList['turnover5'] = talib.MA(histList['turnover'].values, timeperiod=5).round(3)

        #stockstats转换,主要是用来计算KDJ等相关指标
        #用talib计算KDJ时会与现有软件偏差大
        ss = StockDataFrame.retype(histList[['high','low','open','close']])

        #MACD计算
        if 'macd' in indexs:
            difList, deaList, macdList = talib.MACD(histList['close'].values, fastperiod=12, slowperiod=26, signalperiod=9)
            macdList = macdList * 2
            histList['macd_dif'] = difList.round(3)
            histList['macd_dea'] = deaList.round(3)
            histList['macd_value'] = macdList.round(3)
            histList['macd_value_ma'] = 0
            try:
                histList['macd_value_ma'] = talib.MA(histList['macd_value'].values, timeperiod=5).round(3)
            except:
                pass
            histList['macd_cross_status'] = 0
            macdPosList = histList['macd_dif'] > histList['macd_dea']
            histList.loc[macdPosList[(macdPosList == True) & (macdPosList.shift() == False)].index, 'macd_cross_status'] = 1
            histList.loc[macdPosList[(macdPosList == False) & (macdPosList.shift() == True)].index, 'macd_cross_status'] = -1
            #histList[['macd_cross_status']] = histList[['macd_cross_status']].fillna(method='pad')

        #KDJ计算
        if 'kdj' in indexs:
            histList['kdj_k'] = ss['kdjk'].round(3)
            histList['kdj_d'] = ss['kdjd'].round(3)
            histList['kdj_j'] = ss['kdjj'].round(3)
            histList['kdj_cross_status'] = 0
            kdjPosList = histList['kdj_k'] >= histList['kdj_d']
            histList.loc[kdjPosList[(kdjPosList == True) & (kdjPosList.shift() == False)].index, 'kdj_cross_status'] = 1
            histList.loc[kdjPosList[(kdjPosList == False) & (kdjPosList.shift() == True)].index, 'kdj_cross_status'] = -1
            #histList[['kdj_cross_status']] = histList[['kdj_cross_status']].fillna(method='pad')

        #CCI计算
        if 'cci' in indexs:
            histList['cci'] = ss['cci'].round(3)

        #ma相关计算
        if 'ma' in indexs:
            histList['ma5'] = talib.MA(histList['close'].values, timeperiod=5).round(3)
            histList['ma10'] = talib.MA(histList['close'].values, timeperiod=10).round(3)
            histList['ma20'] = talib.MA(histList['close'].values, timeperiod=20).round(3)
            histList['ma30'] = talib.MA(histList['close'].values, timeperiod=30).round(3)
            histList['ma60'] = talib.MA(histList['close'].values, timeperiod=60).round(3)
            histList['ma240'] = talib.MA(histList['close'].values, timeperiod=240).round(3)
            histList[['ma5', 'ma10', 'ma20', 'ma30', 'ma60', 'ma240']] = histList[['ma5', 'ma10', 'ma20', 'ma30', 'ma60', 'ma240']].fillna(0)

        #成交量计算
        if 'vol' in indexs:
            histList['vol5'] = talib.MA(histList['vol'].values, timeperiod=5).round(3)
            histList['vol10'] = talib.MA(histList['vol'].values, timeperiod=10).round(3)
            histList['vol20'] = talib.MA(histList['vol'].values, timeperiod=20).round(3)
            histList['vol_zoom'] = (histList['vol'] / histList['vol5'] * 1.0).round(3)
            histList['vol5_vol10_cross_status'] = 0
            volumePosList = histList['vol5'] >= histList['vol10']
            histList.loc[volumePosList[(volumePosList == True) & (volumePosList.shift() == False)].index, 'vol5_vol10_cross_status'] = 1
            histList.loc[volumePosList[(volumePosList == False) & (volumePosList.shift() == True)].index, 'vol5_vol10_cross_status'] = -1
            del volumePosList
            histList['vol5_vol20_cross_status'] = 0
            volumePosList = histList['vol5'] >= histList['vol20']
            histList.loc[volumePosList[(volumePosList == True) & (volumePosList.shift() == False)].index, 'vol5_vol20_cross_status'] = 1
            histList.loc[volumePosList[(volumePosList == False) & (volumePosList.shift() == True)].index, 'vol5_vol20_cross_status'] = -1
            del volumePosList
            histList['vol10_vol20_cross_status'] = 0
            volumePosList = histList['vol10'] >= histList['vol20']
            histList.loc[volumePosList[(volumePosList == True) & (volumePosList.shift() == False)].index, 'vol10_vol20_cross_status'] = 1
            histList.loc[volumePosList[(volumePosList == False) & (volumePosList.shift() == True)].index, 'vol10_vol20_cross_status'] = -1
            #histList[['vol5_vol10_cross_status', 'vol5_vol20_cross_status', 'vol10_vol20_cross_status']] = histList[['vol5_vol10_cross_status', 'vol5_vol20_cross_status', 'vol10_vol20_cross_status']].fillna(method='pad')

        #bbi计算
        if 'bbi' in indexs:
            ma3 = talib.MA(histList['close'].values, timeperiod=3)
            ma6 = talib.MA(histList['close'].values, timeperiod=6)
            ma12 = talib.MA(histList['close'].values, timeperiod=12)
            ma24 = talib.MA(histList['close'].values, timeperiod=24)
            histList['bbi'] = (ma3 + ma6 + ma12 + ma24) / 4
            histList['bbi'] = histList['bbi'].round(3)

        #SAR计算
        if 'sar' in indexs:
            sarList = talib.SAR(histList['high'].values, histList['low'].values, acceleration=0.04, maximum=0.2)
            histList['sar'] = sarList.round(3)
            histList['sar_cross_status'] = 0
            sarPosList = histList['close'] >= histList['sar']
            histList.loc[sarPosList[(sarPosList == True) & (sarPosList.shift() == False)].index, 'sar_cross_status'] = 1
            histList.loc[sarPosList[(sarPosList == False) & (sarPosList.shift() == True)].index, 'sar_cross_status'] = -1

        #计算TRIX
        if 'trix' in indexs:
            histList['trix'] = np.nan
            histList['trma'] = np.nan
            histList['trix_diff'] = np.nan
            try:
                trix = talib.TRIX(histList['close'].values, 12)
                trma = talib.MA(trix, timeperiod=20)
                histList['trix'] = trix.round(3)
                histList['trma'] = trma.round(3)
                histList['trix_diff'] = histList['trix'] - histList['trma']
                histList['trix_cross_status'] = 0
                trixPosList = histList['trix'] >= histList['trma']
                histList.loc[trixPosList[(trixPosList == True) & (trixPosList.shift() == False)].index, 'trix_cross_status'] = 1
                histList.loc[trixPosList[(trixPosList == False) & (trixPosList.shift() == True)].index, 'trix_cross_status'] = -1
                #histList[['trix_cross_status']] = histList[['trix_cross_status']].fillna(method='pad')
            except:
                pass

        if 'cyc' in indexs:
            avePrice = histList['amount'] / (histList['vol'] * 100)
            histList['cyc5'] = talib.MA(avePrice.values, timeperiod=5).round(3)
            histList['cyc13'] = talib.MA(avePrice.values, timeperiod=13).round(3)
            histList['cyc34'] = talib.MA(avePrice.values, timeperiod=34).round(3)
            #histList['cycx'] = talib.EMA(histList['close'].values, timeperiod=histList['vol'].values * 100 / circulate).round(3)
            histList['cyc5_cyc13_cross_status'] = 0
            volumePosList = histList['cyc5'] >= histList['cyc13']
            histList.loc[volumePosList[(volumePosList == True) & (volumePosList.shift() == False)].index, 'cyc5_cyc13_cross_status'] = 1
            histList.loc[volumePosList[(volumePosList == False) & (volumePosList.shift() == True)].index, 'cyc5_cyc13_cross_status'] = -1
            del volumePosList
            histList['cyc13_cyc34_cross_status'] = 0
            volumePosList = histList['cyc13'] >= histList['cyc34']
            histList.loc[volumePosList[(volumePosList == True) & (volumePosList.shift() == False)].index, 'cyc13_cyc34_cross_status'] = 1
            histList.loc[volumePosList[(volumePosList == False) & (volumePosList.shift() == True)].index, 'cyc13_cyc34_cross_status'] = -1
            del volumePosList

        if 'boll' in indexs:
            up, mid, low = talib.BBANDS(
                histList['close'].values,
                timeperiod=20,
                # number of non-biased standard deviations from the mean
                nbdevup=2,
                nbdevdn=2,
                # Moving average type: simple moving average here
                matype=0)
            histList['boll_up'] = up.round(3)
            histList['boll_mid'] = mid.round(3)
            histList['boll_low'] = low.round(3)


        return histList

    #整理开始,结束时间,并计算相差天数
    def _getDate(self, start, end):
        if not end:
            end = time.strftime('%Y-%m-%d',time.localtime())

        if not start:
            t = int(time.mktime(time.strptime(str(end), '%Y-%m-%d'))) - 86400 * 800
            start = str(time.strftime('%Y-%m-%d',time.localtime(t)))

        startTimestamp = int(time.mktime(time.strptime(str(start), '%Y-%m-%d')))
        endTimestamp = int(time.mktime(time.strptime(str(end), '%Y-%m-%d')))
        diffDayNum = int((time.time() - startTimestamp) / 86400)
        if diffDayNum <= 0:
            diffDayNum = 1

        return start, end, diffDayNum

    #得到市场代码
    def getMarketCode(self, code):
        code = str(code)
        if code[0] in ['5', '6', '9'] or code[:3] in ["009", "126", "110", "201", "202", "203", "204"]:
            return 1
        return 0

    #时间整理
    def _dateStamp(self, date):
        datestr = str(date)[0:10]
        date = time.mktime(time.strptime(datestr, '%Y-%m-%d'))
        return date

    #整理时间
    def _timeStamp(self, _time):
        if len(str(_time)) == 10:
            # yyyy-mm-dd格式
            return time.mktime(time.strptime(_time, '%Y-%m-%d'))
        elif len(str(_time)) == 16:
                # yyyy-mm-dd hh:mm格式
            return time.mktime(time.strptime(_time, '%Y-%m-%d %H:%M'))
        else:
            timestr = str(_time)[0:19]
            return time.mktime(time.strptime(timestr, '%Y-%m-%d %H:%M:%S'))


    #得到除权信息
    def getXdxr(self, code):

        self._connect()

        category = {
            '1': '除权除息', '2': '送配股上市', '3': '非流通股上市', '4': '未知股本变动', '5': '股本变化',
            '6': '增发新股', '7': '股份回购', '8': '增发新股上市', '9': '转配股上市', '10': '可转债上市',
            '11': '扩缩股', '12': '非流通股缩股', '13':  '送认购权证', '14': '送认沽权证'}

        data = self.__tdx.to_df(self.__tdx.get_xdxr_info(self.getMarketCode(code), code))

        if len(data) >= 1:
            data = data\
                .assign(date=pd.to_datetime(data[['year', 'month', 'day']], format='%Y-%m-%d'))\
                .drop(['year', 'month', 'day'], axis=1)\
                .assign(category_meaning=data['category'].apply(lambda x: category[str(x)]))\
                .assign(code=str(code))\
                .rename(index=str, columns={'panhouliutong': 'liquidity_after',
                                            'panqianliutong': 'liquidity_before', 'houzongguben': 'shares_after',
                                            'qianzongguben': 'shares_before'})\
                .set_index('date', drop=False, inplace=False)

            xdxrData = data.assign(date=data['date'].apply(lambda x: str(x)[0:10]))
            #xdxrData = xdxrData.set_index('date')
            self._xdxrData = xdxrData
            return xdxrData
        else:
            return None


    #得到股本
    def getGuben(self, code):
        self._connect()

        if self._xdxrData is None:
            xdxrData = self.getXdxr(code)
        else:
            xdxrData = self._xdxrData
        info = xdxrData[xdxrData['liquidity_after'] > 0][['liquidity_after', 'shares_after']]

        circulate = info['liquidity_after'].values[-1] * 10000
        capital = info['shares_after'].values[-1] * 10000

        return capital,circulate


    #按天得到标准数据
    '''
    ktype = D(天)/W(周)/M(月)/Q(季)/Y(年)
    autype = bfq(不复权)/hfq(后复权)/qfq(前复权)
    '''
    def getDays(self, code, ktype = 'D', start = '', end = '', autype = 'qfq', indexs = ['turnover', 'vol', 'ma', 'macd', 'kdj', 'cci', 'bbi', 'sar', 'trix']):
        startDate, endDate, diffDayNum = self._getDate(start, end)

        self._connect()

        ktypeCode = 9
        if ktype.lower() == 'd':
            ktypeCode = 9
        elif ktype.lower() == 'w':
            ktypeCode = 5
        elif ktype.lower() == 'm':
            ktypeCode = 6
        elif ktype.lower() == 'q':
            ktypeCode = 10
        elif ktype.lower() == 'y':
            ktypeCode = 11

        histList = pd.concat([self.__tdx.to_df(self.__tdx.get_security_bars(ktypeCode, self.getMarketCode(code), code, (int(diffDayNum / 800) - i) * 800, 800)) for i in range(int(diffDayNum / 800) + 1)], axis=0)

        if histList.empty:
            return None

        histList = histList[histList['open'] != 0]
        histList = histList[histList['vol'] > 1]

        if not autype or autype == 'bfq':
            histList = histList.assign(date=histList['datetime'].apply(lambda x: str(x[0:10]))).assign(code=str(code))\
                    .assign(date_stamp=histList['datetime'].apply(lambda x: self._dateStamp(str(x)[0:10])))

            histList = histList.drop(['year', 'month', 'day', 'hour', 'minute', 'datetime', 'date_stamp'], axis=1)
            histList = histList.set_index('date')
            histList = histList[startDate:endDate]
            self._lastBaseHistList = histList

            histList['p_change'] = histList['close'].pct_change().round(5) * 100

            if indexs:
                return self.getData(indexs=indexs)
            else:
                return histList

        elif autype == 'qfq':

            bfqData = histList.assign(date=pd.to_datetime(histList['datetime'].apply(lambda x: str(x[0:10])))).assign(code=str(code))\
                .assign(date_stamp=histList['datetime'].apply(lambda x: self._dateStamp(str(x)[0:10])))
            bfqData = bfqData.set_index('date')

            bfqData = bfqData.drop(
                ['year', 'month', 'day', 'hour', 'minute', 'datetime'], axis=1)

            xdxrData = self.getXdxr(code)
            if xdxrData is not None:
                info = xdxrData[xdxrData['category'] == 1]
                bfqData['if_trade'] = True
                data = pd.concat([bfqData, info[['category']]
                                  [bfqData.index[0]:]], axis=1)

                #data['date'] = data.index
                data['if_trade'].fillna(value=False, inplace=True)
                data = data.fillna(method='ffill')
                data = pd.concat([data, info[['fenhong', 'peigu', 'peigujia',
                                              'songzhuangu']][bfqData.index[0]:]], axis=1)
                data = data.fillna(0)

                data['preclose'] = (data['close'].shift(1) * 10 - data['fenhong'] + data['peigu']
                                    * data['peigujia']) / (10 + data['peigu'] + data['songzhuangu'])
                data['adj'] = (data['preclose'].shift(-1) /
                               data['close']).fillna(1)[::-1].cumprod()
                data['open'] = data['open'] * data['adj']
                data['high'] = data['high'] * data['adj']
                data['low'] = data['low'] * data['adj']
                data['close'] = data['close'] * data['adj']
                data['preclose'] = data['preclose'] * data['adj']
                data = data[data['if_trade']]

                histList = data.drop(['fenhong', 'peigu', 'peigujia', 'songzhuangu', 'if_trade', 'category', 'preclose', 'date_stamp', 'adj'], axis=1)
                histList = histList[startDate:endDate]
                self._lastBaseHistList = histList

                histList['p_change'] = histList['close'].pct_change().round(5) * 100

                if indexs:
                    return self.getData(indexs=indexs)
                else:
                    return histList
            else:
                bfqData['preclose'] = bfqData['close'].shift(1)
                bfqData['adj'] = 1

                histList = bfqData.drop(['preclose', 'date_stamp', 'adj'], axis=1)
                histList = histList[startDate:endDate]
                self._lastBaseHistList = histList

                if indexs:
                    return self.getData(indexs=indexs)
                else:
                    return histList

        elif autype == 'hfq':
            xdxrData = self.getXdxr(code)

            info = xdxrData[xdxrData['category'] == 1]

            bfqData = histList.assign(date=histList['datetime'].apply(lambda x: x[0:10])).assign(code=str(code))\
                .assign(date_stamp=histList['datetime'].apply(lambda x: self._dateStamp(str(x)[0:10])))
            bfqData = bfqData.set_index('date')

            bfqData = bfqData.drop(
                ['year', 'month', 'day', 'hour', 'minute', 'datetime'], axis=1)

            bfqData['if_trade'] = True
            data = pd.concat([bfqData, info[['category']]
                              [bfqData.index[0]:]], axis=1)

            data['if_trade'].fillna(value=False, inplace=True)
            data = data.fillna(method='ffill')
            data = pd.concat([data, info[['fenhong', 'peigu', 'peigujia',
                                          'songzhuangu']][bfqData.index[0]:]], axis=1)
            data = data.fillna(0)

            data['preclose'] = (data['close'].shift(1) * 10 - data['fenhong'] + data['peigu']
                                * data['peigujia']) / (10 + data['peigu'] + data['songzhuangu'])
            data['adj'] = (data['preclose'].shift(-1) /
                           data['close']).fillna(1).cumprod()
            data['open'] = data['open'] / data['adj']
            data['high'] = data['high'] / data['adj']
            data['low'] = data['low'] / data['adj']
            data['close'] = data['close'] / data['adj']
            data['preclose'] = data['preclose'] / data['adj']
            data = data[data['if_trade']]

            histList = data.drop(['fenhong', 'peigu', 'peigujia', 'songzhuangu', 'if_trade', 'category', 'preclose', 'date_stamp', 'adj'], axis=1)
            histList = histList[startDate:endDate]
            self._lastBaseHistList = histList

            histList['p_change'] = histList['close'].pct_change().round(5) * 100

            if indexs:
                return self.getData(indexs=indexs)
            else:
                return histList

    #按分钟得到标准数据
    '''
    ktype = 1/5/15/30/60  分钟
    '''
    def getMins(self, code, ktype = 1, start = '', end = '', indexs=['vol', 'ma', 'macd', 'kdj', 'cci', 'bbi', 'sar', 'trix']):
        startDate, endDate, diffDayNum = self._getDate(start, end)

        self._connect()

        ktypeCode = 8
        if int(ktype) == 1:
            ktypeCode = 8
            diffDayNum = 240 * diffDayNum
        elif int(ktype) == 5:
            ktypeCode = 0
            diffDayNum = 48 * diffDayNum
        elif int(ktype) == 15:
            ktypeCode = 1
            diffDayNum = 16 * diffDayNum
        elif int(ktype) == 30:
            ktypeCode = 2
            diffDayNum = 8 * diffDayNum
        elif int(ktype) == 60:
            ktypeCode = 3
            diffDayNum = 4 * diffDayNum

        if diffDayNum > 20800:
            diffDayNum = 20800

        histList = pd.concat([self.__tdx.to_df(self.__tdx.get_security_bars(ktypeCode, self.getMarketCode(
            str(code)), str(code), (int(diffDayNum / 800) - i) * 800, 800)) for i in range(int(diffDayNum / 800) + 1)], axis=0)

        if histList.empty:
            return None

        histList = histList\
            .assign(datetime=pd.to_datetime(histList['datetime']), code=str(code))\
            .assign(date=histList['datetime'].apply(lambda x: str(x)[0:10]))\
            .assign(date_stamp=histList['datetime'].apply(lambda x: self._dateStamp(x)))\
            .assign(time_stamp=histList['datetime'].apply(lambda x: self._timeStamp(x)))

        histList['date'] = histList['datetime']
        histList = histList.drop(['year', 'month', 'day', 'hour', 'minute', 'datetime', 'date_stamp', 'time_stamp'], axis=1)
        histList = histList.set_index('date')
        histList = histList[startDate:endDate]
        self._lastBaseHistList = histList

        histList['p_change'] = histList['close'].pct_change().round(5) * 100
        histList['vol'] = histList['vol'] / 100.0

        if indexs:
            return self.getData(indexs=indexs)
        else:
            return histList


    #按天得到指数日k线
    '''
    ktype = D(天)/W(周)/M(月)/Q(季)/Y(年)
    '''
    def getIndexDays(self, code, ktype = 'D', start = '', end = '', indexs=['turnover', 'vol', 'ma', 'macd', 'kdj', 'cci', 'bbi', 'sar', 'trix']):
        startDate, endDate, diffDayNum = self._getDate(start, end)

        self._connect()

        ktypeCode = 9
        if ktype.lower() == 'd':
            ktypeCode = 9
        elif ktype.lower() == 'w':
            ktypeCode = 5
        elif ktype.lower() == 'm':
            ktypeCode = 6
        elif ktype.lower() == 'q':
            ktypeCode = 10
        elif ktype.lower() == 'y':
            ktypeCode = 11

        if str(code)[0] in ['5', '1']:  # ETF
            data = pd.concat([self.__tdx.to_df(self.__tdx.get_security_bars(
                ktypeCode, 1 if str(code)[0] in ['0', '8', '9', '5'] else 0, code, (int(diffDayNum / 800) - i) * 800, 800)) for i in range(int(diffDayNum / 800) + 1)], axis=0)
        else:
            data = pd.concat([self.__tdx.to_df(self.__tdx.get_index_bars(
                ktypeCode, 1 if str(code)[0] in ['0', '8', '9', '5'] else 0, code, (int(diffDayNum / 800) - i) * 800, 800)) for i in range(int(diffDayNum / 800) + 1)], axis=0)

        histList = data.assign(date=data['datetime'].apply(lambda x: str(x[0:10]))).assign(code=str(code))\
            .assign(date_stamp=data['datetime'].apply(lambda x: self._dateStamp(str(x)[0:10])))\
            .assign(code=code)

        if histList.empty:
            return None

        histList = histList.drop(['year', 'month', 'day', 'hour', 'minute', 'datetime', 'date_stamp', 'up_count', 'down_count'], axis=1)
        histList = histList.set_index('date')
        histList = histList[startDate:endDate]
        self._lastBaseHistList = histList

        histList['p_change'] = histList['close'].pct_change().round(5) * 100

        if indexs:
            return self.getData(indexs=indexs)
        else:
            return histList

    #按分钟得到标准数据
    '''
    ktype = 1/5/15/30/60  分钟
    '''
    def getIndexMins(self, code, ktype = 1, start = '', end = '', indexs=['vol', 'ma', 'macd', 'kdj', 'cci', 'bbi', 'sar', 'trix']):
        startDate, endDate, diffDayNum = self._getDate(start, end)

        self._connect()

        ktypeCode = 8
        if int(ktype) == 1:
            ktypeCode = 8
            diffDayNum = 240 * diffDayNum
        elif int(ktype) == 5:
            ktypeCode = 0
            diffDayNum = 48 * diffDayNum
        elif int(ktype) == 15:
            ktypeCode = 1
            diffDayNum = 16 * diffDayNum
        elif int(ktype) == 30:
            ktypeCode = 2
            diffDayNum = 8 * diffDayNum
        elif int(ktype) == 60:
            ktypeCode = 3
            diffDayNum = 4 * diffDayNum

        if diffDayNum > 20800:
            diffDayNum = 20800

        if str(code)[0] in ['5', '1']:  # ETF
            data = pd.concat([self.__tdx.to_df(self.__tdx.get_security_bars(
                ktypeCode, 1 if str(code)[0] in ['0', '8', '9', '5'] else 0, code, (int(diffDayNum / 800) - i) * 800, 800)) for i in range(int(diffDayNum / 800) + 1)], axis=0)
        else:
            data = pd.concat([self.__tdx.to_df(self.__tdx.get_index_bars(
                ktypeCode, 1 if str(code)[0] in ['0', '8', '9', '5'] else 0, code, (int(diffDayNum / 800) - i) * 800, 800)) for i in range(int(diffDayNum / 800) + 1)], axis=0)

        histList = data.assign(datetime=pd.to_datetime(data['datetime']), code=str(code))\
            .assign(date=data['datetime'].apply(lambda x: str(x)[0:10]))\
            .assign(date_stamp=data['datetime'].apply(lambda x: self._dateStamp(x)))\
            .assign(time_stamp=data['datetime'].apply(lambda x: self._timeStamp(x)))

        if histList.empty:
            return None

        histList['date'] = histList['datetime']
        histList = histList.drop(['year', 'month', 'day', 'hour', 'minute', 'datetime', 'date_stamp', 'time_stamp', 'up_count', 'down_count'], axis=1)
        histList = histList.set_index('date')
        histList = histList[startDate:endDate]
        self._lastBaseHistList = histList

        histList['p_change'] = histList['close'].pct_change().round(5) * 100

        if indexs:
            return self.getData(indexs=indexs)
        else:
            return histList

    #实时逐笔
    '''
    0买 1卖 2中性
    '''
    def getRealtimeTransaction(self, code):
        self._connect()

        try:
            data = pd.concat([self.__tdx.to_df(self.__tdx.get_transaction_data(
                self.getMarketCode(str(code)), code, (2 - i) * 2000, 2000)) for i in range(3)], axis=0)
            if 'value' in data.columns:
                data = data.drop(['value'], axis=1)
            data = data.dropna()
            day = datetime.date.today()
            histList = data.assign(date=str(day)).assign(datetime=pd.to_datetime(data['time'].apply(lambda x: str(day) + ' ' + str(x))))\
                .assign(code=str(code)).assign(order=range(len(data.index)))

            histList['money'] = histList['price'] * histList['vol'] * 100
            histList['type'] = histList['buyorsell']
            histList['type'].replace([0,1,2], ['B','S','N'], inplace = True)

            histList = histList.drop(['order', 'buyorsell'], axis=1).reset_index()
            return histList
        except:
            return None

    #历史逐笔
    '''
    0买 1卖 2中性
    '''
    def getHistoryTransaction(self, code, date):
        self._connect()

        try:
            data = pd.concat([self.__tdx.to_df(self.__tdx.get_history_transaction_data(
                self.getMarketCode(str(code)), code, (2 - i) * 2000, 2000, int(str(date).replace('-', '')))) for i in range(3)], axis=0)
            if 'value' in data.columns:
                data = data.drop(['value'], axis=1)
            data = data.dropna()
            #day = datetime.date.today()
            day = date
            histList = data.assign(date=str(day)).assign(datetime=pd.to_datetime(data['time'].apply(lambda x: str(day) + ' ' + str(x))))\
                .assign(code=str(code)).assign(order=range(len(data.index)))

            histList['money'] = histList['price'] * histList['vol'] * 100
            histList['type'] = histList['buyorsell']
            histList['type'].replace([0,1,2], ['B','S','N'], inplace = True)

            histList = histList.drop(['order', 'buyorsell'], axis=1).reset_index()
            return histList
        except:
            return None

    #实时分时数据
    def getRealtimeMinuteTime(self, code):
        self._connect()

        date = str(time.strftime('%Y-%m-%d',time.localtime()))

        morningData = pd.date_range(start=str(date) + ' 09:31', end=str(date) + ' 11:30', freq = 'min')
        morningDF = pd.DataFrame(index=morningData)


        afternoonData = pd.date_range(start=str(date) + ' 13:01',end=str(date) + ' 15:00', freq = 'min')
        afternoonDF = pd.DataFrame(index=afternoonData)
        timeData = morningDF.append(afternoonDF)

        histList = self.__tdx.to_df(self.__tdx.get_minute_time_data(
                self.getMarketCode(str(code)), code))

        #非标准均价计算
        money = histList['price'] * histList['vol'] * 100
        histList['money'] = money.round(2)
        totalMoney = money.cumsum()
        totalVol = histList['vol'].cumsum()
        histList['ave'] = totalMoney / (totalVol * 100)
        histList['ave'] = histList['ave'].round(3)

        histList['datetime'] = timeData.index[0:len(histList)]
        histList['date'] = histList['datetime'].apply(lambda x: x.strftime('%Y-%m-%d'))
        histList['time'] = histList['datetime'].apply(lambda x: x.strftime('%H:%M'))

        histList = histList.reset_index()

        return histList

    #历史分时数据
    def getHistoryMinuteTime(self, code, date):
        self._connect()

        morningData = pd.date_range(start=str(date) + ' 09:31', end=str(date) + ' 11:30', freq = 'min')
        morningDF = pd.DataFrame(index=morningData)


        afternoonData = pd.date_range(start=str(date) + ' 13:01',end=str(date) + ' 15:00', freq = 'min')
        afternoonDF = pd.DataFrame(index=afternoonData)
        timeData = morningDF.append(afternoonDF)

        histList = self.__tdx.to_df(self.__tdx.get_history_minute_time_data(
                self.getMarketCode(str(code)), code, int(str(date).replace('-', ''))))

        #非标准均价计算
        money = histList['price'] * histList['vol'] * 100
        histList['money'] = money.round(2)
        totalMoney = money.cumsum()
        totalVol = histList['vol'].cumsum()
        histList['ave'] = totalMoney / (totalVol * 100)
        histList['ave'] = histList['ave'].round(3)

        histList['datetime'] = timeData.index[0:len(histList)]
        histList['date'] = histList['datetime'].apply(lambda x: x.strftime('%Y-%m-%d'))
        histList['time'] = histList['datetime'].apply(lambda x: x.strftime('%H:%M'))

        histList = histList.reset_index()

        return histList


    #实时报价(五档行情)
    '''
    market => 市场
    active1 => 活跃度
    price => 现价
    last_close => 昨收
    open => 开盘
    high => 最高
    low => 最低
    reversed_bytes0 => 保留
    reversed_bytes1 => 保留
    vol => 总量
    cur_vol => 现量
    amount => 总金额
    s_vol => 内盘
    b_vol => 外盘
    reversed_bytes2 => 保留
    reversed_bytes3 => 保留
    bid1 => 买一价
    ask1 => 卖一价
    bid_vol1 => 买一量
    ask_vol1 => 卖一量
    bid2 => 买二价
    ask2 => 卖二价
    bid_vol2 => 买二量
    ask_vol2 => 卖二量
    bid3 => 买三价
    ask3 => 卖三价
    bid_vol3 => 买三量
    ask_vol3 => 卖三量
    bid4 => 买四价
    ask4 => 卖四价
    bid_vol4 => 买四量
    ask_vol4 => 卖四量
    bid5 => 买五价
    ask5 => 卖五价
    bid_vol5 => 买五量
    ask_vol5 => 卖五量
    reversed_bytes4 => 保留
    reversed_bytes5 => 保留
    reversed_bytes6 => 保留
    reversed_bytes7 => 保留
    reversed_bytes8 => 保留
    reversed_bytes9 => 涨速
    active2 => 活跃度
    '''
    def getRealtimeQuotes(self, codeList):
        self._connect()

        itemList = []
        for item in codeList:
            itemList.append((self.getMarketCode(item), item))

        histList = self.__tdx.to_df(self.__tdx.get_security_quotes(itemList))
        histList = histList.set_index('code')

        return histList

    #计算指定日期成交量细节
    def getVolAnalysis(self, code, date):
        self._connect()

        if str(time.strftime('%Y-%m-%d',time.localtime())) == str(date):
            if int(time.strftime('%H%M',time.localtime())) > 1600:
                volList = self.getHistoryTransaction(code, date)
            else:
                volList = self.getRealtimeTransaction(code)
        else:
            volList = self.getHistoryTransaction(code, date)

        if volList is None:
            return None

        guben,circulate = self.getGuben(code)

        if not self._lastBaseHistList.empty:
            histList = self._lastBaseHistList.copy()
        else:
            histList = self.getDays(code, end=date, indexs=[])

        #涨停单数量
        limitVol = round(histList[-5:]['vol'].mean() * 0.0618)
        #超大单,先转成市值,再转回成手数
        superVol = float(circulate) * float(histList['close'].values[-1]) * 0.000618 / float(histList['close'].values[-1]) / 100
        #大单
        bigVol = round(superVol * 0.518)
        #中单
        middleVol = round(superVol * 0.382)
        #小单
        smallVol = round(superVol * 0.191)

        #买单统计
        buyVolList = volList[volList['type'] == 'B']
        totalBuyVolNum = buyVolList['vol'].sum()
        mainBuyVolNum = buyVolList[buyVolList['vol'] >= bigVol]['vol'].sum()
        limitBuyVolNum = math.ceil(buyVolList[(buyVolList['vol'] >= limitVol)]['vol'].sum() / limitVol)
        superBuyVolNum = math.ceil(buyVolList[(buyVolList['vol'] < limitVol) & (buyVolList['vol'] >= superVol)]['vol'].sum() / superVol)
        bigBuyVolNum = math.ceil(buyVolList[(buyVolList['vol'] < superVol) & (buyVolList['vol'] >= bigVol)]['vol'].sum() / bigVol)
        middleBuyVolNum = math.ceil(buyVolList[(buyVolList['vol'] < bigVol) & (buyVolList['vol'] >= middleVol)]['vol'].sum() / middleVol)
        smallBuyVolNum = math.ceil(buyVolList[(buyVolList['vol'] < middleVol) & (buyVolList['vol'] >= smallVol)]['vol'].sum() / smallVol)
        microBuyVolNum = len(buyVolList[(buyVolList['vol'] < smallVol)])
        #print limitBuyVolNum,superBuyVolNum,bigBuyVolNum,middleBuyVolNum,smallBuyVolNum,microBuyVolNum

        #卖单统计
        sellVolList = volList[volList['type'] == 'S']
        totalSellVolNum = sellVolList['vol'].sum()
        mainSellVolNum = sellVolList[sellVolList['vol'] >= bigVol]['vol'].sum()
        limitSellVolNum = math.ceil(sellVolList[(sellVolList['vol'] >= limitVol)]['vol'].sum() / limitVol)
        superSellVolNum = math.ceil(sellVolList[(sellVolList['vol'] < limitVol) & (sellVolList['vol'] >= superVol)]['vol'].sum() / superVol)
        bigSellVolNum = math.ceil(sellVolList[(sellVolList['vol'] < superVol) & (sellVolList['vol'] >= bigVol)]['vol'].sum() / bigVol)
        middleSellVolNum = math.ceil(sellVolList[(sellVolList['vol'] < bigVol) & (sellVolList['vol'] >= middleVol)]['vol'].sum() / middleVol)
        smallSellVolNum = math.ceil(sellVolList[(sellVolList['vol'] < middleVol) & (sellVolList['vol'] >= smallVol)]['vol'].sum() / smallVol)
        microSellVolNum = len(sellVolList[(sellVolList['vol'] < smallVol)])
        #print limitSellVolNum,superSellVolNum,bigSellVolNum,middleSellVolNum,smallSellVolNum,microSellVolNum

        #计算吸筹线
        #主力标准吸筹金额
        mainBaseMoney = round(histList['close'].values[-1] * circulate * 0.001 / 10000 / 10000, 4)
        #主力强力吸筹金额
        mainBigMoney = round(histList['close'].values[-1] * circulate * 0.003 / 10000 / 10000, 4)

        #资金统计
        totalMoney = round(volList['money'].sum() / 10000 / 10000, 4)
        totalBuyMoney = round(buyVolList['money'].sum() / 10000 / 10000, 4)
        totalSellMoney = round(sellVolList['money'].sum() / 10000 / 10000, 4)
        totalAbsMoney = round(totalBuyMoney - totalSellMoney, 3)
        mainMoney = round(volList[volList['vol'] >= bigVol]['money'].sum() / 10000 / 10000, 4)
        mainBuyMoney = round(buyVolList[buyVolList['vol'] >= bigVol]['money'].sum() / 10000 / 10000, 4)
        mainSellMoney = round(sellVolList[sellVolList['vol'] >= bigVol]['money'].sum() / 10000 / 10000, 4)
        mainAbsMoney = round(mainBuyMoney - mainSellMoney, 3)

        mainRate = 0
        try:
            mainRate = round((mainBuyMoney + mainSellMoney) / totalMoney * 100, 2)
        except:
            pass

        mainBuyRate = 0
        try:
            mainBuyRate = round(mainBuyMoney / (mainBuyMoney + mainSellMoney) * 100, 2)
        except:
            pass
        #print totalAbsMoney,mainAbsMoney,totalMoney,totalBuyMoney,totalSellMoney,mainBuyMoney,mainSellMoney,mainRate,mainBuyRate

        #成交笔数
        volNum = len(volList)

        #平均每笔交易价格
        aveTradePrice = round(totalMoney / volNum * 10000 * 10000, 2)

        #平均每股买价格
        avePerShareBuyPrice = 0
        try:
            avePerShareBuyPrice = round(totalBuyMoney * 10000 * 10000 / (totalBuyVolNum * 100), 3)
        except:
            pass

        #主力平均每股买价格
        mainAvePerShareBuyPrice = 0
        try:
            mainAvePerShareBuyPrice = round(mainBuyMoney * 10000 * 10000 / (mainBuyVolNum * 100), 3)
        except:
            pass

        #平均每股卖价格
        avePerShareSellPrice = 0
        try:
            avePerShareSellPrice = round(totalSellMoney * 10000 * 10000 / (totalSellVolNum * 100), 3)
        except:
            pass

        #主力平均每股卖价格
        mainAvePerShareSellPrice = 0
        try:
            mainAvePerShareSellPrice = round(mainSellMoney * 10000 * 10000 / (mainSellVolNum * 100), 3)
        except:
            pass

        #print totalMoney,volNum,aveVolPrice * 10000 * 10000
        statData = {}
        statData['limit_buy_vol_num'] = limitBuyVolNum
        statData['super_buy_vol_num'] = superBuyVolNum
        statData['big_buy_vol_num'] = bigBuyVolNum
        statData['middle_buy_vol_num'] = middleBuyVolNum
        statData['small_buy_vol_num'] = smallBuyVolNum
        statData['micro_buy_vol_num'] = microBuyVolNum
        statData['limit_sell_vol_num'] = limitSellVolNum
        statData['super_sell_vol_num'] = superSellVolNum
        statData['big_sell_vol_num'] = bigSellVolNum
        statData['middle_sell_vol_num'] = middleSellVolNum
        statData['small_sell_vol_num'] = smallSellVolNum
        statData['micro_sell_vol_num'] = microSellVolNum
        statData['total_abs_money'] = totalAbsMoney
        statData['main_abs_money'] = mainAbsMoney
        statData['total_money'] = totalMoney
        statData['total_buy_money'] = totalBuyMoney
        statData['total_sell_money'] = totalSellMoney
        statData['main_money'] = mainMoney
        statData['main_buy_money'] = mainBuyMoney
        statData['main_sell_money'] = mainSellMoney
        statData['main_rate'] = mainRate
        statData['main_buy_rate'] = mainBuyRate
        statData['trade_num'] = volNum
        statData['vol_num'] = volList['vol'].sum()
        statData['ave_trade_price'] = aveTradePrice
        statData['main_base_money'] = mainBaseMoney
        statData['main_big_money'] = mainBigMoney
        statData['ave_per_share_buy_price'] = avePerShareBuyPrice
        statData['ave_per_share_sell_price'] = avePerShareSellPrice
        statData['main_ave_per_share_buy_price'] = mainAvePerShareBuyPrice
        statData['main_ave_per_share_sell_price'] = mainAvePerShareSellPrice
        statData['circulate_money'] = round(circulate * histList['close'].values[-1] / 10000 / 10000, 4)

        return statData

    #输出ebk文件
    def outputEbk(self, stockList, ebkPath = ''):

        if len(ebkPath) <= 0:
            ebkPath = os.getcwd() + '/' + sys.argv[0][0:-3] + '.' + str(time.strftime('%Y%m%d',time.localtime())) + '.ebk'

        if not isinstance(stockList,list):
            return False

        fp = open(ebkPath, "a")
        fp.write('\r\n')    #ebk第一行为空行
        for code in stockList:
            if self.getMarketCode(code) == 1:
                fp.write('1' + code)
            else:
                fp.write('0' + code)

            fp.write('\r\n')

        fp.close()

        return True

    #输出sel文件
    def outputSel(self, stockList, selPath = ''):
        import struct

        if len(selPath) <= 0:
            selPath = os.getcwd() + '/' + sys.argv[0][0:-3] + '.' + str(time.strftime('%Y%m%d',time.localtime())) + '.sel'

        if not isinstance(stockList,list):
            return False

        stocks = []
        for code in stockList:
            if self.getMarketCode(code) == 1:
                stocks.append('\x07\x11' + code)
            else:
                stocks.append('\x07\x21' + code)

        with open(selPath, 'ab') as fp:
            data = struct.pack('H', len(stocks)).decode() + ''.join(stocks)
            fp.write(data.encode())

        return True

    #ebk to sel
    def ebk2sel(self, ebkPath):
        import struct

        if not os.path.exists(ebkPath):
            return False

        selPath = ebkPath.replace('.ebk', '.sel')

        stocks = []
        with open(ebkPath, 'r') as ebkfp:
            for code in ebkfp:
                code = code.strip()
                if len(code) > 0:
                    if self.getMarketCode(code[1:]) == 1:
                        stocks.append('\x07\x11' + code[1:])
                    else:
                        stocks.append('\x07\x21' + code[1:])

        with open(selPath, 'wb') as selfp:
            data = struct.pack('H', len(stocks)).decode() + ''.join(stocks)
            selfp.write(data.encode())


        return True

    #sel to ebk
    def sel2ebk(self, selPath):
        import struct

        if not os.path.exists(selPath):
            return False

        ebkPath = selPath.replace('.sel', '.ebk')

        with open(selPath, 'rb') as selfp:
            ebkfp = open(ebkPath, "a")
            cnt = struct.unpack('<H', selfp.read(2))[0]
            for _ in range(cnt):
                data = selfp.readline(8).decode()
                exch = '1' if data[1] == '\x11' else '0'
                code = exch + data[2:]

                ebkfp.write(code + '\r\n')

            ebkfp.close()

        return True
Exemplo n.º 3
0
class TDX(object):
    '''
    This class is tong da xin data source.
    We can use it to get down the stock datas.
    Tushare can't get minter line and or year line.
    TDX can search index of stock and funds.
    '''
    def __init__(self):
        self.tdx_api = TdxHq_API()
        self.__ip = '119.147.212.81'  #输入IP
        self.__port = 7709  #端口
        self.__code = '600200'
        self.__market = 1  #市场代码 0:深圳,1:上海
        self._startdate = "2017-01-01"
        self.today = datetime.date.today()
        self._enddate = datetime.datetime.strftime(self.today, '%Y-%m-%d')

        self.__mkt_segment = {
            'sh': '60',
            "sz": '00',
            "cyb": "30",
        }  #segment  当前板块开始字符串

    def __str__(self):
        return 'TDX object (code : %s)' % self.code

    @property
    def IP(self):  # self.IP
        return self.__ip

    @property
    def PORT(self):
        return self.__port

    @property
    def code(self):  #定义stock code 属性
        return self.__code

    @code.setter  #设定code
    def code(self, code_input):
        """
        The setter of the code property
        """
        if not isinstance(code_input, str):  #确定是否是字符串
            raise ValueError('the code must string!')
        if not len(code_input) == 6:  #确定长度
            raise ValueError('the code value error,the len must SIX !')
        if code_input.startswith('60'):  #确定表头
            self.__market = 1
        elif code_input.startswith('00'):
            self.__market = 0
        elif code_input.startswith('30'):
            self.__market = 0
        else:
            raise ValueError('this code is not stock code')
        self.__code = code_input

    @property
    def startdate(self):  #开始日期
        return self._startdate

    @startdate.setter  #设置日期
    def startdate(self, date_input):
        """
        The setter of the start date property
        """
        if not isinstance(date_input, str):
            raise ValueError('the date must string!')
        if not len(date_input) == 8:
            raise ValueError(
                'the date value error,the date formet must xxxx-xx-xx !')
        self._startdate = date_input

    @property  #结束日期
    def enddate(self):
        return self._enddate

    @enddate.setter
    def enddate(self, date_input):
        """
        The setter of the start date property
        """
        if not isinstance(date_input, str):
            raise ValueError('the date must string!')
        if not len(date_input) == 8:
            raise ValueError(
                'the date value error,the date formet must xxxx-xx-xx !')
        self._enddate = date_input

    def get_day_data_tdx(self):  #获取K line
        with self.tdx_api.connect(self.IP, self.PORT):
            data = self.tdx_api.get_k_data(self.code, self.startdate,
                                           self.enddate)
            data = pandas.DataFrame(data)
            data.date = data.date.apply(
                lambda x: datetime.datetime.strptime(x, "%Y-%m-%d"))
        return data

    #TODO: 现在是用800点进行计数,以后会细化功能
    def get_k_data_tdx(self, k_mode=9):
        """
        获取k 线图,总计800 点

        Parameters
        ----------
        k_mode= 0-11 
                    0 5分钟K线 
                    1 15分钟K线 
                    2 30分钟K线 
                    3 1小时K线 
                    4 日K线
                    5 周K线
                    6 月K线
                    7 1分钟
                    8 1分钟K线 9 日K线
                    10 季K线
                    11 年K线

        Returns
        -------

        """
        with self.tdx_api.connect(self.self.IP, self.self.PORT):
            data = self.tdx_api.get_security_bars(k_mode, self.__market,
                                                  self.code, 0, 800)

            data = pandas.DataFrame(data)
            #data.date = data.date.apply(
            #    lambda x: datetime.datetime.strptime(x, "%Y-%m-%d"))
        return data

    def len_market(self):  #市场有多少只股票
        with self.tdx_api.connect(self.IP, self.PORT):
            _len = self.tdx_api.get_security_count(self.__market)
        return _len

    def get_page_tdx(self, block=None):

        if block is None:
            market = 1
            page = [0]
        elif block in ['sh', 'SH']:
            market = 1
            page = [13, 14]
        elif block in ['sz', 'SZ']:
            print('block for shenzhen')
            market = 0
            page = [0, 1]
        elif block in ['cyb', 'CYB']:
            print('block for chuang ye ban')
            market = 0
            page = [7, 8]
        else:
            pass
        code_list_df = pandas.DataFrame()
        with self.tdx_api.connect(self.IP, self.PORT):
            for pn in page:
                data = self.tdx_api.get_security_list(market, pn * 1000)
                data = pandas.DataFrame(data)
                print(data)
                code_list_df = code_list_df.append(data, ignore_index=True)
        return code_list_df

    def get_base_finace_tdx(self):
        with self.tdx_api.connect(self.IP, self.PORT):
            data = self.tdx_api.get_finance_info(0, '000001')
            data = pandas.Series(data)
            print(data)

    def get_min_data(self):
        from pytdx.params import TDXParams
        with self.tdx_api.connect(self.IP, self.PORT):
            data = self.tdx_api.get_history_minute_time_data(
                TDXParams.MARKET_SH, self.code, 20161209)
            data = pandas.DataFrame(data)
            print(data)

    #TODO: 需要确定 0: buy  1 : sell
    def get_tick_data(self):
        """
        历史分笔交易:time 顺序; price ; vol ;buyorsell [1:] [0:];

        sh 60 13000-14000


        Parameters
        ----------

        Returns
        -------

        """
        data = pandas.DataFrame()
        with self.tdx_api.connect(self.IP, self.PORT):
            for i in [2000, 0000]:
                df = self.tdx_api.get_history_transaction_data(
                    TDXParams.MARKET_SH, "600547", i, 2000, 20160308)
                df = pandas.DataFrame(df)

                data = data.append(df, ignore_index=True)

        return data

    def get_tick_today(self):
        """
        Get every time the each deal for today.每组数最大len 2 k 所以要确定的数据长度

        Parameters
        ----------
        self: 

        Returns
        -------

        """

        with self.tdx_api.connect(self.IP, self.PORT):
            data = pandas.DataFrame()
            for i in [0, 2000]:
                df = self.tdx_api.get_transaction_data(self.__market,
                                                       self.code, i, 2000)
                df = pandas.DataFrame(df)
                data = data.append(df, ignore_index=True)

        return data

    def get_block(self):
        with self.tdx_api.connect(self.IP, self.PORT):
            data = self.tdx_api.get_and_parse_block_info("block.dat")
            data = pandas.DataFrame(data)
            print(data)

    def get_market_segment_list(self, mkt):
        data = self.get_page_tdx(mkt)
        self.code_list = pandas.DataFrame()
        pbar = tqdm(total=len(data.code))
        mkt_hard = self.mkt_segment[mkt]
        for idx, __code in enumerate(data.code):
            pbar.update(1)
            if __code.startswith(mkt_hard, 0, 2):
                self.code_list = self.code_list.append(data.loc[idx],
                                                       ignore_index=True)
        return self.code_list

    def get_sh_list(self):
        return self.get_market_segment_list('sh')

    def get_sz_list(self):
        return self.get_market_segment_list('sz')

    def get_cyb_list(self):
        return self.get_market_segment_list('cyb')
Exemplo n.º 4
0
class StdQuotes(object):
    """股票市场实时行情"""
    bestip = ('47.103.48.45', 7709)

    def __init__(self, **kwargs):

        try:
            default = settings.get('SERVER').get('HQ')[0]
            self.bestip = config.get('BESTIP').get('HQ', default)
        except ValueError:
            self.config = None

        self.client = TdxHq_API(**kwargs)

    def traffic(self):
        with self.client.connect(*self.bestip):
            return self.client.get_traffic_stats()

    def quotes(self, symbol=[]):
        '''
        获取实时日行情数据

        :param symbol: 股票代码
        :return: pd.dataFrame or None
        '''

        logger.debug(type(logger))

        if type(symbol) is str:
            symbol = [symbol]

        with self.client.connect(*self.bestip):
            symbol = get_stock_markets(symbol)
            result = self.client.get_security_quotes(symbol)

            return to_data(result)

    def bars(self, symbol='000001', frequency='9', start='0', offset='100'):
        '''
        获取实时日K线数据

        :param symbol: 股票代码
        :param frequency: 数据类别
        :param market: 证券市场
        :param start: 开始位置
        :param offset: 每次获取条数
        :return: pd.dataFrame or None
        '''
        with self.client.connect(*self.bestip):
            market = get_stock_market(symbol)
            result = self.client.get_security_bars(int(frequency), int(market),
                                                   str(symbol), int(start),
                                                   int(offset))

            return to_data(result)

    def stock_count(self, market=MARKET_SH):
        '''
        获取市场股票数量

        :param market: 股票市场代码 sh 上海, sz 深圳
        :return: pd.dataFrame or None
        '''
        with self.client.connect(*self.bestip):
            result = self.client.get_security_count(market=market)
            return result

    def stocks(self, market=MARKET_SH):
        '''
        获取股票列表

        :param market:
        :return:
        '''
        with self.client.connect(*self.bestip):
            counts = self.client.get_security_count(market=market)
            stocks = None

            for start in tqdm(range(0, counts, 1000)):
                result = self.client.get_security_list(market=market,
                                                       start=start)
                stocks = pandas.concat(
                    [stocks, to_data(result)],
                    ignore_index=True) if start > 1 else to_data(result)

            return stocks

    def index_bars(self,
                   symbol='000001',
                   frequency='9',
                   start='0',
                   offset='100'):
        '''
        获取指数k线

        :param symbol:
        :param frequency:
        :param start:
        :param offset:
        :return:
        '''
        with self.client.connect(*self.bestip):
            market = get_stock_market(symbol)
            result = self.client.get_index_bars(frequency=frequency,
                                                market=market,
                                                code=symbol,
                                                start=start,
                                                count=offset)

            return to_data(result)

    def minute(self, symbol=''):
        '''
        获取实时分时数据

        :param market: 证券市场
        :param symbol: 股票代码
        :return: pd.DataFrame
        '''
        with self.client.connect(*self.bestip):
            market = get_stock_market(symbol)
            result = self.client.get_minute_time_data(market=market,
                                                      code=symbol)
            return to_data(result)

    def minutes(self, symbol='', date='20191023'):
        '''
        分时历史数据

        :param market:
        :param symbol:
        :param date:
        :return: pd.dataFrame or None
        '''
        with self.client.connect(*self.bestip):
            market = get_stock_market(symbol)
            result = self.client.get_history_minute_time_data(market=market,
                                                              code=symbol,
                                                              date=date)

            return to_data(result)

    def transaction(self, symbol='', start=0, offset=10):
        '''
        查询分笔成交

        :param market: 市场代码
        :param symbol: 股票代码
        :param start: 起始位置
        :param offset: 请求数量
        :return: pd.dataFrame or None
        '''
        with self.client.connect(*self.bestip):
            market = get_stock_market(symbol)
            result = self.client.get_transaction_data(int(market), symbol,
                                                      int(start), int(offset))

            return to_data(result)

    def transactions(self, symbol='', start=0, offset=10, date='20170209'):
        '''
        查询历史分笔成交
        参数:市场代码, 股票代码,起始位置,日期 数量 如: 0,000001,0,10,20170209


        :param symbol: 股票代码
        :param start: 起始位置
        :param offset: 数量
        :param date: 日期
        :return: pd.dataFrame or None
        '''
        with self.client.connect(*self.bestip):
            market = get_stock_market(symbol, string=False)
            result = self.client.get_history_transaction_data(market=market,
                                                              code=symbol,
                                                              start=start,
                                                              count=offset,
                                                              date=date)

            return to_data(result)

    def F10C(self, symbol=''):
        '''
        查询公司信息目录

        :param market: 市场代码
        :param symbol: 股票代码
        :return: pd.dataFrame or None
        '''
        with self.client.connect(*self.bestip):
            market = get_stock_market(symbol)
            result = self.client.get_company_info_category(int(market), symbol)

            return result

    def F10(self, symbol='', name=''):
        '''
        读取公司信息详情

        :param name: 公司 F10 标题
        :param symbol: 股票代码
        :return: pd.dataFrame or None
        '''
        with self.client.connect(*self.bestip):
            result = {}
            market = get_stock_market(symbol, string=False)

            frequency = self.client.get_company_info_category(
                int(market), symbol)

            if name:
                for x in frequency:
                    if x['name'] == name:
                        return self.client.get_company_info_content(
                            market=market,
                            code=symbol,
                            filename=x['filename'],
                            start=x['start'],
                            length=x['length'])

            for x in frequency:
                result[x['name']] = self.client.get_company_info_content(
                    market=market,
                    code=symbol,
                    filename=x['filename'],
                    start=x['start'],
                    length=x['length'])
            else:
                pass

            return result

    def xdxr(self, symbol=''):
        '''
        读取除权除息信息

        :param market: 市场代码
        :param symbol: 股票代码
        :return: pd.dataFrame or None
        '''
        with self.client.connect(*self.bestip):
            market = get_stock_market(symbol)
            result = self.client.get_xdxr_info(int(market), symbol)

            return to_data(result)

    def finance(self, symbol='000001'):
        '''
        读取财务信息

        :param symbol:
        :return:
        '''
        with self.client.connect(*self.bestip):
            market = get_stock_market(symbol)
            result = self.client.get_finance_info(market=market, code=symbol)

            return to_data(result)

    def k(self, symbol='', begin=None, end=None):
        '''
        读取k线信息

        :param symbol:
        :param begin: 开始日期
        :param end: 截止日期
        :return: pd.dataFrame or None
        '''
        with self.client.connect(*self.bestip):
            result = self.client.get_k_data(symbol, begin, end)
            return to_data(result)

    def index(self,
              symbol='000001',
              market=MARKET_SH,
              frequency='9',
              start=1,
              offset=2):
        '''
        获取指数k线

        K线种类:
        - 0 5分钟K线
        - 1 15分钟K线
        - 2 30分钟K线
        - 3 1小时K线
        - 4 日K线
        - 5 周K线
        - 6 月K线
        - 7 1分钟
        - 8 1分钟K线
        - 9 日K线
        - 10 季K线
        - 11 年K线

        :param symbol: 股票代码
        :param frequency: 数据类别
        :param market: 证券市场
        :param start: 开始位置
        :param offset: 每次获取条数
        :return: pd.dataFrame or None
        '''
        with self.client.connect(*self.bestip):
            result = self.client.get_index_bars(int(frequency), int(market),
                                                str(symbol), int(start),
                                                int(offset))
            return to_data(result)

    def block(self, tofile="block.dat"):
        '''
        获取证券板块信息

        :param tofile:
        :return: pd.dataFrame or None
        '''
        with self.client.connect(*self.bestip):
            result = self.client.get_and_parse_block_info(tofile)
            return to_data(result)
Exemplo n.º 5
0
class Engine:
    def __init__(self, *args, **kwargs):
        if kwargs.pop('best_ip', False):
            self.ip = self.best_ip
        else:
            self.ip = '14.17.75.71'

        self.ip = kwargs.pop('ip', '14.17.75.71')

        self.thread_num = kwargs.pop('thread_num', 1)

        if not PY2 and self.thread_num != 1:
            self.use_concurrent = True
        else:
            self.use_concurrent = False

        self.api = TdxHq_API(args, kwargs)
        if self.use_concurrent:
            self.apis = [
                TdxHq_API(args, kwargs) for i in range(self.thread_num)
            ]
            self.executor = ThreadPoolExecutor(self.thread_num)

    def connect(self):
        self.api.connect(self.ip)
        if self.use_concurrent:
            for api in self.apis:
                api.connect(self.ip)
        return self

    def __enter__(self):
        return self

    def exit(self):
        self.api.disconnect()
        if self.use_concurrent:
            for api in self.apis:
                api.disconnect()

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.api.disconnect()
        if self.use_concurrent:
            for api in self.apis:
                api.disconnect()

    def quotes(self, code):
        code = [code] if not isinstance(code, list) else code
        code = self.security_list[self.security_list.code.isin(
            code)].index.tolist()
        data = [
            self.api.to_df(
                self.api.get_security_quotes(code[80 * pos:80 * (pos + 1)]))
            for pos in range(int(len(code) / 80) + 1)
        ]
        return pd.concat(data)
        # data = data[['code', 'open', 'high', 'low', 'price']]
        # data['datetime'] = datetime.datetime.now()
        # return data.set_index('code', drop=False, inplace=False)

    def stock_quotes(self):
        code = self.stock_list.index.tolist()
        if self.use_concurrent:
            res = {
                self.executor.submit(self.apis[pos % self.thread_num].get_security_quotes,
                                     code[80 * pos:80 * (pos + 1)]) \
                for pos in range(int(len(code) / 80) + 1)}
            return pd.concat([self.api.to_df(dic.result()) for dic in res])
        else:
            data = [
                self.api.to_df(
                    self.api.get_security_quotes(code[80 * pos:80 *
                                                      (pos + 1)]))
                for pos in range(int(len(code) / 80) + 1)
            ]
            return pd.concat(data)

    @lazyval
    def security_list(self):
        return pd.concat([
            pd.concat([
                self.api.to_df(self.api.get_security_list(
                    j, i * 1000)).assign(sse=0 if j == 0 else 1).set_index(
                        ['sse', 'code'], drop=False)
                for i in range(int(self.api.get_security_count(j) / 1000) + 1)
            ],
                      axis=0) for j in range(2)
        ],
                         axis=0)

    @lazyval
    def stock_list(self):
        aa = map(stock_filter, self.security_list.index.tolist())
        return self.security_list[list(aa)]

    @lazyval
    def best_ip(self):
        return select_best_ip()

    @lazyval
    def concept(self):
        return self.api.to_df(
            self.api.get_and_parse_block_info(TDXParams.BLOCK_GN))

    @lazyval
    def index(self):
        return self.api.to_df(
            self.api.get_and_parse_block_info(TDXParams.BLOCK_SZ))

    @lazyval
    def fengge(self):
        return self.api.to_df(
            self.api.get_and_parse_block_info(TDXParams.BLOCK_FG))

    @lazyval
    def customer_block(self):
        return CustomerBlockReader().get_df(CUSTOMER_BLOCK_PATH)

    @lazyval
    def gbbq(self):
        df = GbbqReader().get_df(GBBQ_PATH).query('category == 1')
        df['datetime'] = pd.to_datetime(df['datetime'], format='%Y%m%d')
        return df

    def get_security_type(self, code):
        if code in self.security_list.code.values:
            return self.security_list[self.security_list.code ==
                                      code]['sse'].as_matrix()[0]
        else:
            raise SecurityNotExists()

    def get_security_bars(self, code, freq, index=False):
        if index:
            exchange = self.get_security_type(code)
            func = self.api.get_index_bars
        else:
            exchange = get_stock_type(code)
            func = self.api.get_security_bars

        df = pd.DataFrame()
        if freq in ['1d', 'day']:
            freq = 9
        elif freq in ['1m', 'min']:
            freq = 8
        else:
            raise Exception("1d and 1m frequency supported only")

        res = []
        start = 0
        while True:
            data = func(freq, exchange, code, start, 800)
            if not data:
                break
            res = data + res
            start += 800

        df = self.api.to_df(res).drop(
            ['year', 'month', 'day', 'hour', 'minute'], axis=1)
        df['datetime'] = pd.to_datetime(df.datetime)
        df['code'] = code
        return df.set_index('datetime')

    def _get_transaction(self, code, date):
        res = []
        start = 0
        while True:
            data = self.api.get_history_transaction_data(
                get_stock_type(code), code, start, 2000, date)
            if not data:
                break
            start += 2000
            res = data + res

        if len(res) == 0:
            return pd.DataFrame()
        df = self.api.to_df(res).assign(date=date)
        df.index = pd.to_datetime(str(date) + " " + df["time"])
        df['code'] = code
        return df.drop("time", axis=1)

    def time_and_price(self, code):
        start = 0
        res = []
        exchange = self.get_security_type(code)
        while True:
            data = self.api.get_transaction_data(exchange, code, start, 2000)
            if not data:
                break
            res = data + res
            start += 2000

        df = self.api.to_df(res)
        df.time = pd.to_datetime(
            str(pd.to_datetime('today').date()) + " " + df['time'])
        df.loc[0, 'time'] = df.time[1]
        return df.set_index('time')

    @classmethod
    def minute_bars_from_transaction(cls, transaction, freq):
        if transaction.empty:
            return pd.DataFrame()
        data = transaction['price'].resample(freq,
                                             label='right',
                                             closed='left').ohlc()

        data['volume'] = transaction['vol'].resample(freq,
                                                     label='right',
                                                     closed='left').sum()
        data['code'] = transaction['code'][0]

        return fillna(data)

    def get_k_data(self, code, start, end, freq):
        if isinstance(start, str) or isinstance(end, str):
            start = pd.Timestamp(start)
            end = pd.Timestamp(end)
        sessions = pd.date_range(start, end)
        trade_days = map(int, sessions.strftime("%Y%m%d"))

        if freq == '1m':
            freq = '1 min'

        if freq == '1d':
            freq = '24 H'

        res = []
        for trade_day in trade_days:
            df = Engine.minute_bars_from_transaction(
                self._get_transaction(code, trade_day), freq)
            if df.empty:
                continue
            res.append(df)

        if len(res) != 0:
            return pd.concat(res)
        return pd.DataFrame()
            hour, minute, pos = get_time(body_buf, pos)

            price_raw, pos = get_price(body_buf, pos)
            vol, pos = get_price(body_buf, pos)
            buyorsell, pos = get_price(body_buf, pos)
            _, pos = get_price(body_buf, pos)

            last_price = last_price + price_raw

            tick = OrderedDict([
                ("time", "%02d:%02d" % (hour, minute)),
                ("price", float(last_price) / 100),
                ("vol", vol),
                ("buyorsell", buyorsell),
            ])

            ticks.append(tick)

        return ticks


if __name__ == '__main__':
    from pytdx.hq import TdxHq_API
    api = TdxHq_API()
    with api.connect():
        print(
            api.to_df(
                api.get_history_transaction_data(0, '000001', 0, 10,
                                                 20170811)))
Exemplo n.º 7
0
def test_all_functions(multithread, heartbeat, auto_retry, raise_exception):

    api = TdxHq_API(multithread=multithread, heartbeat=heartbeat,
                    auto_retry=auto_retry, raise_exception=raise_exception)
    with api.connect(time_out=30):
        log.info("获取股票行情")
        stocks = api.get_security_quotes([(0, "000001"), (1, "600300")])
        assert stocks is not None
        assert type(stocks) is list

        # 方法2
        stocks = api.get_security_quotes(0, "000001")
        assert stocks is not None
        assert type(stocks) is list

        # 方法3
        stocks = api.get_security_quotes((0, "000001"))
        assert stocks is not None
        assert type(stocks) is list

        log.info("获取k线")
        data = api.get_security_bars(9, 0, '000001', 4, 3)
        assert data is not None
        assert type(data) is list
        assert len(data) == 3

        log.info("获取 深市 股票数量")
        assert api.get_security_count(0) > 0

        log.info("获取股票列表")
        stocks = api.get_security_list(1, 0)
        assert stocks is not None
        assert type(stocks) is list
        assert len(stocks) > 0

        log.info("获取指数k线")
        data = api.get_index_bars(9, 1, '000001', 1, 2)
        assert data is not None
        assert type(data) is list
        assert len(data) == 2

        log.info("查询分时行情")
        data = api.get_minute_time_data(TDXParams.MARKET_SH, '600300')
        assert data is not None

        log.info("查询历史分时行情")
        data = api.get_history_minute_time_data(
            TDXParams.MARKET_SH, '600300', 20161209)
        assert data is not None
        assert type(data) is list
        assert len(data) > 0

        log.info("查询分时成交")
        data = api.get_transaction_data(TDXParams.MARKET_SZ, '000001', 0, 30)
        assert data is not None
        assert type(data) is list

        log.info("查询历史分时成交")
        data = api.get_history_transaction_data(
            TDXParams.MARKET_SZ, '000001', 0, 10, 20170209)

        assert data is not None
        assert type(data) is list
        assert len(data) == 10

        log.info("查询公司信息目录")
        data = api.get_company_info_category(TDXParams.MARKET_SZ, '000001')
        assert data is not None
        assert type(data) is list
        assert len(data) > 0

        start = data[0]['start']
        length = data[0]['length']
        log.info("读取公司信息-最新提示")
        data = api.get_company_info_content(
            0, '000001', '000001.txt', start, length)
        assert data is not None
        assert len(data) > 0

        log.info("读取除权除息信息")
        data = api.get_xdxr_info(1, '600300')
        assert data is not None
        assert type(data) is list
        assert len(data) > 0

        log.info("读取财务信息")
        data = api.get_finance_info(0, '000001')
        assert data is not None
        assert type(data) is OrderedDict
        assert len(data) > 0

        log.info("日线级别k线获取函数")
        data = api.get_k_data('000001', '2017-07-01', '2017-07-10')
        assert type(data) is pd.DataFrame
        assert len(data) == 6

        log.info("获取板块信息")
        data = api.get_and_parse_block_info(TDXParams.BLOCK_FG)
        assert data is not None
        assert type(data) is list
        assert len(data) > 0
Exemplo n.º 8
0
class StdQuotes(object):
    """股票市场实时行情"""

    # __slots__ =
    def __init__(self, **kwargs):
        self.config = None

        try:
            self.config = json.loads(
                os.path.join(os.environ['HOME'], '.mootdx/config.josn'))
        except Exception as e:
            self.config = None

        self.client = TdxHq_API(**kwargs)

        if not self.config:
            self.bestip = os.environ.setdefault("MOOTDX_SERVER",
                                                '202.108.253.131:7709')
            self.bestip = self.bestip.split(':')
            self.bestip[1] = int(self.bestip[1])
        else:
            self.bestip = self.config.get('SERVER')

    # K线
    def bars(self, symbol='000001', category='9', start='0', offset='100'):
        '''
        获取实时日K线数据

        :param symbol: 股票代码
        :param category: 数据类别
        :param market: 证券市场
        :param start: 开始位置
        :param offset: 每次获取条数
        :return: pd.dataFrame or None
        '''
        market = get_stock_market(symbol)

        with self.client.connect(*self.bestip):
            data = self.client.get_security_bars(int(category), int(market),
                                                 str(symbol), int(start),
                                                 int(offset))
            return self.client.to_df(data)

    # 分时数据
    def minute(self, symbol=''):
        '''
        获取实时分时数据

        :param market: 证券市场
        :param symbol: 股票代码
        :return: pd.DataFrame
        '''
        market = get_stock_market(symbol)

        with self.client.connect(*self.bestip):
            data = self.client.get_minute_time_data(int(market), symbol)
            return self.client.to_df(data)

    # 分时历史数据
    def minute_his(self, symbol='', datetime='20161209'):
        '''
        分时历史数据

        :param market:
        :param symbol:
        :param datetime:
        :return: pd.dataFrame or None
        '''
        market = get_stock_market(symbol)

        with self.client.connect(*self.bestip):
            data = self.client.get_history_minute_time_data(
                int(market), symbol, datetime)
            return self.client.to_df(data)

    def trans(self, symbol='', start=0, offset=10):
        '''
        查询分笔成交

        :param market: 市场代码
        :param symbol: 股票代码
        :param start: 起始位置
        :param offset: 请求数量
        :return: pd.dataFrame or None
        '''
        market = get_stock_market(symbol)

        with self.client.connect(*self.bestip):
            data = self.client.get_transaction_data(int(market), symbol,
                                                    int(start), int(market))
            return self.client.to_df(data)

    def trans_his(self, symbol='', start=0, offset=10, date=''):
        '''
        查询历史分笔成交

        :param market: 市场代码
        :param symbol: 股票代码
        :param start: 起始位置
        :param offset: 数量
        :param date: 日期
        :return: pd.dataFrame or None
        '''
        market = get_stock_market(symbol)

        with self.client.connect(*self.bestip):
            data = self.client.get_history_transaction_data(
                int(market), symbol, int(start), int(offset), date)
            return self.client.to_df(data)

    def company(self, symbol='', detail='category', *args, **kwargs):
        '''
        企业信息获取

        :param symbol:
        :param detail:
        :param args:
        :param kwargs:
        :return:
        '''
        pass

    def company_category(self, symbol=''):
        '''
        查询公司信息目录

        :param market: 市场代码
        :param symbol: 股票代码
        :return: pd.dataFrame or None
        '''
        market = get_stock_market(symbol)

        with self.client.connect(*self.bestip):
            data = self.client.get_company_info_category(int(market), symbol)
            return self.client.to_df(data)

    def company_content(self, symbol='', file='', start=0, offset=10):
        '''
        读取公司信息详情

        :param market: 市场代码
        :param symbol: 股票代码
        :param file: 文件名
        :param start: 起始位置
        :param offset: 数量
        :return: pd.dataFrame or None
        '''
        market = get_stock_market(symbol)

        with self.client.connect(*self.bestip):
            data = self.client.get_company_info_content(
                int(market), symbol, file, int(start), int(offset))
            return self.client.to_df(data)

    def xdxr(self, symbol=''):
        '''
        读取除权除息信息

        :param market: 市场代码
        :param symbol: 股票代码
        :return: pd.dataFrame or None
        '''
        market = get_stock_market(symbol)

        with self.client.connect(*self.bestip):
            data = self.client.get_xdxr_info(int(market), symbol)
            return self.client.to_df(data)

    def k(self, symbol='', begin=None, end=None):
        '''
        读取k线信息

        :param symbol:
        :param begin:
        :param end:
        :return: pd.dataFrame or None
        '''
        with self.client.connect(*self.bestip):
            data = self.client.get_k_data(symbol, begin, end)
            return data

    def index(self,
              symbol='000001',
              market='sh',
              category='9',
              start='0',
              offset='100'):
        '''
        获取指数k线

        K线种类:
        - 0 5分钟K线
        - 1 15分钟K线
        - 2 30分钟K线
        - 3 1小时K线
        - 4 日K线
        - 5 周K线
        - 6 月K线
        - 7 1分钟
        - 8 1分钟K线
        - 9 日K线
        - 10 季K线
        - 11 年K线

        :param symbol: 股票代码
        :param category: 数据类别
        :param market: 证券市场
        :param start: 开始位置
        :param offset: 每次获取条数
        :return: pd.dataFrame or None
        '''
        market = 1 if market == 'sh' else 0

        with self.client.connect(*self.bestip):
            data = self.client.get_index_bars(int(category), int(market),
                                              str(symbol), int(start),
                                              int(offset))
            return self.client.to_df(data)

    def block(self, tofile="block.dat"):
        '''
        获取证券板块信息

        :param tofile:
        :return: pd.dataFrame or None
        '''
        with self.client.connect(*self.bestip):
            data = self.client.get_and_parse_block_info(tofile)
            return self.client.to_df(data)

    def batch(self, method='', offset=100, *args, **kwargs):
        '''
        批量下载相关数据

        :param method:
        :param offset:
        :return:
        '''

        pass
Exemplo n.º 9
0
stocks = api.get_security_list(1, 255)
print(stocks)
print("获取指数k线")
data = api.get_index_bars(9, 1, '000001', 1, 2)
print(data)
print("查询分时行情")
data = api.get_minute_time_data(1, '600300')
print(data)
print("查询历史分时行情")
data = api.get_history_minute_time_data(1, '600300', 20161209)
print(data)
print("查询分时成交")
data = api.get_transaction_data(1, '000002', 0, 30)
print(data)
print("查询历史分时成交")
data = api.get_history_transaction_data(2, '600302', 0, 10, 20170209)
print(data)
print("查询公司信息目录")
data = api.get_company_info_category(1, '000003')
print(data)
print("读取公司信息-最新提示")
data = api.get_company_info_content(0, '000001', '000001.txt', 0, 10)
print(data)
print("读取除权除息信息")
data = api.get_xdxr_info(1, '600300')
print(data)
print("读取财务信息")
data = api.get_finance_info(0, '000001')
print(data)
print("日线级别k线获取函数")
data = api.get_k_data('000001', '2005-07-01', '2017-07-10')
Exemplo n.º 10
0
class TdxStockData(object):

    def __init__(self, strategy=None):
        """
        构造函数
        :param strategy: 上层策略,主要用与使用write_log()
        """
        self.api = None

        self.connection_status = False  # 连接状态

        self.strategy = strategy
        self.best_ip = None
        self.symbol_exchange_dict = {}  # tdx合约与vn交易所的字典
        self.symbol_market_dict = {}  # tdx合约与tdx市场的字典

        self.config = get_cache_config(TDX_STOCK_CONFIG)
        self.symbol_dict = self.config.get('symbol_dict', {})
        self.cache_time = self.config.get('cache_time', datetime.now() - timedelta(days=7))

        if len(self.symbol_dict) == 0 or self.cache_time < datetime.now() - timedelta(days=1):
            self.cache_config()

    def write_log(self, content):
        """记录日志"""
        if self.strategy:
            self.strategy.write_log(content)
        else:
            print(content)

    def write_error(self, content):
        """记录错误"""
        if self.strategy:
            self.strategy.write_log(content, level=ERROR)
        else:
            print(content, file=sys.stderr)

    def connect(self, is_reconnect: bool = False):
        """
        连接API
        :param:is_reconnect, 是否重新连接
        :return:
        """
        # 创建api连接对象实例
        try:
            if self.api is None or not self.connection_status:
                self.write_log(u'开始连接通达信股票行情服务器')
                self.api = TdxHq_API(heartbeat=True, auto_retry=True, raise_exception=True)

                # 选取最佳服务器
                if is_reconnect or self.best_ip is None:
                    self.best_ip = self.config.get('best_ip', {})

                if len(self.best_ip) == 0:
                    from pytdx.util.best_ip import select_best_ip
                    self.best_ip = select_best_ip()
                    self.config.update({'best_ip': self.best_ip})
                    save_cache_config(self.config, TDX_STOCK_CONFIG)

                self.api.connect(self.best_ip.get('ip'), self.best_ip.get('port'))
                self.write_log(f'创建tdx连接, : {self.best_ip}')
                self.connection_status = True

        except Exception as ex:
            self.write_log(u'连接服务器tdx异常:{},{}'.format(str(ex), traceback.format_exc()))
            return

    def disconnect(self):
        """断开连接"""
        if self.api is not None:
            self.api = None

    def cache_config(self):
        """缓存所有股票的清单"""
        for market_id in range(2):
            print('get market_id:{}'.format(market_id))
            security_list = self.get_security_list(market_id)
            if len(security_list) == 0:
                continue
            for security in security_list:
                tdx_symbol = security.get('code', None)
                if tdx_symbol:
                    self.symbol_dict.update({f'{tdx_symbol}_{market_id}': security})

        self.config.update({'symbol_dict': self.symbol_dict, 'cache_time': datetime.now()})
        save_cache_config(data=self.config, config_file_name=TDX_STOCK_CONFIG)

    def get_security_list(self, market_id: int = 0):
        """
        获取市场代码
        :param: market_id: 1,上交所 , 0, 深交所
        :return:
        """
        if self.api is None:
            self.connect()

        start = 0
        results = []
        # 接口有数据量连续,循环获取,直至取不到结果为止
        while True:
            try:
                result = self.api.get_security_list(market_id, start)
            except Exception:
                break
            if len(result) > 0:
                start += len(result)
            else:
                break
            results.extend(result)

        return results

    # ----------------------------------------------------------------------
    def get_bars(self,
                 symbol: str,
                 period: str,
                 callback=None,
                 bar_freq: int = 1,
                 start_dt: datetime = None):
        """
        返回k线数据
        symbol:股票 000001.XG
        period: 周期: 1min,5min,15min,30min,1hour,1day,
        """
        if not self.api:
            self.connect()
        ret_bars = []
        if self.api is None:
            return False, []

        # symbol => tdx_code, market_id
        if '.' in symbol:
            tdx_code, market_str = symbol.split('.')
            # 1, 上交所 , 0, 深交所
            market_code = 1 if market_str.upper() in ['XSHG', Exchange.SSE.value] else 0
            self.symbol_exchange_dict.update({tdx_code: symbol})  # tdx合约与vn交易所的字典
            self.symbol_market_dict.update({tdx_code: market_code})  # tdx合约与tdx市场的字典
        else:
            market_code = get_tdx_market_code(symbol)
            tdx_code = symbol
            self.symbol_exchange_dict.update({symbol: symbol})  # tdx合约与vn交易所的字典
            self.symbol_market_dict.update({symbol: market_code})  # tdx合约与tdx市场的字典

        # period => tdx_period
        if period not in PERIOD_MAPPING.keys():
            self.write_error(u'{} 周期{}不在下载清单中: {}'
                             .format(datetime.now(), period, list(PERIOD_MAPPING.keys())))
            # print(u'{} 周期{}不在下载清单中: {}'.format(datetime.now(), period, list(PERIOD_MAPPING.keys())))
            return False, ret_bars
        tdx_period = PERIOD_MAPPING.get(period)

        # start_dt => qry_start_dt  & qry_end_dt
        if start_dt is None:
            self.write_log(u'没有设置开始时间,缺省为10天前')
            qry_start_date = datetime.now() - timedelta(days=10)
            start_dt = qry_start_date
        else:
            qry_start_date = start_dt
        qry_end_date = datetime.now()
        if qry_start_date > qry_end_date:
            qry_start_date = qry_end_date

        self.write_log('{}开始下载tdx股票:{} {}数据, {} to {}.'
                       .format(datetime.now(), tdx_code, tdx_period, qry_start_date, qry_end_date))

        try:
            _start_date = qry_end_date
            _bars = []
            _pos = 0
            while _start_date > qry_start_date:
                _res = self.api.get_security_bars(
                    category=PERIOD_MAPPING[period],
                    market=market_code,
                    code=tdx_code,
                    start=_pos,
                    count=QSIZE)
                if _res is not None:
                    _bars = _res + _bars
                _pos += QSIZE
                if _res is not None and len(_res) > 0:
                    _start_date = _res[0]['datetime']
                    _start_date = datetime.strptime(_start_date, '%Y-%m-%d %H:%M')
                    self.write_log(u'分段取数据开始:{}'.format(_start_date))
                else:
                    break
            if len(_bars) == 0:
                self.write_error('{} Handling {}, len1={}..., continue'.format(
                    str(datetime.now()), tdx_code, len(_bars)))
                return False, ret_bars

            current_datetime = datetime.now()
            data = self.api.to_df(_bars)
            data = data.assign(datetime=to_datetime(data['datetime']))
            data = data.assign(ticker=symbol)
            data['symbol'] = symbol
            data = data.drop(
                ['year', 'month', 'day', 'hour', 'minute', 'price', 'ticker'],
                errors='ignore',
                axis=1)
            data = data.rename(
                index=str,
                columns={'amount': 'volume'})
            if len(data) == 0:
                print('{} Handling {}, len2={}..., continue'.format(
                    str(datetime.now()), tdx_code, len(data)))
                return False, ret_bars

            # 通达信是以bar的结束时间标记的,vnpy是以bar开始时间标记的,所以要扣减bar本身的分钟数
            data['datetime'] = data['datetime'].apply(
                lambda x: x - timedelta(minutes=NUM_MINUTE_MAPPING.get(period, 1)))
            data['trading_date'] = data['datetime'].apply(lambda x: (x.strftime('%Y-%m-%d')))
            data['date'] = data['datetime'].apply(lambda x: (x.strftime('%Y-%m-%d')))
            data['time'] = data['datetime'].apply(lambda x: (x.strftime('%H:%M:%S')))
            data = data.set_index('datetime', drop=False)

            for index, row in data.iterrows():
                add_bar = BarData()
                try:
                    add_bar.symbol = symbol
                    add_bar.datetime = index
                    add_bar.date = row['date']
                    add_bar.time = row['time']
                    add_bar.trading_date = row['trading_date']
                    add_bar.open_price = float(row['open'])
                    add_bar.high_price = float(row['high'])
                    add_bar.low_price = float(row['low'])
                    add_bar.close_price = float(row['close'])
                    add_bar.volume = float(row['volume'])
                except Exception as ex:
                    self.write_error('error when convert bar:{},ex:{},t:{}'
                                     .format(row, str(ex), traceback.format_exc()))
                    # print('error when convert bar:{},ex:{},t:{}'.format(row, str(ex), traceback.format_exc()))
                    return False

                if start_dt is not None and index < start_dt:
                    continue
                ret_bars.append(add_bar)

                if callback is not None:
                    freq = bar_freq
                    bar_is_completed = True
                    if period != '1min' and index == data['datetime'][-1]:
                        # 最后一个bar,可能是不完整的,强制修改
                        # - 5min修改后freq基本正确
                        # - 1day在VNPY合成时不关心已经收到多少Bar, 所以影响也不大
                        # - 但其它分钟周期因为不好精确到每个品种, 修改后的freq可能有错
                        if index > current_datetime:
                            bar_is_completed = False
                            # 根据秒数算的话,要+1,例如13:31,freq=31,第31根bar
                            freq = NUM_MINUTE_MAPPING[period] - int((index - current_datetime).total_seconds() / 60)
                    callback(add_bar, bar_is_completed, freq)

            return True, ret_bars
        except Exception as ex:
            self.write_error('exception in get:{},{},{}'.format(tdx_code, str(ex), traceback.format_exc()))
            # print('exception in get:{},{},{}'.format(tdx_symbol,str(ex), traceback.format_exc()))
            self.write_log(u'重置连接')
            self.api = None
            self.connect(is_reconnect=True)
            return False, ret_bars

    def save_cache(self,
                   cache_folder: str,
                   cache_symbol: str,
                   cache_date: str,
                   data_list: list):
        """保存文件到缓存"""

        os.makedirs(cache_folder, exist_ok=True)

        if not os.path.exists(cache_folder):
            self.write_error('缓存目录不存在:{},不能保存'.format(cache_folder))
            return
        cache_folder_year_month = os.path.join(cache_folder, cache_date[:6])
        os.makedirs(cache_folder_year_month, exist_ok=True)

        save_file = os.path.join(cache_folder_year_month, '{}_{}.pkb2'.format(cache_symbol, cache_date))
        with bz2.BZ2File(save_file, 'wb') as f:
            pickle.dump(data_list, f)
            self.write_log(u'缓存成功:{}'.format(save_file))

    def load_cache(self,
                   cache_folder: str,
                   cache_symbol: str,
                   cache_date: str):
        """加载缓存数据"""
        if not os.path.exists(cache_folder):
            self.write_error('缓存目录:{}不存在,不能读取'.format(cache_folder))
            return None
        cache_folder_year_month = os.path.join(cache_folder, cache_date[:6])
        if not os.path.exists(cache_folder_year_month):
            self.write_error('缓存目录:{}不存在,不能读取'.format(cache_folder_year_month))
            return None

        cache_file = os.path.join(cache_folder_year_month, '{}_{}.pkb2'.format(cache_symbol, cache_date))
        if not os.path.isfile(cache_file):
            self.write_error('缓存文件:{}不存在,不能读取'.format(cache_file))
            return None
        with bz2.BZ2File(cache_file, 'rb') as f:
            data = pickle.load(f)
            return data

        return None

    def get_history_transaction_data(self,
                                     symbol: str,
                                     trading_date,
                                     cache_folder: str = None):
        """
        获取当某一交易日的历史成交记录
        :param symbol: 查询合约  xxxxxx.交易所
        :param trading_date: 可以是日期参数,或者字符串参数,支持 2019-01-01 或  20190101格式
        :param cache_folder:
        :return:
        """
        if not self.api:
            self.connect()

        ret_datas = []

        # trading_date ,转为为查询数字类型
        if isinstance(trading_date, datetime):
            trading_date = trading_date.strftime('%Y%m%d')
        if isinstance(trading_date, str):
            trading_date = int(trading_date.replace('-', ''))

        cache_symbol = symbol
        cache_date = str(trading_date)

        max_data_size = sys.maxsize
        # symbol.exchange => tdx_code market_code
        if '.' in symbol:
            tdx_code, market_str = symbol.split('.')
            market_code = 1 if market_str.upper() in ['XSHG', Exchange.SSE.value] else 0
            self.symbol_exchange_dict.update({tdx_code: symbol})  # tdx合约与vn交易所的字典
            self.symbol_market_dict.update({tdx_code: market_code})  # tdx合约与tdx市场的字典
        else:
            market_code = get_tdx_market_code(symbol)
            tdx_code = symbol
            self.symbol_exchange_dict.update({symbol: symbol})  # tdx合约与vn交易所的字典
            self.symbol_market_dict.update({symbol: market_code})  # tdx合约与tdx市场的字典

        q_size = QSIZE * 5
        # 每秒 2个, 10小时
        max_data_size = 1000000
        # 优先从缓存加载
        if cache_folder:
            buffer_data = self.load_cache(cache_folder, cache_symbol, cache_date)
            if buffer_data:
                return True, buffer_data

        self.write_log(u'开始下载{} 历史{}分笔数据'.format(trading_date, symbol))

        is_today = False
        if trading_date == int(datetime.now().strftime('%Y%m%d')):
            is_today = True

        try:
            _datas = []
            _pos = 0

            while True:
                if is_today:
                    # 获取当前交易日得交易记录
                    _res = self.api.get_transaction_data(
                        market=self.symbol_market_dict[symbol],
                        code=symbol,
                        start=_pos,
                        count=q_size)
                else:
                    # 获取历史交易记录
                    _res = self.api.get_history_transaction_data(
                        market=self.symbol_market_dict[symbol],
                        date=trading_date,
                        code=symbol,
                        start=_pos,
                        count=q_size)
                last_dt = None
                if _res is not None:
                    _datas = _res + _datas
                _pos += min(q_size, len(_res))

                if _res is not None and len(_res) > 0:
                    self.write_log(u'分段取{}分笔数据:{} ~{}, {}条,累计:{}条'
                                   .format(trading_date, _res[0]['time'], _res[-1]['time'], len(_res), _pos))
                else:
                    break

                if len(_datas) >= max_data_size:
                    break

            if len(_datas) == 0:
                self.write_error(u'{}分笔成交数据获取为空'.format(trading_date))
                return False, _datas

            for d in _datas:
                dt = datetime.strptime(str(trading_date) + ' ' + d.get('time'), '%Y%m%d %H:%M')
                if last_dt is None or last_dt < dt:
                    last_dt = dt
                else:
                    if last_dt < dt + timedelta(seconds=59):
                        last_dt = last_dt + timedelta(seconds=1)
                d.update({'datetime': last_dt})
                d.update({'volume': d.pop('vol', 0)})
                d.update({'trading_date': last_dt.strftime('%Y-%m-%d')})

            _datas = sorted(_datas, key=lambda s: s['datetime'])

            # 缓存文件
            if cache_folder and not is_today:
                self.save_cache(cache_folder, cache_symbol, cache_date, _datas)

            return True, _datas

        except Exception as ex:
            self.write_error(
                'exception in get_transaction_data:{},{},{}'.format(symbol, str(ex), traceback.format_exc()))
            return False, ret_datas
Exemplo n.º 11
0
class Engine:
    concurrent_thread_count = 50

    def __init__(self, *args, **kwargs):
        if 'ip' in kwargs:
            self.ip = kwargs.pop('ip')
        else:
            if kwargs.pop('best_ip', False):
                self.ip = self.best_ip
            else:
                self.ip = '14.17.75.71'
        if 'concurrent_thread_count' in kwargs:
            self.concurrent_thread_count = kwargs.pop(
                'concurrent_thread_count', 50)
        self.thread_num = kwargs.pop('thread_num', 1)

        self.api = TdxHq_API(args, kwargs, raise_exception=True)

    def connect(self):
        self.api.connect(self.ip)
        return self

    def __enter__(self):
        return self

    def exit(self):
        self.api.disconnect()

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.api.disconnect()

    def quotes(self, code):
        code = [code] if not isinstance(code, list) else code
        code = self.security_list[self.security_list.code.isin(
            code)].index.tolist()
        data = [
            self.api.to_df(
                self.api.get_security_quotes(code[80 * pos:80 * (pos + 1)]))
            for pos in range(int(len(code) / 80) + 1)
        ]
        return pd.concat(data)
        # data = data[['code', 'open', 'high', 'low', 'price']]
        # data['datetime'] = datetime.datetime.now()
        # return data.set_index('code', drop=False, inplace=False)

    def stock_quotes(self):
        code = self.stock_list.index.tolist()
        data = [
            self.api.to_df(
                self.api.get_security_quotes(code[80 * pos:80 * (pos + 1)]))
            for pos in range(int(len(code) / 80) + 1)
        ]
        return pd.concat(data)

    @lazyval
    def security_list(self):
        return pd.concat([
            pd.concat([
                self.api.to_df(self.api.get_security_list(
                    j, i * 1000)).assign(sse=0 if j == 0 else 1).set_index(
                        ['sse', 'code'], drop=False)
                for i in range(int(self.api.get_security_count(j) / 1000) + 1)
            ],
                      axis=0) for j in range(2)
        ],
                         axis=0)

    @lazyval
    def stock_list(self):
        aa = map(stock_filter, self.security_list.index.tolist())
        return self.security_list[list(aa)]

    @lazyval
    def best_ip(self):
        return select_best_ip()

    @lazyval
    def concept(self):
        return self.api.to_df(
            self.api.get_and_parse_block_info(TDXParams.BLOCK_GN))

    @lazyval
    def index(self):
        return self.api.to_df(
            self.api.get_and_parse_block_info(TDXParams.BLOCK_SZ))

    @lazyval
    def fengge(self):
        return self.api.to_df(
            self.api.get_and_parse_block_info(TDXParams.BLOCK_FG))

    @lazyval
    def block(self):
        return self.api.to_df(
            self.api.get_and_parse_block_info(TDXParams.BLOCK_DEFAULT))

    @lazyval
    def customer_block(self):
        return CustomerBlockReader().get_df(CUSTOMER_BLOCK_PATH)

    def xdxr(self, code):
        df = self.api.to_df(
            self.api.get_xdxr_info(self.get_security_type(code), code))
        if df.empty:
            return df
        df['datetime'] = pd.to_datetime((df.year * 10000 + df.month * 100 +
                                         df.day).apply(lambda x: str(x)))
        return df.drop(['year', 'month', 'day'], axis=1).set_index('datetime')

    @lazyval
    def gbbq(self):
        df = GbbqReader().get_df(GBBQ_PATH).query('category == 1')
        df['datetime'] = pd.to_datetime(df['datetime'], format='%Y%m%d')
        return df

    def get_security_type(self, code):
        if code in self.security_list.code.values:
            return self.security_list[self.security_list.code ==
                                      code]['sse'].as_matrix()[0]
        else:
            raise SecurityNotExists()

    @retry(3)
    def get_security_bars(self, code, freq, start=None, end=None, index=False):
        if index:
            exchange = self.get_security_type(code)
            func = self.api.get_index_bars
        else:
            exchange = get_stock_type(code)
            func = self.api.get_security_bars

        if start:
            start = start.tz_localize(None)
        if end:
            end = end.tz_localize(None)

        if freq in ['1d', 'day']:
            freq = 9
        elif freq in ['1m', 'min']:
            freq = 8
        else:
            raise Exception("1d and 1m frequency supported only")

        res = []
        pos = 0
        while True:
            data = func(freq, exchange, code, pos, 800)
            if not data:
                break
            res = data + res
            pos += 800

            if start and pd.to_datetime(data[0]['datetime']) < start:
                break
        try:
            df = self.api.to_df(res).drop(
                ['year', 'month', 'day', 'hour', 'minute'], axis=1)
            df['datetime'] = pd.to_datetime(df.datetime)
            df.set_index('datetime', inplace=True)
            if freq == 9:
                df.index = df.index.normalize()
        except ValueError:  # 未上市股票,无数据
            logger.warning("no k line data for {}".format(code))
            # return pd.DataFrame({
            #     'amount': [0],
            #     'close': [0],
            #     'open': [0],
            #     'high': [0],
            #     'low': [0],
            #     'vol': [0],
            #     'code': code
            # },
            #     index=[start]
            # )
            return pd.DataFrame()
        close = [df.close.values[-1]]
        if start:
            df = df.loc[lambda df: start <= df.index]
        if end:
            df = df.loc[lambda df: df.index.normalize() <= end]

        if df.empty:
            # return pd.DataFrame({
            #     'amount': [0],
            #     'close': close,
            #     'open': close,
            #     'high': close,
            #     'low': close,
            #     'vol': [0],
            #     'code': code
            # },
            #     index=[start]
            # )
            return df
        else:
            if int(df['vol'][-1]) <= 0 and end == df.index[-1] and len(
                    df) == 1:  # 成交量为0,当天返回的是没开盘的数据
                return pd.DataFrame()
            df['code'] = code
            return df

    def _get_transaction(self, code, date):
        res = []
        start = 0
        while True:
            data = self.api.get_history_transaction_data(
                get_stock_type(code), code, start, 2000, date)
            if not data:
                break
            start += 2000
            res = data + res

        if len(res) == 0:
            return pd.DataFrame()
        df = self.api.to_df(res).assign(date=date)
        df.loc[0, 'time'] = df.time[1]
        df.index = pd.to_datetime(str(date) + " " + df["time"])
        df['code'] = code
        return df.drop("time", axis=1)

    def time_and_price(self, code):
        start = 0
        res = []
        exchange = self.get_security_type(code)
        while True:
            data = self.api.get_transaction_data(exchange, code, start, 2000)
            if not data:
                break
            res = data + res
            start += 2000

        df = self.api.to_df(res)
        df.time = pd.to_datetime(
            str(pd.to_datetime('today').date()) + " " + df['time'])
        df.loc[0, 'time'] = df.time[1]
        return df.set_index('time')

    @classmethod
    def minute_bars_from_transaction(cls, transaction, freq):
        if transaction.empty:
            return pd.DataFrame()
        mask = transaction.index < transaction.index[0].normalize(
        ) + pd.Timedelta('12 H')

        def resample(transaction):
            if transaction.empty:
                return pd.DataFrame()
            data = transaction['price'].resample(freq,
                                                 label='right',
                                                 closed='left').ohlc()

            data['volume'] = transaction['vol'].resample(freq,
                                                         label='right',
                                                         closed='left').sum()
            data['code'] = transaction['code'][0]
            return data

        morning = resample(transaction[mask])
        afternoon = resample(transaction[~mask])
        if morning.empty and afternoon.empty:
            return pd.DataFrame()
        if not afternoon.empty:
            morning.index.values[-1] = afternoon.index[0] - pd.Timedelta(
                '1 min')

        df = pd.concat([morning, afternoon])

        return fillna(df)

    def _get_k_data(self, code, freq, sessions):
        trade_days = map(int, sessions.strftime("%Y%m%d"))
        if freq == '1m':
            freq = '1 min'

        if freq == '1d':
            freq = '24 H'

        res = []
        concurrent_count = self.concurrent_thread_count
        jobs = []
        for trade_day in trade_days:
            # df = Engine.minute_bars_from_transaction(self._get_transaction(code, trade_day), freq)
            reqevent = gevent.spawn(Engine.minute_bars_from_transaction,
                                    self._get_transaction(code, trade_day),
                                    freq)
            jobs.append(reqevent)
            if len(jobs) >= concurrent_count:
                gevent.joinall(jobs, timeout=30)
                for j in jobs:
                    if j.value is not None and not j.value.empty:
                        res.append(j.value)
                jobs.clear()
        gevent.joinall(jobs, timeout=30)
        for j in jobs:
            if j.value is not None and not j.value.empty:
                res.append(j.value)
        jobs.clear()
        if len(res) != 0:
            return pd.concat(res)
        return pd.DataFrame()

    def get_k_data(self, code, start, end, freq, check=True):
        if isinstance(start, str) or isinstance(end, str):
            start = pd.Timestamp(start)
            end = pd.Timestamp(end)
        if check:
            daily_bars = self.get_security_bars(code, '1d', start, end)
            if daily_bars is None or daily_bars.empty:
                return daily_bars
            sessions = daily_bars.index
        else:
            sessions = pd.bdate_range(start,
                                      end,
                                      weekmask='Mon Tue Wed Thu Fri')
        df = self._get_k_data(code, freq, sessions)

        def check_df(freq, df, daily_bars):
            if freq == '1m':
                need_check = pd.DataFrame({
                    'open':
                    df['open'].resample('1D').first(),
                    'high':
                    df['high'].resample('1D').max(),
                    'low':
                    df['low'].resample('1D').min(),
                    'close':
                    df['close'].resample('1D').last(),
                    'volume':
                    df['volume'].resample('1D').sum()
                }).dropna()
            else:
                need_check = df

            if daily_bars.shape[0] != need_check.shape[0]:
                logger.warning("{} merged {}, expected {}".format(
                    code, need_check.shape[0], daily_bars.shape[0]))
                need_check = fillna(
                    need_check.reindex(daily_bars.index, copy=False))
            diff = daily_bars[['open',
                               'close']] == need_check[['open', 'close']]
            res = (diff.open) & (diff.close)
            sessions = res[res == False].index
            return sessions

        if not df.empty:
            if check:
                sessions = check_df(freq, df, daily_bars)
                if sessions.shape[0] != 0:
                    logger.info(
                        "fixing data for {}-{} with sessions: {}".format(
                            code, freq, sessions))
                    fix = self._get_k_data(code, freq, sessions)
                    df.loc[fix.index] = fix
            return df
        return df
Exemplo n.º 12
0
class TdxHelper:
    ip_list = [{
        'ip': '119.147.212.81',
        'port': 7709
    }, {
        'ip': '60.12.136.250',
        'port': 7709
    }]

    def __init__(self):
        #连接tdx接口
        self.api = TdxHq_API()
        if not self.api.connect('60.12.136.250', 7709):
            print("服务器连接失败!")

        # pandas数据显示设置
        pd.set_option('display.max_columns', None)  # 显示所有列
        #pd.set_option('display.max_rows', None)  # 显示所有行

        # mysql对象
        self.mysql = mysqlHelper(config.mysql_host, config.mysql_username,
                                 bluedothe.mysql_password, config.mysql_dbname)

        # pandas的mysql对象
        self.engine = create_engine(
            f'mysql+pymysql://{config.mysql_username}:{bluedothe.mysql_password}@{config.mysql_host}/{config.mysql_dbname}?charset=utf8'
        )

    #断开tdx接口连接
    def close_connect(self):
        self.api.disconnect()

    #获取k线,最后一个参数day,说明需要获取的数量,本接口只获取从最近交易日往前的数据
    #输入参数:五个参数分别为:category(k线),市场代码(0:深圳,1:上海),股票代码,开始位置(从最近交易日向前取,0表示最近交易日),返回的记录条数
    #K线种类:  0 5分钟K线; 1 15分钟K线; 2 30分钟K线; 3 1小时K线; 4 日K线;5 周K线;6 月K线;7 1分钟;8 1分钟K线; 9 日K线;10 季K线;11 年K线
    #返回值:open,close,high,low,vol,amount,year,month,day,hour,minute,datetime
    # csv格式:code,ts_code,trade_date(缩写),trade_time,time_index,open,high,low,close,amount,volume
    def get_security_bars(self, category, market, code, start=0, count=240):
        dict = {0: 'SZ', 1: 'SH'}
        ts_code = code + "." + dict[market]
        order = [
            'code', 'ts_code', 'trade_date', 'trade_time', 'time_index',
            'open', 'high', 'low', 'close', 'amount', 'volume'
        ]
        #df = self.api.get_security_bars(9, 0, '000001', 0, 10)  # 返回普通list
        df = self.api.to_df(
            self.api.get_security_bars(category, market, code, start,
                                       count))  # 返回DataFrame
        if df.empty: return df

        df.insert(0, 'ts_code', ts_code)
        df.insert(0, 'code', code)
        df['trade_time'] = df['datetime'].apply(lambda x: str(x)[11:19])
        df['time_index'] = df['trade_time'].apply(
            lambda x: datatime_util.stockTradeTime2Index(x))
        df['trade_date'] = df['datetime'].apply(
            lambda x: (str(x)[0:10]).replace('-', ''))
        df.rename(columns={'vol': 'volume'}, inplace=True)
        df.drop(['year', 'month', 'day', 'hour', 'minute', 'datetime'],
                axis=1,
                inplace=True)
        df['volume'] = df['volume'].apply(lambda x: int(x))  #取整
        df.loc[df['amount'] == 5.877471754111438e-39,
               'amount'] = 0  #列值根据条件筛选后修改为0
        df = df[order]

        filename = config.tdx_csv_minline1_all + ts_code + ".csv"
        if os.path.isfile(filename):
            df.to_csv(filename,
                      index=False,
                      mode='a',
                      header=False,
                      sep=',',
                      encoding="utf_8_sig")
        else:
            df.to_csv(filename,
                      index=False,
                      mode='w',
                      header=True,
                      sep=',',
                      encoding="utf_8_sig")
            print("新增加的一分钟all股票数据:", filename)

    # 获取1分钟k线,最后一个参数说明需要获取的数量,本接口只获取从最近交易日往前的数据
    # 输入参数:五个参数分别为:category(k线),市场代码(0:深圳,1:上海),股票代码,开始位置(从最近交易日向前取,0表示最近交易日),返回的记录条数
    # K线种类:  0 5分钟K线; 1 15分钟K线; 2 30分钟K线; 3 1小时K线; 4 日K线;5 周K线;6 月K线;7 1分钟;8 1分钟K线; 9 日K线;10 季K线;11 年K线
    # 返回值:open,close,high,low,vol,amount,year,month,day,hour,minute,datetime
    # csv格式:code,ts_code,trade_date(缩写),trade_time,time_index,open,high,low,close,amount,volume
    def get_security_bars_minute1(self, category, market, code, start, count):
        dict = {0: 'SZ', 1: 'SH'}
        ts_code = code + "." + dict[market]
        order = [
            'code', 'ts_code', 'trade_date', 'trade_time', 'time_index',
            'open', 'high', 'low', 'close', 'amount', 'volume'
        ]
        # df = self.api.get_security_bars(9, 0, '000001', 0, 10)  # 返回普通list
        df = self.api.to_df(
            self.api.get_security_bars(category, market, code, start,
                                       count))  # 返回DataFrame
        if df.empty: return

        df.insert(0, 'ts_code', ts_code)
        df.insert(0, 'code', code)
        df['trade_time'] = df['datetime'].apply(lambda x: str(x)[11:19])
        df['time_index'] = df['trade_time'].apply(
            lambda x: datatime_util.stockTradeTime2Index(x))
        df['trade_date'] = df['datetime'].apply(
            lambda x: (str(x)[0:10]).replace('-', ''))
        df.rename(columns={'vol': 'volume'}, inplace=True)
        df.drop(['year', 'month', 'day', 'hour', 'minute', 'datetime'],
                axis=1,
                inplace=True)
        df['volume'] = df['volume'].apply(lambda x: int(x))  # 取整
        df.loc[df['amount'] == 5.877471754111438e-39,
               'amount'] = 0  # 列值根据条件筛选后修改为0
        df = df[order]

        #过滤掉停牌的数据,在tdx中,停牌股票也能取到数据,价格是前一交易日的收盘价,所以只能用成交量或成交金额为0来判断
        #1按日期分组后取出成交量为0的日期;2循环过滤掉成交量为0的日期的数据。
        dfg = df.groupby(by='trade_date').mean()  #分组
        dfg['trade_date'] = dfg.index
        dfg = dfg[dfg.volume == 0]  #条件过滤,保留满足条件的数据
        for trade_date in dfg['trade_date'].values:
            df = df[(df['trade_date'] != trade_date)]  # 每个条件要用括号()括起来

        return df

    #可以获取多只股票的行情信息
    #返回值:market,code,active1,price,last_close,open,high,low,reversed_bytes0,reversed_bytes1,vol,cur_vol,amount,s_vol,
    #reversed_bytes2,reversed_bytes3,bid1,ask1,bid_vol1,ask_vol1,bid2,ask2,bid_vol2,ask_vol2,bid3,ask3,bid_vol3,ask_vol3,bid4,
    #ask4,bid_vol4,ask_vol4,bid5,ask5,bid_vol5,ask_vol5,reversed_bytes4,reversed_bytes5,reversed_bytes6,reversed_bytes7,
    #reversed_bytes8,reversed_bytes9,active2
    def get_security_quotes(self):
        df = self.api.to_df(
            self.api.get_security_quotes([(0, '000001'), (1, '600300')]))
        print(df)

    # 获取市场股票数量
    #返回值:value
    def get_security_count(self):
        df = self.api.to_df(self.api.get_security_count(0))  #0 - 深圳, 1 - 上海
        print(df)

    # 获取股票列表,返回值里面除了股票,还有国债等
    #返回值:code,volunit,decimal_point,name,pre_close
    def get_security_list(self):
        df = self.api.to_df(self.api.get_security_list(
            0, 10000))  # 市场代码, 起始位置 如: 0,0 或 1,100
        print(df)

    # 获取指数k线
    #输入参数同股票k线接口
    # 返回值:open,close,high,low,vol,amount,year,month,day,hour,minute,datetime,up_count  down_count
    def get_index_bars(self):
        index_dict_cn = {
            "上证指数": "999999",
            "深证成指": "399001",
            "中小板指": "399005",
            "创业板指": "399006",
            "深证综指": "399106",
            "上证50": "000016",
            "沪深300": "000300"
        }
        index_dict = {
            "sh": "999999",
            "sz": "399001",
            "zxb": "399005",
            "cyb": "399006",
            "szz": "399106",
            "sz50": "000016",
            "hs300": "000300"
        }
        for key in index_dict.keys():
            df = self.api.to_df(
                self.api.get_index_bars(9, 1, index_dict[key], 0, 2))
            print(df)

    # 查询分时行情,最近交易日的数据,一分钟一条记录
    #返回值:price,vol
    def get_minute_time_data(self):
        df = self.api.to_df(self.api.get_minute_time_data(
            1, '600300'))  #市场代码, 股票代码
        print(df)

    # 查询历史分时行情
    # 返回值:price,vol
    def get_history_minute_time_data(self):
        df = self.api.to_df(
            self.api.get_history_minute_time_data(TDXParams.MARKET_SH,
                                                  '603887',
                                                  20200420))  #市场代码, 股票代码,时间
        print(df)

    # 查询分笔成交,最近交易日数据
    #返回值:time,price,vol,num,buyorsell
    def get_transaction_data(self):
        df = self.api.to_df(
            self.api.get_transaction_data(TDXParams.MARKET_SZ, '000001', 0,
                                          30))  #市场代码, 股票代码,起始位置, 数量
        print(df)

    # 查询历史分笔成交
    #返回值:time,price,vol,buyorsell
    def get_history_transaction_data(self):
        df = self.api.to_df(
            self.api.get_history_transaction_data(
                TDXParams.MARKET_SZ, '000001', 0, 10,
                20170209))  #市场代码, 股票代码,起始位置,日期 数量
        print(df)

    # 查询公司信息目录,返回的不是具体数据
    #返回值:name,filename,start,length
    def get_company_info_category(self):
        df = self.api.to_df(
            self.api.get_company_info_category(TDXParams.MARKET_SZ,
                                               '000001'))  #市场代码, 股票代码
        print(df)

    # 读取公司信息详情
    #返回值:value
    def get_company_info_content(self):
        df = self.api.to_df(
            self.api.get_company_info_content(
                0, '000001', '000001.txt', 0,
                1000))  #市场代码, 股票代码, 文件名, 起始位置, 数量
        print(df)

    # 读取除权除息信息
    #返回值:year,month,day,category,name,fenhong,peigujia,songzhuangu,peigu
    def get_xdxr_info(self):
        df = self.api.to_df(self.api.get_xdxr_info(1, '600300'))  #市场代码, 股票代码
        print(df)

    # 读取财务信息
    #返回值:market,code,liutongguben,province,industry,updated_date,ipo_date,zongguben,guojiagu,faqirenfarengu,farengu,bgu,hgu,zhigonggu,
    #zongzichan,liudongzichan,gudingzichan,wuxingzichan,gudongrenshu,liudongfuzhai,changqifuzhai,zibengongjijin,jingzichan,zhuyingshouru,
    #zhuyinglirun,yingshouzhangkuan,yingyelirun,touzishouyu,jingyingxianjinliu,zongxianjinliu,cunhuo,lirunzonghe,shuihoulirun,jinglirun,weifenlirun,baoliu1,baoliu2
    def get_finance_info(self):
        df = self.api.to_df(self.api.get_finance_info(1,
                                                      '600300'))  #市场代码, 股票代码
        print(df)

    # 读取k线信息
    # 返回值:value
    def get_k_data(self):
        df = self.api.to_df(
            self.api.get_k_data('600300', '2017-07-03',
                                '2017-07-10'))  #股票代码, 开始时间, 结束时间
        print(df)

    # 读取板块信息
    #返回值:blockname, block_type, code_index, code
    """   BLOCK_SZ = "block_zs.dat";BLOCK_FG = "block_fg.dat";BLOCK_GN = "block_gn.dat";BLOCK_DEFAULT = "block.dat"  """

    def get_and_parse_block_info(self):
        ##指数板块 风格板块  概念板块  一般板块
        block_filename = [
            "block_zs.dat", "block_fg.dat", "block_gn.dat", "block.dat"
        ]
        for block in block_filename:
            df = self.api.to_df(
                self.api.get_and_parse_block_info(block))  #板块文件名称
            filename = config.tdx_csv_block + block[0:-4] + ".csv"
            if os.path.isfile(filename):
                os.remove(filename)
                df.to_csv(filename,
                          index=False,
                          mode='w',
                          header=True,
                          sep=',',
                          encoding="utf_8_sig")
            else:
                df.to_csv(filename,
                          index=False,
                          mode='w',
                          header=True,
                          sep=',',
                          encoding="utf_8_sig")

    # 读取板块信息,多个类型封装到一个df对象中返回
    # 返回值:data_source, block_category, block_type, block_name, block_code, ts_code, create_time
    def update_block_member(self):
        ##指数板块 风格板块  概念板块  一般板块
        #block_filename = ["block_zs.dat", "block_fg.dat", "block_gn.dat", "block.dat"]
        block_filename = ["block_zs.dat", "block_fg.dat",
                          "block_gn.dat"]  #block.dat中的数据都包含在其他版块里了,这个可以去掉
        data_source = "tdx"
        dfall = None
        for block in block_filename:
            df = self.api.to_df(
                self.api.get_and_parse_block_info(block))  # 板块文件名称
            df['data_source'] = data_source
            if block == "block.dat":
                df['block_category'] = data_source + ".yb"
            else:
                df['block_category'] = data_source + "." + block[6:8]
            df['block_type'] = df['block_type'].map(lambda x: str(x))
            df['block_type'] = df['block_category'].str.cat(
                df['block_type'], sep=".")  #, sep = "."
            df['block_code'] = ""  #使用pd直接插入到数据库时,字段不能是None值
            df['ts_code'] = df['code'].apply(lambda x: x + ".SH"
                                             if x[0:1] == "6" else x + ".SZ")
            if (dfall is not None) and (not dfall.empty):
                dfall = dfall.append(df, ignore_index=True)
            else:
                dfall = df
        if (dfall is None) or (dfall.empty): return None

        dfall.rename(columns={'blockname': 'block_name'}, inplace=True)
        dfall['create_time'] = time.strftime('%Y-%m-%d %H:%M:%S',
                                             time.localtime(time.time()))
        dfall = dfall[[
            'data_source', 'block_category', 'block_type', 'block_name',
            'block_code', 'ts_code', 'create_time'
        ]]  #列重排序

        #分组统计
        dfg = dfall.groupby(by=[
            'data_source', 'block_category', 'block_type', 'block_name',
            'block_code'
        ],
                            as_index=False).count()  # 分组求每组数量
        dfg.rename(columns={'ts_code': 'member_count'},
                   inplace=True)  #ts_code列数值为汇总值,需要重命名
        dfg['create_time'] = time.strftime(
            '%Y-%m-%d %H:%M:%S',
            time.localtime(time.time()))  #create_time列数值为汇总值,需要重新赋值
        delete_condition = f"data_source = '{data_source}'"
        mysql_script.df2db_update(delete_condition=delete_condition,
                                  block_basic_df=dfg,
                                  block_member_df=dfall)
        return (len(dfg), len(dfall))

    #获取一段时间的1分钟数据,因为每次调用接口只能返回3天的分钟数据(240*3),需要分多次调用
    #返回值:0没有提取到数据;1提取到数据
    def get_minute1_data(self, category, market, code, start_date, end_date):
        init_start_date = start_date.replace('-', '')
        init_end_date = end_date.replace('-', '')
        day = datatime_util.diffrentPeriod(datatime_util.DAILY, start_date,
                                           end_date)
        df = self.get_security_bars_minute1(category, market, code, 0,
                                            240 * 3)  # 返回DataFrame
        if df is None or df.empty:
            print('{0}没有交易数据'.format(code))
            return 0
        print(market, '--', code, '--', start_date, '--', end_date)
        #print("最大值:",df.groupby('datetime').max())
        #print(df.describe())   #df数据统计
        data_start_date = df.min()['trade_date']
        data_end_date = df.max()['trade_date']

        start_date = start_date.replace('-', '')
        end_date = end_date.replace('-', '')
        if data_end_date < start_date or end_date < data_start_date:
            print("采集时间在数据范围之外,退出函数")
            return 0
        elif end_date > data_end_date:
            end_date = data_end_date

        if start_date < data_start_date:
            #最近三天的数据中,去掉无用的数据后即是最终数据
            #需要取的数据还有三天前的数据,需要继续向前取
            n = (day - 3) // 3
            m = (day - 3) % 3
            for i in range(0, n):
                dfn = self.get_security_bars_minute1(category, market, code,
                                                     240 * 3 * (i + 1),
                                                     240 * 3)  # 返回DataFrame
                if (dfn is not None) and (not dfn.empty):
                    df = dfn.append(df, ignore_index=True)
            if m > 0:
                dfn = self.get_security_bars_minute1(category, market, code,
                                                     240 * 3 * (n + 1),
                                                     240 * m)
                if (dfn is not None) and (not dfn.empty):
                    df = dfn.append(df, ignore_index=True)

        df = df.sort_values(by=['trade_date', 'time_index'],
                            axis=0,
                            ascending=True)
        #过滤掉start_date, end_date之外的数据
        df = df[(df['trade_date'] >= str(init_start_date)) &
                (df['trade_date'] <= str(init_end_date))]  #每个条件要用括号()括起来

        dict = {0: 'SZ', 1: 'SH'}
        ts_code = code + "." + dict[market]
        filename = config.tdx_csv_minline1_all + ts_code + ".csv"
        if os.path.isfile(filename):
            df.to_csv(filename,
                      index=False,
                      mode='a',
                      header=False,
                      sep=',',
                      encoding="utf_8_sig")
            print("更新一分钟all股票数据:", filename)
        else:
            df.to_csv(filename,
                      index=False,
                      mode='w',
                      header=True,
                      sep=',',
                      encoding="utf_8_sig")
            print("新增加的一分钟all股票数据:", filename)
Exemplo n.º 13
0
class PYTDXService:
    """pytdx数据服务类"""

    def __init__(self, client):
        """Constructor"""
        self.connected = False  # 数据服务连接状态
        self.hq_api = None  # 行情API
        self.client = client  # mongo client

    def connect_api(self):
        """连接API"""
        # 连接增强行情API并检查连接情况
        try:
            if not self.connected:
                host = SETTINGS["TDX_HOST"]
                port = SETTINGS["TDX_PORT"]
                self.hq_api = TdxHq_API()
                self.hq_api.connect(host, port)
                self.connected = True
            return True
        except Exception:
            raise ConnectionError("pytdx连接错误")

    def get_realtime_data(self, symbol: str):
        """获取股票实时数据"""
        try:
            symbols = self.generate_symbols(symbol)
            df = self.hq_api.to_df(self.hq_api.get_security_quotes(symbols))
            data = self.client["stocks"]["security"].find_one(
                {"code": symbols[0][1], "market": str(symbols[0][0])}
            )
            # 处理基金价格:通达信基金数据是实际价格的10倍
            if data["decimal_point"] == 3:
                df["price"] = df["price"] / 10
                df["last_close"] = df["last_close"] / 10
                df["open"] = df["open"] / 10
                df["high"] = df["high"] / 10
                df["low"] = df["low"] / 10
                df["ask1"] = df["ask1"] / 10
                df["bid1"] = df["bid1"] / 10
                df["ask2"] = df["ask2"] / 10
                df["bid2"] = df["bid2"] / 10
                df["ask3"] = df["ask3"] / 10
                df["bid3"] = df["bid3"] / 10
                df["ask4"] = df["ask4"] / 10
                df["bid4"] = df["bid4"] / 10
                df["ask5"] = df["ask5"] / 10
                df["bid5"] = df["bid5"] / 10
            return df
        except Exception:
            raise ValueError("股票数据获取失败")

    def get_history_transaction_data(self, symbol, date):
        """
        查询历史分笔数据
        get_history_transaction_data(TDXParams.MARKET_SZ, '000001', 0, 10, 20170209)
        参数:市场代码, 股票代码, 起始位置, 数量, 日期
        输出[time, price, vol, buyorsell(0:buy, 1:sell, 2:平)]
        """
        # 获得标的
        code, market = self.check_symbol(symbol)

        # 设置参数
        check_date = int(date)
        count = 2000
        data_list = []
        position = [6000, 4000, 2000, 0]
        for start in position:
            data = self.hq_api.to_df(
                self.hq_api.get_history_transaction_data(
                    market, code, start, count, check_date
                )
            )
            data_list.append(data)

        df = pd.concat(data_list)
        df.drop_duplicates(inplace=True)
        return df

    @staticmethod
    def generate_symbols(symbol: str):
        """组装symbols数据,pytdx接收的是以市场代码和标的代码组成的元祖的list"""
        new_symbols = []
        code, exchange = symbol.split(".")
        new_symbols.append((exchange_map[exchange], code))

        return new_symbols

    @staticmethod
    def check_symbol(symbol: str):
        """检查标的格式"""
        if symbol:
            code, market = symbol.split(".")
            market = exchange_map.get(market)
            return code, market

        else:
            return False

    def close(self):
        """数据服务关闭"""
        self.connected = False
        self.hq_api.disconnect()
class TdxStockData(object):
    best_ip = None
    symbol_exchange_dict = {}  # tdx合约与vn交易所的字典
    symbol_market_dict = {}  # tdx合约与tdx市场的字典

    # ----------------------------------------------------------------------
    def __init__(self, strategy=None):
        """
        构造函数
        :param strategy: 上层策略,主要用与使用write_log()
        """
        self.api = None

        self.connection_status = False  # 连接状态

        self.strategy = strategy

        self.connect()

    def write_log(self, content):
        if self.strategy:
            self.strategy.write_log(content)
        else:
            print(content)

    def write_error(self, content):
        if self.strategy:
            self.strategy.write_log(content, level=ERROR)
        else:
            print(content, file=sys.stderr)

    def connect(self):
        """
        连接API
        :return:
        """
        # 创建api连接对象实例
        try:
            if self.api is None or not self.connection_status:
                self.write_log(u'开始连接通达信股票行情服务器')
                self.api = TdxHq_API(heartbeat=True,
                                     auto_retry=True,
                                     raise_exception=True)

                # 选取最佳服务器
                if TdxStockData.best_ip is None:
                    from pytdx.util.best_ip import select_best_ip
                    TdxStockData.best_ip = select_best_ip()

                self.api.connect(self.best_ip.get('ip'),
                                 self.best_ip.get('port'))
                self.write_log(f'创建tdx连接, : {self.best_ip}')
                TdxStockData.connection_status = True

        except Exception as ex:
            self.write_log(u'连接服务器tdx异常:{},{}'.format(str(ex),
                                                      traceback.format_exc()))
            return

    def disconnect(self):
        if self.api is not None:
            self.api = None

    # ----------------------------------------------------------------------
    def get_bars(self,
                 symbol,
                 period,
                 callback,
                 bar_is_completed=False,
                 bar_freq=1,
                 start_dt=None):
        """
        返回k线数据
        symbol:股票 000001.XG
        period: 周期: 1min,5min,15min,30min,1hour,1day,
        """
        if self.api is None:
            self.connect()

        # 新版一劳永逸偷懒写法zzz
        if '.' in symbol:
            tdx_code, market_str = symbol.split('.')
            market_code = 1 if market_str.upper() == 'XSHG' else 0
            self.symbol_exchange_dict.update({tdx_code:
                                              symbol})  # tdx合约与vn交易所的字典
            self.symbol_market_dict.update({tdx_code:
                                            market_code})  # tdx合约与tdx市场的字典
        else:
            market_code = get_tdx_market_code(symbol)
            tdx_code = symbol
            self.symbol_exchange_dict.update({symbol:
                                              symbol})  # tdx合约与vn交易所的字典
            self.symbol_market_dict.update({symbol:
                                            market_code})  # tdx合约与tdx市场的字典

        # https://github.com/rainx/pytdx/issues/33
        # 0 - 深圳, 1 - 上海

        ret_bars = []

        if period not in PERIOD_MAPPING.keys():
            self.write_error(u'{} 周期{}不在下载清单中: {}'.format(
                datetime.now(), period, list(PERIOD_MAPPING.keys())))
            # print(u'{} 周期{}不在下载清单中: {}'.format(datetime.now(), period, list(PERIOD_MAPPING.keys())))
            return False, ret_bars

        if self.api is None:
            return False, ret_bars

        tdx_period = PERIOD_MAPPING.get(period)

        if start_dt is None:
            self.write_log(u'没有设置开始时间,缺省为10天前')
            qry_start_date = datetime.now() - timedelta(days=10)
            start_dt = qry_start_date
        else:
            qry_start_date = start_dt
        end_date = datetime.now()
        if qry_start_date > end_date:
            qry_start_date = end_date

        self.write_log('{}开始下载tdx股票:{} {}数据, {} to {}.'.format(
            datetime.now(), tdx_code, tdx_period, qry_start_date, end_date))

        try:
            _start_date = end_date
            _bars = []
            _pos = 0
            while _start_date > qry_start_date:
                _res = self.api.get_security_bars(
                    category=PERIOD_MAPPING[period],
                    market=market_code,
                    code=tdx_code,
                    start=_pos,
                    count=QSIZE)
                if _res is not None:
                    _bars = _res + _bars
                _pos += QSIZE
                if _res is not None and len(_res) > 0:
                    _start_date = _res[0]['datetime']
                    _start_date = datetime.strptime(_start_date,
                                                    '%Y-%m-%d %H:%M')
                    self.write_log(u'分段取数据开始:{}'.format(_start_date))
                else:
                    break
            if len(_bars) == 0:
                self.write_error('{} Handling {}, len1={}..., continue'.format(
                    str(datetime.now()), tdx_code, len(_bars)))
                return False, ret_bars

            current_datetime = datetime.now()
            data = self.api.to_df(_bars)
            data = data.assign(datetime=to_datetime(data['datetime']))
            data = data.assign(ticker=symbol)
            data['symbol'] = symbol
            data = data.drop(
                ['year', 'month', 'day', 'hour', 'minute', 'price', 'ticker'],
                errors='ignore',
                axis=1)
            data = data.rename(index=str, columns={'amount': 'volume'})
            if len(data) == 0:
                print('{} Handling {}, len2={}..., continue'.format(
                    str(datetime.now()), tdx_code, len(data)))
                return False, ret_bars

            # 通达信是以bar的结束时间标记的,vnpy是以bar开始时间标记的,所以要扣减bar本身的分钟数
            data['datetime'] = data['datetime'].apply(lambda x: x - timedelta(
                minutes=NUM_MINUTE_MAPPING.get(period, 1)))
            data['trading_date'] = data['datetime'].apply(
                lambda x: (x.strftime('%Y-%m-%d')))
            data['date'] = data['datetime'].apply(lambda x:
                                                  (x.strftime('%Y-%m-%d')))
            data['time'] = data['datetime'].apply(lambda x:
                                                  (x.strftime('%H:%M:%S')))
            data = data.set_index('datetime', drop=False)

            for index, row in data.iterrows():
                add_bar = BarData()
                try:
                    add_bar.symbol = symbol
                    add_bar.datetime = index
                    add_bar.date = row['date']
                    add_bar.time = row['time']
                    add_bar.trading_date = row['trading_date']
                    add_bar.open = float(row['open'])
                    add_bar.high = float(row['high'])
                    add_bar.low = float(row['low'])
                    add_bar.close = float(row['close'])
                    add_bar.volume = float(row['volume'])
                except Exception as ex:
                    self.write_error(
                        'error when convert bar:{},ex:{},t:{}'.format(
                            row, str(ex), traceback.format_exc()))
                    # print('error when convert bar:{},ex:{},t:{}'.format(row, str(ex), traceback.format_exc()))
                    return False

                if start_dt is not None and index < start_dt:
                    continue
                ret_bars.append(add_bar)

                if callback is not None:
                    freq = bar_freq
                    bar_is_completed = True
                    if period != '1min' and index == data['datetime'][-1]:
                        # 最后一个bar,可能是不完整的,强制修改
                        # - 5min修改后freq基本正确
                        # - 1day在VNPY合成时不关心已经收到多少Bar, 所以影响也不大
                        # - 但其它分钟周期因为不好精确到每个品种, 修改后的freq可能有错
                        if index > current_datetime:
                            bar_is_completed = False
                            # 根据秒数算的话,要+1,例如13:31,freq=31,第31根bar
                            freq = NUM_MINUTE_MAPPING[period] - int(
                                (index - current_datetime).total_seconds() /
                                60)
                    callback(add_bar, bar_is_completed, freq)

            return True, ret_bars
        except Exception as ex:
            self.write_error('exception in get:{},{},{}'.format(
                tdx_code, str(ex), traceback.format_exc()))
            # print('exception in get:{},{},{}'.format(tdx_symbol,str(ex), traceback.format_exc()))
            self.write_log(u'重置连接')
            TdxStockData.api = None
            self.connect()
            return False, ret_bars

    def save_cache(self, cache_folder, cache_symbol, cache_date, data_list):
        """保存文件到缓存"""

        os.makedirs(cache_folder, exist_ok=True)

        if not os.path.exists(cache_folder):
            self.write_error('缓存目录不存在:{},不能保存'.format(cache_folder))
            return
        cache_folder_year_month = os.path.join(cache_folder, cache_date[:6])
        os.makedirs(cache_folder_year_month, exist_ok=True)

        save_file = os.path.join(cache_folder_year_month,
                                 '{}_{}.pkb2'.format(cache_symbol, cache_date))
        with bz2.BZ2File(save_file, 'wb') as f:
            pickle.dump(data_list, f)
            self.write_log(u'缓存成功:{}'.format(save_file))

    def load_cache(self, cache_folder, cache_symbol, cache_date):
        """加载缓存数据"""
        if not os.path.exists(cache_folder):
            self.write_error('缓存目录:{}不存在,不能读取'.format(cache_folder))
            return None
        cache_folder_year_month = os.path.join(cache_folder, cache_date[:6])
        if not os.path.exists(cache_folder_year_month):
            self.write_error('缓存目录:{}不存在,不能读取'.format(cache_folder_year_month))
            return None

        cache_file = os.path.join(
            cache_folder_year_month,
            '{}_{}.pkb2'.format(cache_symbol, cache_date))
        if not os.path.isfile(cache_file):
            self.write_error('缓存文件:{}不存在,不能读取'.format(cache_file))
            return None
        with bz2.BZ2File(cache_file, 'rb') as f:
            data = pickle.load(f)
            return data

        return None

    def get_history_transaction_data(self, symbol, date, cache_folder=None):
        """获取当某一交易日的历史成交记录"""
        ret_datas = []
        if isinstance(date, datetime):
            date = date.strftime('%Y%m%d')
        if isinstance(date, str):
            date = int(date)

        cache_symbol = symbol
        cache_date = str(date)

        max_data_size = sys.maxsize
        # symbol.exchange => tdx_code market_code
        if '.' in symbol:
            tdx_code, market_str = symbol.split('.')
            market_code = 1 if market_str.upper() == 'XSHG' else 0
            self.symbol_exchange_dict.update({tdx_code:
                                              symbol})  # tdx合约与vn交易所的字典
            self.symbol_market_dict.update({tdx_code:
                                            market_code})  # tdx合约与tdx市场的字典
        else:
            market_code = get_tdx_market_code(symbol)
            tdx_code = symbol
            self.symbol_exchange_dict.update({symbol:
                                              symbol})  # tdx合约与vn交易所的字典
            self.symbol_market_dict.update({symbol:
                                            market_code})  # tdx合约与tdx市场的字典

        q_size = QSIZE * 5
        # 每秒 2个, 10小时
        max_data_size = 1000000
        # 优先从缓存加载
        if cache_folder:
            buffer_data = self.load_cache(cache_folder, cache_symbol,
                                          cache_date)
            if buffer_data:
                return True, buffer_data

        self.write_log(u'开始下载{} 历史{}分笔数据'.format(date, symbol))

        is_today = False
        if date == int(datetime.now().strftime('%Y%m%d')):
            is_today = True

        try:
            _datas = []
            _pos = 0

            while (True):
                if is_today:
                    _res = self.api.get_transaction_data(
                        market=self.symbol_market_dict[symbol],
                        code=symbol,
                        start=_pos,
                        count=q_size)
                else:
                    _res = self.api.get_history_transaction_data(
                        market=self.symbol_market_dict[symbol],
                        date=date,
                        code=symbol,
                        start=_pos,
                        count=q_size)
                last_dt = None
                if _res is not None:
                    _datas = _res + _datas
                _pos += min(q_size, len(_res))

                if _res is not None and len(_res) > 0:
                    self.write_log(u'分段取{}分笔数据:{} ~{}, {}条,累计:{}条'.format(
                        date, _res[0]['time'], _res[-1]['time'], len(_res),
                        _pos))
                else:
                    break

                if len(_datas) >= max_data_size:
                    break

            if len(_datas) == 0:
                self.write_error(u'{}分笔成交数据获取为空'.format(date))
                return False, _datas

            for d in _datas:
                dt = datetime.strptime(
                    str(date) + ' ' + d.get('time'), '%Y%m%d %H:%M')
                if last_dt is None or last_dt < dt:
                    last_dt = dt
                else:
                    if last_dt < dt + timedelta(seconds=59):
                        last_dt = last_dt + timedelta(seconds=1)
                d.update({'datetime': last_dt})
                d.update({'volume': d.pop('vol', 0)})
                d.update({'trading_date': last_dt.strftime('%Y-%m-%d')})

            _datas = sorted(_datas, key=lambda s: s['datetime'])

            # 缓存文件
            if cache_folder and not is_today:
                self.save_cache(cache_folder, cache_symbol, cache_date, _datas)

            return True, _datas

        except Exception as ex:
            self.write_error(
                'exception in get_transaction_data:{},{},{}'.format(
                    symbol, str(ex), traceback.format_exc()))
            return False, ret_datas
Exemplo n.º 15
0
    print("导入股票代码表")
    #import_stock_name(connect, api, 'SH', quotations)
    #import_stock_name(connect, api, 'SZ', quotations)

    add_count = 0

    print("\n导入上证日线数据")
    #add_count = import_data(connect, 'SH', 'DAY', ['bond'], api, dest_dir, progress=ProgressBar)
    print("\n导入数量:", add_count)

    print("\n导入深证日线数据")
    #add_count = import_data(connect, 'SZ', 'DAY', ['stock'], api, dest_dir, progress=ProgressBar)
    print("\n导入数量:", add_count)

    for i in range(10):
        x = api.get_history_transaction_data(TDXParams.MARKET_SZ, '000001',
                                             (9 - i) * 2000, 2000, 20181112)
        #x = api.get_transaction_data(TDXParams.MARKET_SZ, '000001', (9-i)*800, 800)
        if x is not None and len(x) > 0:
            print(i, len(x), x[0], x[-1])

    api.disconnect()

    connect.close()

    endtime = time.time()
    print("\nTotal time:")
    print("%.2fs" % (endtime - starttime))
    print("%.2fm" % ((endtime - starttime) / 60))
            # \x80\x03 = 14:56

            hour, minute, pos = get_time(body_buf, pos)

            price_raw, pos = get_price(body_buf, pos)
            vol, pos = get_price(body_buf, pos)
            buyorsell, pos = get_price(body_buf, pos)
            _, pos = get_price(body_buf, pos)

            last_price = last_price + price_raw

            tick = OrderedDict(
                [
                    ("time", "%02d:%02d" % (hour, minute)),
                    ("price", float(last_price)/100),
                    ("vol", vol),
                    ("buyorsell", buyorsell),
                ]
            )

            ticks.append(tick)

        return ticks


if __name__ == '__main__':
    from pytdx.hq import TdxHq_API
    api = TdxHq_API()
    with api.connect():
        print(api.to_df(api.get_history_transaction_data(0, '000001', 0, 10, 20170811)))