def __init__(self, portfolio, vtSymbol):
        self.portfolio = portfolio      # 投资组合
        self.vtSymbol = vtSymbol        # 合约代码
        self.am = ArrayManager()        # K线容器
        self.bar = None                 # 最新K线

        # 策略参数
        self.initBars = 60              # 初始化数据所用的天数
        self.entryWindow = 20           # 入场通道周期数
        self.exitWindow = 50            # 出场通道周期数
        self.atrWindow = 5              # 计算ATR周期数
        self.profitCheck = True         # 是否检查上一笔盈利
        self.minx = 'day'

        # 策略临时变量
        self.atrVolatility = 0          # ATR波动率
        self.entryUp = 0                # 入场通道
        self.entryDown = 0
        self.exitUp = 0                 # 出场通道
        self.exitDown = 0

        self.longEntry1 = 0             # 多头入场位
        self.longEntry2 = 0
        self.longEntry3 = 0
        self.longEntry4 = 0
        self.longStop = 0               # 多头止损位

        self.shortEntry1 = 0            # 空头入场位
        self.shortEntry2 = 0
        self.shortEntry3 = 0
        self.shortEntry4 = 0
        self.shortStop = 0              # 空头止损位

        # 需要持久化保存的变量
        self.unit = 0

        self.result = None              # 当前的交易
        self.resultList = []            # 交易列表

        # 载入历史数据,并采用回放计算的方式初始化策略数值
        initData = self.portfolio.engine._bc_loadInitBar(self.vtSymbol, self.initBars, self.minx)
        for bar in initData:
            self.bar = bar
            self.am.updateBar(bar)
    def __init__(self, portfolio, vtSymbol):
        self.portfolio = portfolio  # 投资组合
        self.vtSymbol = vtSymbol  # 合约代码
        self.am = ArrayManager()  # K线容器
        self.bar = None  # 最新K线

        # 策略参数
        self.atrLength = 1  # 计算ATR指标的窗口数
        self.atrMaLength = 14  # 计算ATR均线的窗口数
        self.rsiLength = 5  # 计算RSI的窗口数
        self.rsiEntry = 16  # RSI的开仓信号
        self.trailingPercent = 0.7  # 百分比移动止损
        self.victoryPercent = 0.3
        self.initBars = 60  # 初始化数据所用的天数
        self.fixedSize = 1  # 每次交易的数量
        self.ratio_atrMa = 0.8
        self.minx = 'min5'
        # 初始化RSI入场阈值
        self.rsiBuy = 50 + self.rsiEntry
        self.rsiSell = 50 - self.rsiEntry

        # 策略临时变量
        self.atrValue = 0  # 最新的ATR指标数值
        self.atrMa = 0  # ATR移动平均的数值
        self.rsiValue = 0  # RSI指标的数值
        self.iswave = True

        # 需要持久化保存的变量
        self.unit = 0
        self.cost = 0
        self.intraTradeHigh = 0  # 移动止损用的持仓期内最高价
        self.intraTradeLow = 0  # 持仓期内的最低点
        self.stop = 0  # 多头止损

        self.result = None  # 当前的交易
        self.resultList = []  # 交易列表

        # 载入历史数据,并采用回放计算的方式初始化策略数值
        initData = self.portfolio.engine._bc_loadInitBar(
            self.vtSymbol, self.initBars, self.minx)
        for bar in initData:
            self.bar = bar
            self.am.updateBar(bar)
    def __init__(self, portfolio, vtSymbol):
        self.type = 'kong'

        # 策略参数
        self.fixedSize = 1  # 每次交易的数量
        self.initBars = 100  # 初始化数据所用的天数
        self.minx = 'min30'

        # 策略临时变量
        self.can_buy = False
        self.can_sell = False
        self.can_short = False
        self.can_cover = False

        # 需要持久化保存的变量
        self.cost = 0

        size_am = 100
        assert self.initBars <= size_am
        Signal.__init__(self, portfolio, vtSymbol)

        self.bm_bar = None
        self.bm = ArrayManager(60)
        self.init_bm()
