Пример #1
0
class RsiSignal(CtaSignal):
    """"""
    def __init__(self, rsi_window: int, rsi_level: float):
        """Constructor"""
        super(RsiSignal, self).__init__()

        self.rsi_window = rsi_window
        self.rsi_level = rsi_level
        self.rsi_long = 50 + self.rsi_level
        self.rsi_short = 50 - self.rsi_level

        self.bg = BarGenerator(self.on_bar)
        self.am = ArrayManager()

    def on_tick(self, tick: TickData):
        """
        Callback of new tick data update.
        """
        self.bg.update_tick(tick)

    def on_bar(self, bar: BarData):
        """
        Callback of new bar data update.
        """
        self.am.update_bar(bar)
        if not self.am.inited:
            self.set_signal_pos(0)

        rsi_value = self.am.rsi(self.rsi_window)

        if rsi_value >= self.rsi_long:
            self.set_signal_pos(1)
        elif rsi_value <= self.rsi_short:
            self.set_signal_pos(-1)
        else:
            self.set_signal_pos(0)
Пример #2
0
class RSICurveItem(ChartItem):
    name = 'rsi'
    plot_name = 'indicator'
    RSI_PARAMS = [6, 12, 24]
    RSI_COLORS = {
        6: pg.mkPen(color=(255, 255, 255), width=PEN_WIDTH),
        12: pg.mkPen(color=(255, 255, 0), width=PEN_WIDTH),
        24: pg.mkPen(color=(218, 112, 214), width=PEN_WIDTH)
    }

    def __init__(self, manager: BarManager):
        """"""
        super().__init__(manager)
        # self.periods = [6, 12, 24]
        self.init_setting()
        self._arrayManager = ArrayManager(150)
        self.rsis = defaultdict(dict)
        self.last_ix = 0
        self.br_max = -np.inf
        self.br_min = np.inf
        self.last_picture = QtGui.QPicture()

    def init_setting(self):
        setting = VISUAL_SETTING.get(self.name, {})
        self.RSI_PARAMS = setting.get('params', self.RSI_PARAMS)
        if 'pen' in setting:
            pen_settings = setting['pen']
            for p in self.RSI_PARAMS:
                self.RSI_COLORS[p] = pg.mkPen(**pen_settings[str(p)])

    def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
        """"""
        # Create objects

        if ix <= self.last_ix:
            return self.last_picture

        pre_bar = self._manager.get_bar(ix - 1)

        if not pre_bar:
            return self.last_picture

        rsi_picture = QtGui.QPicture()
        self._arrayManager.update_bar(pre_bar)
        painter = QtGui.QPainter(rsi_picture)

        # Draw volume body
        for p in self.RSI_PARAMS:
            rsi_ = self._arrayManager.rsi(p, True)
            pre_rsi = rsi_[-2]
            rsi = rsi_[-1]
            self.rsis[p][ix - 1] = rsi
            if np.isnan(pre_rsi) or np.isnan(rsi):
                continue

            self.br_max = max(self.br_max, rsi_[-1])
            self.br_min = min(self.br_min, rsi_[-1])

            rsi_sp = QtCore.QPointF(ix - 2, rsi_[-2])
            rsi_ep = QtCore.QPointF(ix - 1, rsi_[-1])
            drawPath(painter, rsi_sp, rsi_ep, self.RSI_COLORS[p])

        # Finish
        painter.end()
        self.last_ix = ix
        self.last_picture = rsi_picture
        return rsi_picture

    def boundingRect(self) -> QtCore.QRectF:
        """"""
        rect = QtCore.QRectF(0, self.br_min, len(self._bar_picutures),
                             self.br_max - self.br_min)
        return rect

    def get_y_range(self,
                    min_ix: int = None,
                    max_ix: int = None) -> Tuple[float, float]:
        """
        Get range of y-axis with given x-axis range.

        If min_ix and max_ix not specified, then return range with whole data set.
        """
        min_ix = 0 if min_ix is None else min_ix
        max_ix = self.last_ix if max_ix is None else max_ix

        min_v = np.inf
        max_v = -np.inf

        p = self.RSI_PARAMS[0]
        for i in range(min_ix, max_ix):
            min_v = min(min_v, self.rsis[p].get(i, min_v),
                        self.rsis[p].get(i, min_v))
            max_v = max(max_v, self.rsis[p].get(i, max_v),
                        self.rsis[p].get(i, max_v))

        return min_v, max_v

    def get_info_text(self, ix: int) -> str:
        """
        Get information text to show by cursor.
        """
        text = '\n'.join(f'rsi{p}: {v.get(ix, np.nan):.2f}'
                         for p, v in self.rsis.items())
        return f"RSI \n{text}"

    def clear_all(self) -> None:
        """
        Clear all data in the item.
        """
        super().clear_all()
        self._arrayManager = ArrayManager(150)
        self.last_ix = 0
        self.last_picture = QtGui.QPicture()
        self.rsis = defaultdict(dict)
        self.br_max = -np.inf
        self.br_min = np.inf