class Fut_AtrRsiSignal(object):

    #----------------------------------------------------------------------
    def __init__(self, portfolio, vtSymbol):
        self.portfolio = portfolio  # 投资组合
        self.vtSymbol = vtSymbol  # 合约代码
        self.am = ArrayManager()  # K线容器
        self.bar = None  # 最新K线

        # 策略参数
        self.atrLength = 22  # 计算ATR指标的窗口数
        self.atrMaLength = 10  # 计算ATR均线的窗口数
        self.rsiLength = 5  # 计算RSI的窗口数
        self.rsiEntry = 16  # RSI的开仓信号
        self.victoryPercent = 0.8  # 百分比移动止损
        self.initBars = 90  # 初始化数据所用的天数
        self.fixedSize = 1  # 每次交易的数量
        self.minx = 'min5'

        # 初始化RSI入场阈值
        self.rsiBuy = 50 + self.rsiEntry
        self.rsiSell = 50 - self.rsiEntry

        # 策略临时变量
        self.atrValue = 0  # 最新的ATR指标数值
        self.atrMa = 0  # ATR移动平均的数值
        self.rsiValue = 0  # RSI指标的数值

        # 需要持久化保存的变量
        self.unit = 0
        self.cost = 0
        self.intraTradeHigh = 0  # 移动止损用的持仓期内最高价
        self.intraTradeLow = 0  # 持仓期内的最低点
        self.stop = 0  # 多头止损

        self.result = None  # 当前的交易
        self.resultList = []  # 交易列表

        # 载入历史数据,并采用回放计算的方式初始化策略数值
        initData = self.portfolio.engine._bc_loadInitBar(
            self.vtSymbol, self.initBars, self.minx)
        for bar in initData:
            self.bar = bar
            self.am.updateBar(bar)

    #----------------------------------------------------------------------
    def set_param(self, param_dict):
        if 'atrMaLength' in param_dict:
            self.atrMaLength = param_dict['atrMaLength']
        if 'rsiLength' in param_dict:
            self.rsiLength = param_dict['rsiLength']
        if 'victoryPercent' in param_dict:
            self.victoryPercent = param_dict['victoryPercent']

    #----------------------------------------------------------------------
    def onBar(self, bar):
        """新推送过来一个bar,进行处理"""
        #print(bar.time, self.vtSymbol)

        self.bar = bar
        self.am.updateBar(bar)
        if not self.am.inited:
            return

        #print('here')
        self.calculateIndicator()  # 计算指标
        self.generateSignal(bar)  # 触发信号,产生交易指令

    #----------------------------------------------------------------------
    def calculateIndicator(self):
        """计算技术指标"""
        atrArray = self.am.atr(self.atrLength, array=True)
        # print(len(atrArray))

        self.atrValue = atrArray[-1]
        self.atrMa = atrArray[-self.atrMaLength:].mean()

        self.rsiValue = self.am.rsi(self.rsiLength)

    #----------------------------------------------------------------------
    def generateSignal(self, bar):

        # 当前无仓位
        if self.unit == 0:
            self.intraTradeHigh = bar.high
            self.intraTradeLow = bar.low

            # ATR数值上穿其移动平均线,说明行情短期内波动加大
            # 即处于趋势的概率较大,适合CTA开仓
            if self.atrValue > self.atrMa:
                # 使用RSI指标的趋势行情时,会在超买超卖区钝化特征,作为开仓信号
                if self.rsiValue > self.rsiBuy:
                    # 这里为了保证成交,选择超价5个整指数点下单

                    self.buy(bar.close, self.fixedSize)

                elif self.rsiValue < self.rsiSell:

                    self.short(bar.close, self.fixedSize)

        # 持有多头仓位
        elif self.unit > 0:
            # 计算多头持有期内的最高价,以及重置最低价
            self.intraTradeHigh = max(self.intraTradeHigh, bar.high)

            self.stop = self.intraTradeHigh * (1 - self.victoryPercent / 100)

            if bar.close <= self.stop:
                # print('平多: ', bar.datetime, self.intraTradeHigh, self.stop, bar.close)
                self.sell(bar.close, abs(self.unit))

        # 持有空头仓位
        elif self.unit < 0:
            self.intraTradeLow = min(self.intraTradeLow, bar.low)

            self.stop = self.intraTradeLow * (1 + self.victoryPercent / 100)

            if bar.close >= self.stop:
                # print('平空: ', bar.datetime, self.intraTradeLow, self.stop, bar.close)
                self.cover(bar.close, abs(self.unit))

#----------------------------------------------------------------------

    def load_var(self):
        filename = get_dss() + 'fut/check/signal_atrrsi_var.csv'
        df = pd.read_csv(filename)
        df = df[df.vtSymbol == self.vtSymbol]
        if len(df) > 0:
            rec = df.iloc[-1, :]  # 取最近日期的记录
            self.unit = rec.unit
            self.cost = rec.cost
            self.intraTradeHigh = rec.intraTradeHigh
            self.intraTradeLow = rec.intraTradeLow
            self.stop = rec.stop
            if rec.has_result == 1:
                self.result = SignalResult()
                self.result.unit = rec.result_unit
                self.result.entry = rec.result_entry
                self.result.exit = rec.result_exit
                self.result.pnl = rec.result_pnl

#----------------------------------------------------------------------

    def save_var(self):
        r = []
        if self.result is None:
            r = [ [self.portfolio.result.date,self.vtSymbol, self.unit, self.cost, \
                   self.intraTradeHigh, self.intraTradeLow, self.stop, \
                   0, 0, 0, 0, 0 ] ]
        else:
            r = [ [self.portfolio.result.date,self.vtSymbol, self.unit, self.cost, \
                   self.intraTradeHigh, self.intraTradeLow, self.stop, \
                   1, self.result.unit, self.result.entry, self.result.exit, self.result.pnl ] ]
        df = pd.DataFrame(r, columns=['datetime','vtSymbol','unit','cost', \
                                      'intraTradeHigh','intraTradeLow','stop', \
                                      'has_result','result_unit','result_entry','result_exit', 'result_pnl'])
        filename = get_dss() + 'fut/check/signal_atrrsi_var.csv'
        df.to_csv(filename, index=False, mode='a', header=False)


#----------------------------------------------------------------------

    def buy(self, price, volume):
        """买入开仓"""
        self.open(price, volume)
        self.newSignal(DIRECTION_LONG, OFFSET_OPEN, price, volume)

    #----------------------------------------------------------------------
    def sell(self, price, volume):
        """卖出平仓"""
        volume = abs(self.unit)

        self.close(price)
        self.newSignal(DIRECTION_SHORT, OFFSET_CLOSE, price, volume)

    #----------------------------------------------------------------------
    def short(self, price, volume):
        """卖出开仓"""
        self.open(price, -volume)
        self.newSignal(DIRECTION_SHORT, OFFSET_OPEN, price, volume)

    #----------------------------------------------------------------------
    def cover(self, price, volume):
        """买入平仓"""
        volume = abs(self.unit)

        self.close(price)
        self.newSignal(DIRECTION_LONG, OFFSET_CLOSE, price, volume)

    #----------------------------------------------------------------------
    def open(self, price, change):
        """开仓"""
        self.unit += change

        if not self.result:
            self.result = SignalResult()
        self.result.open(price, change)

        r = [ [self.portfolio.result.date, '多' if change>0 else '空', '开',  \
               abs(change), price, 0, \
               self.atrValue, self.atrMa, self.rsiValue, \
               self.iswave, self.intraTradeHigh, self.intraTradeLow, \
               self.stop] ]
        df = pd.DataFrame(r, columns=['datetime','direction','offset','volume','price','pnl',  \
                                      'atrValue', 'atrMa', 'rsiValue', 'iswave', \
                                      'intraTradeHigh','intraTradeLow','stop'])
        filename = get_dss(
        ) + 'fut/deal/signal_turtle_' + self.vtSymbol + '.csv'
        df.to_csv(filename, index=False, mode='a', header=False)

    #----------------------------------------------------------------------
    def close(self, price):
        """平仓"""
        self.unit = 0
        self.result.close(price)

        r = [ [self.portfolio.result.date, '', '平',  \
               0, price, self.result.pnl, \
               self.atrValue, self.atrMa, self.rsiValue, \
               self.iswave, self.intraTradeHigh, self.intraTradeLow, \
               self.stop] ]
        df = pd.DataFrame(r, columns=['datetime','direction','offset','volume','price','pnl',  \
                                      'atrValue', 'atrMa', 'rsiValue', 'iswave', \
                                      'intraTradeHigh','intraTradeLow','stop'])
        filename = get_dss(
        ) + 'fut/deal/signal_turtle_' + self.vtSymbol + '.csv'
        df.to_csv(filename, index=False, mode='a', header=False)

        self.resultList.append(self.result)
        self.result = None