Пример #3
0
    def show(self):
        bars = self.barDatas
        if (bars[0].datetime > bars[-1].datetime):
            bars = bars.__reversed__()

        data = []
        index = []
        am = ArrayManager(self.window_size * 2)

        ### 初始化columns
        columns = ['Open', 'High', 'Low', 'Close', "Volume"]

        if self.open_boll:
            columns.append("boll_up")
            columns.append("boll_down")
        if (self.open_obv):
            columns.append("obv")
        if (self.open_rsi):
            columns.append("rsi")

        for bar in bars:
            index.append(bar.datetime)
            list = [
                bar.open_price, bar.high_price, bar.low_price, bar.close_price,
                bar.volume
            ]
            am.update_bar(bar)

            #添加布林指标数据
            if self.open_boll:
                if am.count >= self.window_size:
                    up, down = am.boll(self.window_size, 3.4)
                    list.append(up)
                    list.append(down)
                else:
                    list.append(bar.close_price)
                    list.append(bar.close_price)

            if self.open_obv:
                if am.count >= self.window_size:
                    obv = am.obv(self.window_size)
                    list.append(obv)
                else:
                    list.append(bar.volume)

            if self.open_rsi:
                if am.count >= self.window_size:
                    rsi = am.rsi(self.window_size)
                    list.append(rsi)
                else:
                    list.append(50)

            data.append(list)

        trades = pd.DataFrame(data, index=index, columns=columns)

        apds = []

        # 添加布林指标数据
        if self.open_boll:
            apds.append(
                mpf.make_addplot(trades['boll_up'], linestyle='dashdot'))
            apds.append(
                mpf.make_addplot(trades['boll_down'], linestyle='dashdot'))

        if self.open_obv:
            apds.append(
                mpf.make_addplot(trades['obv'],
                                 panel='lower',
                                 color='g',
                                 secondary_y=True))
        if self.open_rsi:
            apds.append(
                mpf.make_addplot(trades['rsi'],
                                 panel='lower',
                                 color='b',
                                 secondary_y=True))

        mpf.plot(trades,
                 type='candle',
                 volume=True,
                 mav=(5),
                 figscale=1.3,
                 style='yahoo',
                 addplot=apds)
Пример #4
0
class MultiTimeframeStrategy(CtaTemplate):
    """"""
    author = "中科云集"

    rsi_signal = 20
    rsi_window = 14
    fast_window = 5
    slow_window = 20
    fixed_size = 1

    rsi_value = 0
    rsi_long = 0
    rsi_short = 0
    fast_ma = 0
    slow_ma = 0
    ma_trend = 0

    parameters = [
        "rsi_signal", "rsi_window", "fast_window", "slow_window", "fixed_size"
    ]

    variables = [
        "rsi_value", "rsi_long", "rsi_short", "fast_ma", "slow_ma", "ma_trend"
    ]

    def __init__(self, cta_engine, strategy_name, vt_symbol, setting):
        """"""
        super(MultiTimeframeStrategy, self).__init__(cta_engine, strategy_name,
                                                     vt_symbol, setting)

        self.rsi_long = 50 + self.rsi_signal
        self.rsi_short = 50 - self.rsi_signal

        self.bg5 = BarGenerator(self.on_bar, 5, self.on_5min_bar)
        self.am5 = ArrayManager()

        self.bg15 = BarGenerator(self.on_bar, 15, self.on_15min_bar)
        self.am15 = ArrayManager()

    def on_init(self):
        """
        Callback when strategy is inited.
        """
        self.write_log("策略初始化")
        self.load_bar(10)

    def on_start(self):
        """
        Callback when strategy is started.
        """
        self.write_log("策略启动")

    def on_stop(self):
        """
        Callback when strategy is stopped.
        """
        self.write_log("策略停止")

    def on_tick(self, tick: TickData):
        """
        Callback of new tick data update.
        """
        self.bg5.update_tick(tick)

    def on_bar(self, bar: BarData):
        """
        Callback of new bar data update.
        """
        self.bg5.update_bar(bar)
        self.bg15.update_bar(bar)

    def on_5min_bar(self, bar: BarData):
        """"""
        self.cancel_all()

        self.am5.update_bar(bar)
        if not self.am5.inited:
            return

        if not self.ma_trend:
            return

        self.rsi_value = self.am5.rsi(self.rsi_window)

        if self.pos == 0:
            if self.ma_trend > 0 and self.rsi_value >= self.rsi_long:
                self.buy(bar.close_price + 5, self.fixed_size)
            elif self.ma_trend < 0 and self.rsi_value <= self.rsi_short:
                self.short(bar.close_price - 5, self.fixed_size)

        elif self.pos > 0:
            if self.ma_trend < 0 or self.rsi_value < 50:
                self.sell(bar.close_price - 5, abs(self.pos))

        elif self.pos < 0:
            if self.ma_trend > 0 or self.rsi_value > 50:
                self.cover(bar.close_price + 5, abs(self.pos))

        self.put_event()

    def on_15min_bar(self, bar: BarData):
        """"""
        self.am15.update_bar(bar)
        if not self.am15.inited:
            return

        self.fast_ma = self.am15.sma("c", self.fast_window)
        self.slow_ma = self.am15.sma("c", self.slow_window)

        if self.fast_ma > self.slow_ma:
            self.ma_trend = 1
        else:
            self.ma_trend = -1

    def on_order(self, order: OrderData):
        """
        Callback of new order data update.
        """
        pass

    def on_trade(self, trade: TradeData):
        """
        Callback of new trade data update.
        """
        self.put_event()

    def on_stop_order(self, stop_order: StopOrder):
        """
        Callback of stop order update.
        """
        pass