class Fut_TurtleSignal(object):

    #----------------------------------------------------------------------
    def __init__(self, portfolio, vtSymbol):
        self.portfolio = portfolio      # 投资组合
        self.vtSymbol = vtSymbol        # 合约代码
        self.am = ArrayManager()        # K线容器
        self.bar = None                 # 最新K线

        # 策略参数
        self.initBars = 60              # 初始化数据所用的天数
        self.entryWindow = 20           # 入场通道周期数
        self.exitWindow = 50            # 出场通道周期数
        self.atrWindow = 5              # 计算ATR周期数
        self.profitCheck = True         # 是否检查上一笔盈利
        self.minx = 'day'

        # 策略临时变量
        self.atrVolatility = 0          # ATR波动率
        self.entryUp = 0                # 入场通道
        self.entryDown = 0
        self.exitUp = 0                 # 出场通道
        self.exitDown = 0

        self.longEntry1 = 0             # 多头入场位
        self.longEntry2 = 0
        self.longEntry3 = 0
        self.longEntry4 = 0
        self.longStop = 0               # 多头止损位

        self.shortEntry1 = 0            # 空头入场位
        self.shortEntry2 = 0
        self.shortEntry3 = 0
        self.shortEntry4 = 0
        self.shortStop = 0              # 空头止损位

        # 需要持久化保存的变量
        self.unit = 0

        self.result = None              # 当前的交易
        self.resultList = []            # 交易列表

        # 载入历史数据,并采用回放计算的方式初始化策略数值
        initData = self.portfolio.engine._bc_loadInitBar(self.vtSymbol, self.initBars, self.minx)
        for bar in initData:
            self.bar = bar
            self.am.updateBar(bar)


    #----------------------------------------------------------------------
    def load_param(self):
        filename = get_dss() +  'fut/cfg/signal_turtle_param.csv'
        if os.path.exists(filename):
            df = pd.read_csv(filename)
            df = df[ df.pz == get_contract(self.vtSymbol).pz ]
            if len(df) > 0:
                rec = df.iloc[0,:]
                self.rsiLength = rec.rsiLength
                self.trailingPercent = rec.trailingPercent
                self.victoryPercent = rec.victoryPercent

    #----------------------------------------------------------------------
    def set_param(self, param_dict):
        if 'atrMaLength' in param_dict:
            self.atrMaLength = param_dict['atrMaLength']
        if 'rsiLength' in param_dict:
            self.rsiLength = param_dict['rsiLength']
        if 'trailingPercent' in param_dict:
            self.trailingPercent = param_dict['trailingPercent']
        if 'victoryPercent' in param_dict:
            self.victoryPercent = param_dict['victoryPercent']

    #----------------------------------------------------------------------
    def onBar(self, bar, minx='day'):
        """新推送过来一个bar,进行处理"""
        if minx != 'min1':
            self.on_bar_day(bar)

        if minx == 'min1':
            self.on_bar_min1(bar)

    #----------------------------------------------------------------------
    def on_bar_min1(self, bar):
        # 持有多头仓位
        if self.unit > 0:
            if bar.close <= self.stop:
                # print('平多: ', bar.datetime, self.intraTradeHigh, self.stop, bar.close)
                self.sell(bar.close, abs(self.unit))

        # 持有空头仓位
        elif self.unit < 0:
            if bar.close >= self.stop:
                # print('平空: ', bar.datetime, self.intraTradeLow, self.stop, bar.close)
                self.cover(bar.close, abs(self.unit))

    #----------------------------------------------------------------------
    def on_bar_day(self, bar):
        self.bar = bar
        self.am.updateBar(bar)
        if not self.am.inited:
            return

        self.generateSignal(bar)    # 触发信号,产生交易指令
        self.calculateIndicator()     # 计算指标

    #----------------------------------------------------------------------
    def generateSignal(self, bar):
        """
        判断交易信号
        要注意在任何一个数据点:buy/sell/short/cover只允许执行一类动作
        """
        # 如果指标尚未初始化,则忽略
        if self.longEntry1 == 0:
            return

        # 优先检查平仓
        if self.unit > 0:
            longExit = max(self.longStop, self.exitDown)

            if bar.low <= longExit:
                self.sell(longExit)
                return
        elif self.unit < 0:
            shortExit = min(self.shortStop, self.exitUp)
            if bar.high >= shortExit:
                self.cover(shortExit)
                return

        # 没有仓位或者持有多头仓位的时候,可以做多(加仓)
        if self.unit >= 0:
            trade = False

            if bar.high >= self.longEntry1 and self.unit < 1:
                self.buy(self.longEntry1, 1)
                trade = True

            if bar.high >= self.longEntry2 and self.unit < 2:
                self.buy(self.longEntry2, 1)
                trade = True

            if bar.high >= self.longEntry3 and self.unit < 3:
                self.buy(self.longEntry3, 1)
                trade = True

            if bar.high >= self.longEntry4 and self.unit < 4:
                self.buy(self.longEntry4, 1)
                trade = True

            if trade:
                return

        # 没有仓位或者持有空头仓位的时候,可以做空(加仓)
        if self.unit <= 0:
            if bar.low <= self.shortEntry1 and self.unit > -1:
                self.short(self.shortEntry1, 1)

            if bar.low <= self.shortEntry2 and self.unit > -2:
                self.short(self.shortEntry2, 1)

            if bar.low <= self.shortEntry3 and self.unit > -3:
                self.short(self.shortEntry3, 1)

            if bar.low <= self.shortEntry4 and self.unit > -4:
                self.short(self.shortEntry4, 1)

    #----------------------------------------------------------------------
    def calculateIndicator(self):
        """计算技术指标"""
        self.entryUp, self.entryDown = self.am.donchian(self.entryWindow)
        self.exitUp, self.exitDown = self.am.donchian(self.exitWindow)

        # 有持仓后,ATR波动率和入场位等都不再变化
        if self.unit == 0:
            self.atrVolatility = self.am.atr(self.atrWindow)

            self.longEntry1 = self.entryUp
            self.longEntry2 = self.entryUp + self.atrVolatility * 0.5
            self.longEntry3 = self.entryUp + self.atrVolatility * 1
            self.longEntry4 = self.entryUp + self.atrVolatility * 1.5
            self.longStop = 0

            self.shortEntry1 = self.entryDown
            self.shortEntry2 = self.entryDown - self.atrVolatility * 0.5
            self.shortEntry3 = self.entryDown - self.atrVolatility * 1
            self.shortEntry4 = self.entryDown - self.atrVolatility * 1.5
            self.shortStop = 0

    #----------------------------------------------------------------------
    def load_var(self):
        filename = get_dss() +  'fut/check/signal_turtle_var.csv'
        df = pd.read_csv(filename)
        df = df[df.vtSymbol == self.vtSymbol]
        if len(df) > 0:
            rec = df.iloc[-1,:]            # 取最近日期的记录
            self.unit = rec.unit
            self.cost = rec.cost
            self.intraTradeHigh = rec.intraTradeHigh
            self.intraTradeLow = rec.intraTradeLow
            self.stop = rec.stop
            if rec.has_result == 1:
                self.result = SignalResult()
                self.result.unit = rec.result_unit
                self.result.entry = rec.result_entry
                self.result.exit = rec.result_exit
                self.result.pnl = rec.result_pnl

#----------------------------------------------------------------------
    def save_var(self):
        r = []
        if self.result is None:
            r = [ [self.portfolio.result.date,self.vtSymbol, self.unit, self.cost, \
                   self.intraTradeHigh, self.intraTradeLow, self.stop, \
                   0, 0, 0, 0, 0 ] ]
        else:
            r = [ [self.portfolio.result.date,self.vtSymbol, self.unit, self.cost, \
                   self.intraTradeHigh, self.intraTradeLow, self.stop, \
                   1, self.result.unit, self.result.entry, self.result.exit, self.result.pnl ] ]
        df = pd.DataFrame(r, columns=['datetime','vtSymbol','unit','cost', \
                                      'intraTradeHigh','intraTradeLow','stop', \
                                      'has_result','result_unit','result_entry','result_exit', 'result_pnl'])
        filename = get_dss() +  'fut/check/signal_turtle_var.csv'
        df.to_csv(filename, index=False, mode='a', header=False)


    #----------------------------------------------------------------------
    def newSignal(self, direction, offset, price, volume):
        """调用组合中的接口,传递下单指令"""
        self.portfolio._bc_newSignal(self, direction, offset, price, volume)

#----------------------------------------------------------------------
    def buy(self, price, volume):
        """买入开仓"""
        price = self.calculateTradePrice(DIRECTION_LONG, price)

        self.open(price, volume)
        self.newSignal(DIRECTION_LONG, OFFSET_OPEN, price, volume)

        # 以最后一次加仓价格,加上两倍N计算止损
        self.longStop = price - self.atrVolatility * 2

    #----------------------------------------------------------------------
    def sell(self, price):
        """卖出平仓"""
        price = self.calculateTradePrice(DIRECTION_SHORT, price)

        volume = abs(self.unit)

        self.close(price)
        self.newSignal(DIRECTION_SHORT, OFFSET_CLOSE, price, volume)

    #----------------------------------------------------------------------
    def short(self, price, volume):
        """卖出开仓"""
        price = self.calculateTradePrice(DIRECTION_SHORT, price)

        self.open(price, -volume)
        self.newSignal(DIRECTION_SHORT, OFFSET_OPEN, price, volume)

        # 以最后一次加仓价格,加上两倍N计算止损
        self.shortStop = price + self.atrVolatility * 2

    #----------------------------------------------------------------------
    def cover(self, price):
        """买入平仓"""
        price = self.calculateTradePrice(DIRECTION_LONG, price)
        volume = abs(self.unit)

        self.close(price)
        self.newSignal(DIRECTION_LONG, OFFSET_CLOSE, price, volume)

    #----------------------------------------------------------------------
    def open(self, price, change):
        """开仓"""
        self.unit += change

        if not self.result:
            self.result = SignalResult()
        self.result.open(price, change)

        r = [ [self.portfolio.result.date, '多' if change>0 else '空', '开',  \
               abs(change), price, 0, \
               self.unit] ]
        df = pd.DataFrame(r, columns=['datetime','direction','offset','volume','price','pnl',  \
                                      'unit'])
        filename = get_dss() +  'fut/deal/signal_turtle_' + self.vtSymbol + '.csv'
        df.to_csv(filename, index=False, mode='a', header=False)


    #----------------------------------------------------------------------
    def close(self, price):
        """平仓"""
        self.unit = 0
        self.result.close(price)

        r = [ [self.portfolio.result.date, '', '平',  \
               0, price, self.result.pnl, \
               self.unit] ]
        df = pd.DataFrame(r, columns=['datetime','direction','offset','volume','price','pnl',  \
                                      'unit'])
        filename = get_dss() +  'fut/deal/signal_turtle_' + self.vtSymbol + '.csv'
        df.to_csv(filename, index=False, mode='a', header=False)

        self.resultList.append(self.result)
        self.result = None


    #----------------------------------------------------------------------
    def getLastPnl(self):
        """获取上一笔交易的盈亏"""
        if not self.resultList:
            return 0

        result = self.resultList[-1]
        return result.pnl

    #----------------------------------------------------------------------
    def calculateTradePrice(self, direction, price):
        """计算成交价格"""
        # 买入时,停止单成交的最优价格不能低于当前K线开盘价
        if direction == DIRECTION_LONG:
            tradePrice = max(self.bar.open, price)
        # 卖出时,停止单成交的最优价格不能高于当前K线开盘价
        else:
            tradePrice = min(self.bar.open, price)

        return tradePrice
Beispiel #6
0
df['datetime'] = df['date'] + ' ' + df['time']
df = df[df.datetime < startDt]
assert len(df) >= initBars

df = df.sort_values(by=['date', 'time'])
df = df.iloc[-initBars:]
print(df)

for i, row in df.iterrows():
    d = dict(row)
    #print(d)
    # print(type(d))
    bar = VtBarData()
    bar.__dict__ = d
    #print(bar.__dict__)
    r.append(bar)

am = ArrayManager(initBars)  # K线容器
for bar in r:
    am.updateBar(bar)

#rsiValue = am.rsi(5, array=True)
#rsiArray50 = am.rsi(10, array=True)
#rsiMa  = rsiValue[-30:].mean()

#atrValue = am.atr(30)
atrValue = am.atr(1, array=True)
atrMa = atrValue[-30:].mean()
print(atrValue)
print(atrMa)
class Fut_DualBandSignal_Kong(Signal):

    #----------------------------------------------------------------------
    def __init__(self, portfolio, vtSymbol):
        self.type = 'kong'

        # 策略参数
        self.fixedSize = 1  # 每次交易的数量
        self.initBars = 100  # 初始化数据所用的天数
        self.minx = 'min30'

        # 策略临时变量
        self.can_buy = False
        self.can_sell = False
        self.can_short = False
        self.can_cover = False

        # 需要持久化保存的变量
        self.cost = 0

        size_am = 100
        assert self.initBars <= size_am
        Signal.__init__(self, portfolio, vtSymbol)

        self.bm_bar = None
        self.bm = ArrayManager(60)
        self.init_bm()

    #----------------------------------------------------------------------
    def init_bm(self):
        initData = self.portfolio.engine._bc_loadInitBar(
            self.vtSymbol, 60, 'day')
        for bar in initData:
            self.bm_bar = bar
            self.bm.updateBar(bar)

    #----------------------------------------------------------------------
    def load_param(self):
        filename = get_dss() + 'fut/engine/dualband/signal_dualband_param.csv'
        if os.path.exists(filename):
            df = pd.read_csv(filename)
            df = df[df.symbol == self.vtSymbol]
            if len(df) > 0:
                rec = df.iloc[0, :]
                print('成功加载策略参数')

    #----------------------------------------------------------------------
    def set_param(self, param_dict):
        if 'fixedSize' in param_dict:
            self.fixedSize = param_dict['fixedSize']
            print('成功设置策略参数 self.fixedSize: ', self.fixedSize)

    #----------------------------------------------------------------------
    def onBar(self, bar, minx='min1'):
        """新推送过来一个bar,进行处理"""
        if minx == 'day':
            self.on_bar_day(bar)

        if minx == 'min30':
            self.on_bar_minx(bar)

    #----------------------------------------------------------------------
    def on_bar_day(self, bar):
        if self.paused == True:
            return

        self.bm.updateBar(bar)
        if not self.bm.inited:
            return

    #----------------------------------------------------------------------
    def on_bar_minx(self, bar):
        self.bar = bar

        if self.paused == True:
            return

        self.am.updateBar(bar)
        if not self.am.inited:
            return

        self.calculateIndicator()  # 计算指标
        self.generateSignal(bar)  # 触发信号,产生交易指令

    #----------------------------------------------------------------------
    def calculateIndicator(self):
        """计算技术指标"""
        self.can_short = False
        self.can_cover = False

        ma_short_arr = self.am.sma(30, array=True)
        ma_long_arr = self.am.sma(90, array=True)

        mb_short_arr = self.bm.sma(10, array=True)
        mb_long_arr = self.bm.sma(60, array=True)

        if self.unit == 0:
            if ma_short_arr[-2] >= ma_long_arr[-2] and ma_short_arr[
                    -1] < ma_long_arr[-1]:
                if mb_short_arr[-1] <= mb_long_arr[-1]:
                    self.can_short = True

        if self.unit < 0:
            if ma_short_arr[-2] <= ma_long_arr[-2] and ma_short_arr[
                    -1] > ma_long_arr[-1]:
                self.can_cover = True

        # r = [[self.bar.date,self.bar.time,self.bar.close,self.can_short,self.can_cover,ma_short_arr[-1],ma_long_arr[-1],mb_short_arr[-1],mb_long_arr[-1]]]
        # df = pd.DataFrame(r)
        # filename = get_dss() +  'fut/engine/dualband/bar_dualband_'+self.type+ '_' + self.vtSymbol + '.csv'
        # if os.path.exists(filename):
        #     df.to_csv(filename, index=False, mode='a', header=False)
        # else:
        #     df.to_csv(filename, index=False)

    #----------------------------------------------------------------------
    def generateSignal(self, bar):
        # 开多仓
        if self.can_buy == True:
            self.buy(bar.close, self.fixedSize)
            self.cost = bar.close

        # 开空仓
        if self.can_short == True:
            self.short(bar.close, self.fixedSize)
            self.cost = bar.close

        # 平多仓
        if self.can_sell == True:
            self.sell(bar.close, self.fixedSize)
            self.cost = 0

        # 平空仓
        if self.can_cover == True:
            self.cover(bar.close, self.fixedSize)
            self.cost = 0

    #----------------------------------------------------------------------
    def load_var(self):
        pz = str(get_contract(self.vtSymbol).pz)
        filename = get_dss(
        ) + 'fut/engine/dualband/signal_dualband_' + self.type + '_var_' + pz + '.csv'
        if os.path.exists(filename):
            df = pd.read_csv(filename)
            df = df[df.vtSymbol == self.vtSymbol]
            df = df.sort_values(by='datetime')
            df = df.reset_index()
            if len(df) > 0:
                rec = df.iloc[-1, :]  # 取最近日期的记录
                self.unit = rec.unit
                self.cost = rec.cost
                if rec.has_result == 1:
                    self.result = SignalResult()
                    self.result.unit = rec.result_unit
                    self.result.entry = rec.result_entry
                    self.result.exit = rec.result_exit
                    self.result.pnl = rec.result_pnl

    #----------------------------------------------------------------------
    def save_var(self):
        if self.paused == True:
            return

        r = []
        if self.result is None:
            r = [ [self.portfolio.result.date,self.vtSymbol, self.unit, self.cost, \
                   0, 0, 0, 0, 0 ] ]
        else:
            r = [ [self.portfolio.result.date,self.vtSymbol, self.unit, self.cost, \
                   1, self.result.unit, self.result.entry, self.result.exit, self.result.pnl ] ]

        df = pd.DataFrame(r, columns=['datetime','vtSymbol','unit','cost', \
                                      'has_result','result_unit','result_entry','result_exit', 'result_pnl'])
        pz = str(get_contract(self.vtSymbol).pz)
        filename = get_dss(
        ) + 'fut/engine/dualband/signal_dualband_' + self.type + '_var_' + pz + '.csv'
        if os.path.exists(filename):
            df.to_csv(filename, index=False, mode='a', header=False)
        else:
            df.to_csv(filename, index=False)

    #----------------------------------------------------------------------
    def open(self, price, change):
        self.unit += change

        if not self.result:
            self.result = SignalResult()
        self.result.open(price, change)

        r = [ [self.bar.date+' '+self.bar.time, '多' if change>0 else '空', '开',  \
               abs(change), price, 0, self.vtSymbol ] ]
        df = pd.DataFrame(r,
                          columns=[
                              'datetime', 'direction', 'offset', 'volume',
                              'price', 'pnl', 'symbol'
                          ])
        pz = str(get_contract(self.vtSymbol).pz)
        filename = get_dss(
        ) + 'fut/engine/dualband/signal_dualband_' + self.type + '_deal_' + pz + '.csv'
        if os.path.exists(filename):
            df.to_csv(filename, index=False, mode='a', header=False)
        else:
            df.to_csv(filename, index=False)

    #----------------------------------------------------------------------
    def close(self, price):
        self.unit = 0
        self.result.close(price)

        r = [ [self.bar.date+' '+self.bar.time, '', '平',  \
               0, price, self.result.pnl, self.vtSymbol] ]
        df = pd.DataFrame(r,
                          columns=[
                              'datetime', 'direction', 'offset', 'volume',
                              'price', 'pnl', 'symbol'
                          ])
        pz = str(get_contract(self.vtSymbol).pz)
        filename = get_dss(
        ) + 'fut/engine/dualband/signal_dualband_' + self.type + '_deal_' + pz + '.csv'
        if os.path.exists(filename):
            df.to_csv(filename, index=False, mode='a', header=False)
        else:
            df.to_csv(filename, index=False)

        self.result = None