Beispiel #1
0
    def __init__(self,
                 security_type=SecurityType.stock,
                 exchanges=['sh', 'sz'],
                 codes=None,
                 start_timestamp=None,
                 end_timestamp=None) -> None:

        self.trader_name = type(self).__name__.lower()
        self.trading_signal_listeners = []
        self.state_listeners = []

        self.selectors: List[TargetSelector] = None

        self.security_type = security_type
        self.exchanges = exchanges
        self.codes = codes

        if start_timestamp and end_timestamp:
            self.start_timestamp = to_pd_timestamp(start_timestamp)
            self.end_timestamp = to_pd_timestamp(end_timestamp)
        else:
            assert False

        self.account_service = SimAccountService(
            trader_name=self.trader_name, timestamp=self.start_timestamp)

        self.add_trading_signal_listener(self.account_service)

        self.init_selectors(security_type=self.security_type,
                            exchanges=self.exchanges,
                            codes=self.codes,
                            start_timestamp=self.start_timestamp,
                            end_timestamp=self.end_timestamp)

        self.selectors_comparator.add_selectors(self.selectors)
Beispiel #2
0
    def __init__(self,
                 security_type=SecurityType.stock,
                 exchanges=['sh', 'sz'],
                 codes=None,
                 start_timestamp=None,
                 end_timestamp=None,
                 provider=Provider.JOINQUANT,
                 trading_level=TradingLevel.LEVEL_1DAY,
                 trader_name=None) -> None:
        if trader_name:
            self.trader_name = trader_name
        else:
            self.trader_name = type(self).__name__.lower()
        self.trading_signal_listeners = []
        self.state_listeners = []

        self.selectors: List[TargetSelector] = None

        self.security_type = security_type
        self.exchanges = exchanges
        self.codes = codes
        # make sure the min level selector correspond to the provider and level
        self.provider = provider
        self.trading_level = trading_level

        if start_timestamp and end_timestamp:
            self.start_timestamp = to_pd_timestamp(start_timestamp)
            self.end_timestamp = to_pd_timestamp(end_timestamp)
        else:
            assert False

        self.account_service = SimAccountService(
            trader_name=self.trader_name,
            timestamp=self.start_timestamp,
            provider=self.provider,
            level=self.trading_level)

        self.add_trading_signal_listener(self.account_service)

        self.init_selectors(security_type=self.security_type,
                            exchanges=self.exchanges,
                            codes=self.codes,
                            start_timestamp=self.start_timestamp,
                            end_timestamp=self.end_timestamp)

        self.selectors_comparator = LimitSelectorsComparator(self.selectors)

        self.trading_level_asc = list(
            set([TradingLevel(selector.level) for selector in self.selectors]))
        self.trading_level_asc.sort()

        self.trading_level_desc = list(self.trading_level_asc)
        self.trading_level_desc.reverse()

        self.targets_slot: TargetsSlot = TargetsSlot()
Beispiel #3
0
    def __init__(self, security_id, trading_level, timestamp, trader_name, history_size=250) -> None:
        self.security_id = security_id
        self.trading_level = trading_level
        self.current_timestamp = trading_level.floor_timestamp(to_pd_timestamp(timestamp))

        self.model_name = "{}_{}_{}".format(trader_name, type(self).__name__, trading_level.value)

        self.history_size = history_size

        self.add_trading_signal_listener(SimAccountService(trader_name=trader_name, model_name=self.model_name,
                                                           timestamp=timestamp))

        self.close_hour, self.close_minute = get_close_time(self.security_id)
Beispiel #4
0
class Trader(object):
    entity_schema: EntityMixin = None

    def __init__(self,
                 entity_ids: List[str] = None,
                 exchanges: List[str] = None,
                 codes: List[str] = None,
                 start_timestamp: Union[str, pd.Timestamp] = None,
                 end_timestamp: Union[str, pd.Timestamp] = None,
                 provider: str = None,
                 level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY,
                 trader_name: str = None,
                 real_time: bool = False,
                 kdata_use_begin_time: bool = False,
                 draw_result: bool = True) -> None:
        assert self.entity_schema is not None

        self.logger = logging.getLogger(__name__)

        if trader_name:
            self.trader_name = trader_name
        else:
            self.trader_name = type(self).__name__.lower()

        self.trading_signal_listeners: List[TradingListener] = []

        self.selectors: List[TargetSelector] = []

        self.entity_ids = entity_ids

        self.exchanges = exchanges
        self.codes = codes

        self.provider = provider
        # make sure the min level selector correspond to the provider and level
        self.level = IntervalLevel(level)
        self.real_time = real_time

        if start_timestamp and end_timestamp:
            self.start_timestamp = to_pd_timestamp(start_timestamp)
            self.end_timestamp = to_pd_timestamp(end_timestamp)
        else:
            assert False

        self.trading_dates = self.entity_schema.get_trading_dates(
            start_date=self.start_timestamp, end_date=self.end_timestamp)

        if real_time:
            logger.info(
                'real_time mode, end_timestamp should be future,you could set it big enough for running forever'
            )
            assert self.end_timestamp >= now_pd_timestamp()

        self.kdata_use_begin_time = kdata_use_begin_time
        self.draw_result = draw_result

        self.account_service = SimAccountService(
            entity_schema=self.entity_schema,
            trader_name=self.trader_name,
            timestamp=self.start_timestamp,
            provider=self.provider,
            level=self.level)

        self.add_trading_signal_listener(self.account_service)

        self.init_selectors(entity_ids=entity_ids,
                            entity_schema=self.entity_schema,
                            exchanges=self.exchanges,
                            codes=self.codes,
                            start_timestamp=self.start_timestamp,
                            end_timestamp=self.end_timestamp)

        if self.selectors:
            self.trading_level_asc = list(
                set([
                    IntervalLevel(selector.level)
                    for selector in self.selectors
                ]))
            self.trading_level_asc.sort()

            self.logger.info(
                f'trader level:{self.level},selectors level:{self.trading_level_asc}'
            )

            if self.level != self.trading_level_asc[0]:
                raise Exception(
                    "trader level should be the min of the selectors")

            self.trading_level_desc = list(self.trading_level_asc)
            self.trading_level_desc.reverse()

        self.targets_slot: TargetsSlot = TargetsSlot()

        self.session = get_db_session('zvt', data_schema=TraderInfo)
        self.on_start()

    def on_start(self):
        # run all the selectors
        for selector in self.selectors:
            # run for the history data at first
            selector.run()

        if self.entity_ids:
            entity_ids = json.dumps(self.entity_ids)
        else:
            entity_ids = None

        if self.exchanges:
            exchanges = json.dumps(self.exchanges)
        else:
            exchanges = None

        if self.codes:
            codes = json.dumps(self.codes)
        else:
            codes = None

        sim_account = TraderInfo(
            id=self.trader_name,
            entity_id=f'trader_zvt_{self.trader_name}',
            timestamp=self.start_timestamp,
            trader_name=self.trader_name,
            entity_ids=entity_ids,
            exchanges=exchanges,
            codes=codes,
            start_timestamp=self.start_timestamp,
            end_timestamp=self.end_timestamp,
            provider=self.provider,
            level=self.level.value,
            real_time=self.real_time,
            kdata_use_begin_time=self.kdata_use_begin_time)
        self.session.add(sim_account)
        self.session.commit()

    def init_selectors(self, entity_ids, entity_schema, exchanges, codes,
                       start_timestamp, end_timestamp):
        """
        implement this to init selectors

        """
        pass

    def add_trading_signal_listener(self, listener):
        if listener not in self.trading_signal_listeners:
            self.trading_signal_listeners.append(listener)

    def remove_trading_signal_listener(self, listener):
        if listener in self.trading_signal_listeners:
            self.trading_signal_listeners.remove(listener)

    def handle_targets_slot(self, due_timestamp: pd.Timestamp,
                            happen_timestamp: pd.Timestamp):
        """
        this function would be called in every cycle, you could overwrite it for your custom algorithm to select the
        targets of different levels

        the default implementation is selecting the targets in all levels

        :param due_timestamp:
        :param happen_timestamp:

        """
        long_selected = None
        short_selected = None
        for level in self.trading_level_desc:
            targets = self.targets_slot.get_targets(level=level)
            if targets:
                long_targets = set(targets[0])
                short_targets = set(targets[1])

                if not long_selected:
                    long_selected = long_targets
                else:
                    long_selected = long_selected & long_targets

                if not short_selected:
                    short_selected = short_targets
                else:
                    short_selected = short_selected & short_targets

        self.logger.debug('timestamp:{},long_selected:{}'.format(
            due_timestamp, long_selected))

        self.logger.debug('timestamp:{},short_selected:{}'.format(
            due_timestamp, short_selected))

        self.trade_the_targets(due_timestamp=due_timestamp,
                               happen_timestamp=happen_timestamp,
                               long_selected=long_selected,
                               short_selected=short_selected)

    def get_current_account(self) -> AccountStats:
        return self.account_service.account

    def buy(self,
            due_timestamp,
            happen_timestamp,
            entity_ids,
            position_pct=1.0,
            ignore_in_position=True):
        if ignore_in_position:
            account = self.get_current_account()
            current_holdings = []
            if account.positions:
                current_holdings = [
                    position.entity_id for position in account.positions
                    if position != None and position.available_long > 0
                ]

            entity_ids = set(entity_ids) - set(current_holdings)

        if entity_ids:
            position_pct = (1.0 / len(entity_ids)) * position_pct

        for entity_id in entity_ids:
            trading_signal = TradingSignal(
                entity_id=entity_id,
                due_timestamp=due_timestamp,
                happen_timestamp=happen_timestamp,
                trading_signal_type=TradingSignalType.open_long,
                trading_level=self.level,
                position_pct=position_pct)
            self.send_trading_signal(trading_signal)

    def sell(self,
             due_timestamp,
             happen_timestamp,
             entity_ids,
             position_pct=1.0):
        # current position
        account = self.get_current_account()
        current_holdings = []
        if account.positions:
            current_holdings = [
                position.entity_id for position in account.positions
                if position != None and position.available_long > 0
            ]

        shorted = set(current_holdings) & entity_ids

        for entity_id in shorted:
            trading_signal = TradingSignal(
                entity_id=entity_id,
                due_timestamp=due_timestamp,
                happen_timestamp=happen_timestamp,
                trading_signal_type=TradingSignalType.close_long,
                trading_level=self.level,
                position_pct=position_pct)
            self.send_trading_signal(trading_signal)

    def trade_the_targets(self,
                          due_timestamp,
                          happen_timestamp,
                          long_selected,
                          short_selected,
                          long_pct=1.0,
                          short_pct=1.0):
        self.buy(due_timestamp=due_timestamp,
                 happen_timestamp=happen_timestamp,
                 entity_ids=long_selected,
                 position_pct=long_pct)
        self.sell(due_timestamp=due_timestamp,
                  happen_timestamp=happen_timestamp,
                  entity_ids=short_selected,
                  position_pct=short_pct)

    def send_trading_signal(self, signal: TradingSignal):
        for listener in self.trading_signal_listeners:
            try:
                listener.on_trading_signal(signal)
            except Exception as e:
                self.logger.exception(e)
                listener.on_trading_error(timestamp=signal.happen_timestamp,
                                          error=e)

    def on_finish(self):
        # show the result
        if self.draw_result:
            import plotly.io as pio
            pio.renderers.default = "browser"
            reader = AccountStatsReader(trader_names=[self.trader_name])
            df = reader.data_df
            drawer = Drawer(main_data=NormalData(
                df.copy()[['trader_name', 'timestamp', 'all_value']],
                category_field='trader_name'))
            drawer.draw_line(show=True)

    def select_long_targets(self, long_targets: List[str]) -> List[str]:
        if len(long_targets) > 10:
            return long_targets[0:10]
        return long_targets

    def select_short_targets(self, short_targets: List[str]) -> List[str]:
        if len(short_targets) > 10:
            return short_targets[0:10]
        return short_targets

    def in_trading_date(self, timestamp):
        return to_time_str(timestamp) in self.trading_dates

    def on_time(self, timestamp):
        self.logger.debug(f'current timestamp:{timestamp}')

    def run(self):
        # iterate timestamp of the min level,e.g,9:30,9:35,9.40...for 5min level
        # timestamp represents the timestamp in kdata
        for timestamp in self.entity_schema.get_interval_timestamps(
                start_date=self.start_timestamp,
                end_date=self.end_timestamp,
                level=self.level):

            if not self.in_trading_date(timestamp=timestamp):
                continue

            if self.real_time:
                # all selector move on to handle the coming data
                if self.kdata_use_begin_time:
                    real_end_timestamp = timestamp + pd.Timedelta(
                        seconds=self.level.to_second())
                else:
                    real_end_timestamp = timestamp

                seconds = (now_pd_timestamp() -
                           real_end_timestamp).total_seconds()
                waiting_seconds = self.level.to_second() - seconds
                # meaning the future kdata not ready yet,we could move on to check
                if waiting_seconds > 0:
                    # iterate the selector from min to max which in finished timestamp kdata
                    for level in self.trading_level_asc:
                        if self.entity_schema.is_finished_kdata_timestamp(
                                timestamp=timestamp, level=level):
                            for selector in self.selectors:
                                if selector.level == level:
                                    selector.move_on(timestamp,
                                                     self.kdata_use_begin_time,
                                                     timeout=waiting_seconds +
                                                     20)

            # on_trading_open to setup the account
            if self.level == IntervalLevel.LEVEL_1DAY or (
                    self.level != IntervalLevel.LEVEL_1DAY
                    and self.entity_schema.is_open_timestamp(timestamp)):
                self.account_service.on_trading_open(timestamp)

            self.on_time(timestamp=timestamp)

            if self.selectors:
                for level in self.trading_level_asc:
                    # in every cycle, all level selector do its job in its time
                    if self.entity_schema.is_finished_kdata_timestamp(
                            timestamp=timestamp, level=level):
                        all_long_targets = []
                        all_short_targets = []
                        for selector in self.selectors:
                            if selector.level == level:
                                long_targets = selector.get_open_long_targets(
                                    timestamp=timestamp)
                                long_targets = self.select_long_targets(
                                    long_targets)

                                short_targets = selector.get_open_short_targets(
                                    timestamp=timestamp)
                                short_targets = self.select_short_targets(
                                    short_targets)

                                all_long_targets += long_targets
                                all_short_targets += short_targets

                        if all_long_targets or all_short_targets:
                            self.targets_slot.input_targets(
                                level, all_long_targets, all_short_targets)
                            # the time always move on by min level step and we could check all level targets in the slot
                            # 1)the targets is generated for next interval
                            # 2)the acceptable price is next interval prices,you could buy it at current price if the time is before the timestamp(due_timestamp) when trading signal received
                            # 3)the suggest price the the close price for generating the signal(happen_timestamp)
                            due_timestamp = timestamp + pd.Timedelta(
                                seconds=self.level.to_second())
                            if level == self.level:
                                self.handle_targets_slot(
                                    due_timestamp=due_timestamp,
                                    happen_timestamp=timestamp)

            # on_trading_close to calculate date account
            if self.level == IntervalLevel.LEVEL_1DAY or (
                    self.level != IntervalLevel.LEVEL_1DAY
                    and self.entity_schema.is_close_timestamp(timestamp)):
                self.account_service.on_trading_close(timestamp)

        self.on_finish()
Beispiel #5
0
class Trader(Constructor):
    logger = logging.getLogger(__name__)

    security_type: SecurityType = None

    def __init__(self,
                 security_list: List[str] = None,
                 exchanges: List[str] = ['sh', 'sz'],
                 codes: List[str] = None,
                 start_timestamp: Union[str, pd.Timestamp] = None,
                 end_timestamp: Union[str, pd.Timestamp] = None,
                 provider: Union[str, Provider] = Provider.JOINQUANT,
                 level: Union[str, TradingLevel] = TradingLevel.LEVEL_1DAY,
                 trader_name: str = None,
                 real_time: bool = False,
                 kdata_use_begin_time: bool = False) -> None:

        assert self.security_type is not None

        if trader_name:
            self.trader_name = trader_name
        else:
            self.trader_name = type(self).__name__.lower()

        self.trading_signal_listeners = []
        self.state_listeners = []

        self.selectors: List[TargetSelector] = []

        self.security_list = security_list

        self.exchanges = exchanges
        self.codes = codes

        # FIXME:handle this case gracefully
        if self.security_list:
            security_type, exchange, code = decode_security_id(self.security_list[0])
            if not self.security_type:
                self.security_type = security_type
            if not self.exchanges:
                self.exchanges = [exchange]

        self.provider = Provider(provider)
        # make sure the min level selector correspond to the provider and level
        self.level = TradingLevel(level)
        self.real_time = real_time

        if start_timestamp and end_timestamp:
            self.start_timestamp = to_pd_timestamp(start_timestamp)
            self.end_timestamp = to_pd_timestamp(end_timestamp)
        else:
            assert False

        if real_time:
            logger.info(
                'real_time mode, end_timestamp should be future,you could set it big enough for running forever')
            assert self.end_timestamp >= now_pd_timestamp()

        self.kdata_use_begin_time = kdata_use_begin_time

        self.account_service = SimAccountService(trader_name=self.trader_name,
                                                 timestamp=self.start_timestamp,
                                                 provider=self.provider,
                                                 level=self.level)

        self.add_trading_signal_listener(self.account_service)

        self.init_selectors(security_list=security_list, security_type=self.security_type, exchanges=self.exchanges,
                            codes=self.codes, start_timestamp=self.start_timestamp, end_timestamp=self.end_timestamp)

        self.selectors_comparator = self.init_selectors_comparator()

        self.trading_level_asc = list(set([TradingLevel(selector.level) for selector in self.selectors]))
        self.trading_level_asc.sort()

        self.trading_level_desc = list(self.trading_level_asc)
        self.trading_level_desc.reverse()

        self.targets_slot: TargetsSlot = TargetsSlot()

        self.session = get_db_session('zvt', StoreCategory.business)
        trader = get_trader(session=self.session, trader_name=self.trader_name, return_type='domain', limit=1)

        if trader:
            self.logger.warning("trader:{} has run before,old result would be deleted".format(self.trader_name))
            self.session.query(business.Trader).filter(business.Trader.trader_name == self.trader_name).delete()
            self.session.commit()
        self.on_start()

    def on_start(self):
        if not self.selectors:
            raise Exception('please setup self.selectors in init_selectors at first')

        # run all the selectors
        technical_factors = []
        for selector in self.selectors:
            # run for the history data at first
            selector.run()

            for factor in selector.filter_factors:
                if isinstance(factor, TechnicalFactor):
                    technical_factors.append(factor)

        technical_factors = simplejson.dumps(technical_factors, for_json=True)

        if self.security_list:
            security_list = json.dumps(self.security_list)
        else:
            security_list = None

        if self.exchanges:
            exchanges = json.dumps(self.exchanges)
        else:
            exchanges = None

        if self.codes:
            codes = json.dumps(self.codes)
        else:
            codes = None

        trader_domain = business.Trader(id=self.trader_name, timestamp=self.start_timestamp,
                                        trader_name=self.trader_name,
                                        security_type=security_list, exchanges=exchanges, codes=codes,
                                        start_timestamp=self.start_timestamp,
                                        end_timestamp=self.end_timestamp, provider=self.provider.value,
                                        level=self.level.value,
                                        real_time=self.real_time, kdata_use_begin_time=self.kdata_use_begin_time,
                                        technical_factors=technical_factors)
        self.session.add(trader_domain)
        self.session.commit()

    def init_selectors(self, security_list, security_type, exchanges, codes, start_timestamp, end_timestamp):
        """
        implement this to init selectors

        """
        raise NotImplementedError

    def init_selectors_comparator(self):
        """
        overwrite this to set selectors_comparator

        """
        return LimitSelectorsComparator(self.selectors)

    def add_trading_signal_listener(self, listener):
        if listener not in self.trading_signal_listeners:
            self.trading_signal_listeners.append(listener)

    def remove_trading_signal_listener(self, listener):
        if listener in self.trading_signal_listeners:
            self.trading_signal_listeners.remove(listener)

    def handle_targets_slot(self, timestamp):
        """
        this function would be called in every cycle, you could overwrite it for your custom algorithm to select the
        targets of different levels

        the default implementation is selecting the targets in all levels

        :param timestamp:
        :type timestamp:
        """
        long_selected = None
        short_selected = None
        for level in self.trading_level_desc:
            targets = self.targets_slot.get_targets(level=level)
            if targets:
                long_targets = set(targets[0])
                short_targets = set(targets[1])

                if not long_selected:
                    long_selected = long_targets
                else:
                    long_selected = long_selected & long_targets

                if not short_selected:
                    short_selected = short_targets
                else:
                    short_selected = short_selected & short_targets

        self.logger.debug('timestamp:{},long_selected:{}'.format(timestamp, long_selected))

        self.logger.debug('timestamp:{},short_selected:{}'.format(timestamp, short_selected))

        self.send_trading_signals(timestamp=timestamp, long_selected=long_selected, short_selected=short_selected)

    def send_trading_signals(self, timestamp, long_selected, short_selected):
        # current position
        account = self.account_service.latest_account
        current_holdings = [position['security_id'] for position in account['positions'] if
                            position['available_long'] > 0]

        if long_selected:
            # just long the security not in the positions
            longed = long_selected - set(current_holdings)
            if longed:
                position_pct = 1.0 / len(longed)
                order_money = account['cash'] * position_pct

                for security_id in longed:
                    trading_signal = TradingSignal(security_id=security_id,
                                                   the_timestamp=timestamp,
                                                   trading_signal_type=TradingSignalType.trading_signal_open_long,
                                                   trading_level=self.level,
                                                   order_money=order_money)
                    for listener in self.trading_signal_listeners:
                        listener.on_trading_signal(trading_signal)

        # just short the security in current_holdings and short_selected
        if short_selected:
            shorted = set(current_holdings) & short_selected

            for security_id in shorted:
                trading_signal = TradingSignal(security_id=security_id,
                                               the_timestamp=timestamp,
                                               trading_signal_type=TradingSignalType.trading_signal_close_long,
                                               position_pct=1.0,
                                               trading_level=self.level)
                for listener in self.trading_signal_listeners:
                    listener.on_trading_signal(trading_signal)

    def on_finish(self):
        # show the result
        pass

    def run(self):
        # iterate timestamp of the min level,e.g,9:30,9:35,9.40...for 5min level
        # timestamp represents the timestamp in kdata
        for timestamp in iterate_timestamps(security_type=self.security_type, exchange=self.exchanges[0],
                                            start_timestamp=self.start_timestamp, end_timestamp=self.end_timestamp,
                                            level=self.level):

            if not is_trading_date(security_type=self.security_type, exchange=self.exchanges[0], timestamp=timestamp):
                continue
            if self.real_time:
                # all selector move on to handle the coming data
                if self.kdata_use_begin_time:
                    real_end_timestamp = timestamp + pd.Timedelta(seconds=self.level.to_second())
                else:
                    real_end_timestamp = timestamp

                waiting_seconds, _ = self.level.count_from_timestamp(real_end_timestamp,
                                                                     one_day_trading_minutes=get_one_day_trading_minutes(
                                                                         security_type=self.security_type))
                # meaning the future kdata not ready yet,we could move on to check
                if waiting_seconds and (waiting_seconds > 0):
                    # iterate the selector from min to max which in finished timestamp kdata
                    for level in self.trading_level_asc:
                        if (is_in_finished_timestamps(security_type=self.security_type, exchange=self.exchanges[0],
                                                      timestamp=timestamp, level=level)):
                            for selector in self.selectors:
                                if selector.level == level:
                                    selector.move_on(timestamp, self.kdata_use_begin_time, timeout=waiting_seconds + 20)

            # on_trading_open to setup the account
            if self.level == TradingLevel.LEVEL_1DAY or (
                    self.level != TradingLevel.LEVEL_1DAY and is_open_time(security_type=self.security_type,
                                                                           exchange=self.exchanges[0],
                                                                           timestamp=timestamp)):
                self.account_service.on_trading_open(timestamp)

            # the time always move on by min level step and we could check all level targets in the slot
            self.handle_targets_slot(timestamp=timestamp)

            for level in self.trading_level_asc:
                # in every cycle, all level selector do its job in its time
                if (is_in_finished_timestamps(security_type=self.security_type, exchange=self.exchanges[0],
                                              timestamp=timestamp, level=level)):
                    long_targets, short_targets = self.selectors_comparator.make_decision(timestamp=timestamp,
                                                                                          trading_level=level)

                    self.targets_slot.input_targets(level, long_targets, short_targets)

            # on_trading_close to calculate date account
            if self.level == TradingLevel.LEVEL_1DAY or (
                    self.level != TradingLevel.LEVEL_1DAY and is_close_time(security_type=self.security_type,
                                                                            exchange=self.exchanges[0],
                                                                            timestamp=timestamp)):
                self.account_service.on_trading_close(timestamp)

        self.on_finish()

    @classmethod
    def get_constructor_meta(cls):
        meta = super().get_constructor_meta()
        meta.metas['security_type'] = marshal_object_for_ui(cls.security_type)
        return meta
Beispiel #6
0
    def __init__(self,
                 security_list: List[str] = None,
                 exchanges: List[str] = ['sh', 'sz'],
                 codes: List[str] = None,
                 start_timestamp: Union[str, pd.Timestamp] = None,
                 end_timestamp: Union[str, pd.Timestamp] = None,
                 provider: Union[str, Provider] = Provider.JOINQUANT,
                 level: Union[str, TradingLevel] = TradingLevel.LEVEL_1DAY,
                 trader_name: str = None,
                 real_time: bool = False,
                 kdata_use_begin_time: bool = False) -> None:

        assert self.security_type is not None

        if trader_name:
            self.trader_name = trader_name
        else:
            self.trader_name = type(self).__name__.lower()

        self.trading_signal_listeners = []
        self.state_listeners = []

        self.selectors: List[TargetSelector] = []

        self.security_list = security_list

        self.exchanges = exchanges
        self.codes = codes

        # FIXME:handle this case gracefully
        if self.security_list:
            security_type, exchange, code = decode_security_id(self.security_list[0])
            if not self.security_type:
                self.security_type = security_type
            if not self.exchanges:
                self.exchanges = [exchange]

        self.provider = Provider(provider)
        # make sure the min level selector correspond to the provider and level
        self.level = TradingLevel(level)
        self.real_time = real_time

        if start_timestamp and end_timestamp:
            self.start_timestamp = to_pd_timestamp(start_timestamp)
            self.end_timestamp = to_pd_timestamp(end_timestamp)
        else:
            assert False

        if real_time:
            logger.info(
                'real_time mode, end_timestamp should be future,you could set it big enough for running forever')
            assert self.end_timestamp >= now_pd_timestamp()

        self.kdata_use_begin_time = kdata_use_begin_time

        self.account_service = SimAccountService(trader_name=self.trader_name,
                                                 timestamp=self.start_timestamp,
                                                 provider=self.provider,
                                                 level=self.level)

        self.add_trading_signal_listener(self.account_service)

        self.init_selectors(security_list=security_list, security_type=self.security_type, exchanges=self.exchanges,
                            codes=self.codes, start_timestamp=self.start_timestamp, end_timestamp=self.end_timestamp)

        self.selectors_comparator = self.init_selectors_comparator()

        self.trading_level_asc = list(set([TradingLevel(selector.level) for selector in self.selectors]))
        self.trading_level_asc.sort()

        self.trading_level_desc = list(self.trading_level_asc)
        self.trading_level_desc.reverse()

        self.targets_slot: TargetsSlot = TargetsSlot()

        self.session = get_db_session('zvt', StoreCategory.business)
        trader = get_trader(session=self.session, trader_name=self.trader_name, return_type='domain', limit=1)

        if trader:
            self.logger.warning("trader:{} has run before,old result would be deleted".format(self.trader_name))
            self.session.query(business.Trader).filter(business.Trader.trader_name == self.trader_name).delete()
            self.session.commit()
        self.on_start()
Beispiel #7
0
    def __init__(self,
                 entity_ids: List[str] = None,
                 exchanges: List[str] = ['sh', 'sz'],
                 codes: List[str] = None,
                 start_timestamp: Union[str, pd.Timestamp] = None,
                 end_timestamp: Union[str, pd.Timestamp] = None,
                 provider: str = None,
                 level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY,
                 trader_name: str = None,
                 real_time: bool = False,
                 kdata_use_begin_time: bool = False,
                 draw_result: bool = True) -> None:

        assert self.entity_schema is not None

        if trader_name:
            self.trader_name = trader_name
        else:
            self.trader_name = type(self).__name__.lower()

        self.trading_signal_listeners = []

        self.selectors: List[TargetSelector] = []

        self.entity_ids = entity_ids

        self.exchanges = exchanges
        self.codes = codes

        self.provider = provider
        # make sure the min level selector correspond to the provider and level
        self.level = IntervalLevel(level)
        self.real_time = real_time

        if start_timestamp and end_timestamp:
            self.start_timestamp = to_pd_timestamp(start_timestamp)
            self.end_timestamp = to_pd_timestamp(end_timestamp)
        else:
            assert False

        if real_time:
            logger.info(
                'real_time mode, end_timestamp should be future,you could set it big enough for running forever'
            )
            assert self.end_timestamp >= now_pd_timestamp()

        self.kdata_use_begin_time = kdata_use_begin_time
        self.draw_result = draw_result

        self.account_service = SimAccountService(
            trader_name=self.trader_name,
            timestamp=self.start_timestamp,
            provider=self.provider,
            level=self.level)

        self.add_trading_signal_listener(self.account_service)

        self.init_selectors(entity_ids=entity_ids,
                            entity_schema=self.entity_schema,
                            exchanges=self.exchanges,
                            codes=self.codes,
                            start_timestamp=self.start_timestamp,
                            end_timestamp=self.end_timestamp)

        self.selectors_comparator = self.init_selectors_comparator()

        self.trading_level_asc = list(
            set([IntervalLevel(selector.level)
                 for selector in self.selectors]))
        self.trading_level_asc.sort()

        self.trading_level_desc = list(self.trading_level_asc)
        self.trading_level_desc.reverse()

        self.targets_slot: TargetsSlot = TargetsSlot()

        self.session = get_db_session('zvt', 'business')
        trader = get_trader(session=self.session,
                            trader_name=self.trader_name,
                            return_type='domain',
                            limit=1)

        if trader:
            self.logger.warning(
                "trader:{} has run before,old result would be deleted".format(
                    self.trader_name))
            self.session.query(business.Trader).filter(
                business.Trader.trader_name == self.trader_name).delete()
            self.session.commit()
        self.on_start()
Beispiel #8
0
class Trader(object):
    logger = logging.getLogger(__name__)

    entity_schema: EntityMixin = None

    def __init__(self,
                 entity_ids: List[str] = None,
                 exchanges: List[str] = ['sh', 'sz'],
                 codes: List[str] = None,
                 start_timestamp: Union[str, pd.Timestamp] = None,
                 end_timestamp: Union[str, pd.Timestamp] = None,
                 provider: str = None,
                 level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY,
                 trader_name: str = None,
                 real_time: bool = False,
                 kdata_use_begin_time: bool = False,
                 draw_result: bool = True) -> None:

        assert self.entity_schema is not None

        if trader_name:
            self.trader_name = trader_name
        else:
            self.trader_name = type(self).__name__.lower()

        self.trading_signal_listeners = []

        self.selectors: List[TargetSelector] = []

        self.entity_ids = entity_ids

        self.exchanges = exchanges
        self.codes = codes

        self.provider = provider
        # make sure the min level selector correspond to the provider and level
        self.level = IntervalLevel(level)
        self.real_time = real_time

        if start_timestamp and end_timestamp:
            self.start_timestamp = to_pd_timestamp(start_timestamp)
            self.end_timestamp = to_pd_timestamp(end_timestamp)
        else:
            assert False

        if real_time:
            logger.info(
                'real_time mode, end_timestamp should be future,you could set it big enough for running forever'
            )
            assert self.end_timestamp >= now_pd_timestamp()

        self.kdata_use_begin_time = kdata_use_begin_time
        self.draw_result = draw_result

        self.account_service = SimAccountService(
            trader_name=self.trader_name,
            timestamp=self.start_timestamp,
            provider=self.provider,
            level=self.level)

        self.add_trading_signal_listener(self.account_service)

        self.init_selectors(entity_ids=entity_ids,
                            entity_schema=self.entity_schema,
                            exchanges=self.exchanges,
                            codes=self.codes,
                            start_timestamp=self.start_timestamp,
                            end_timestamp=self.end_timestamp)

        self.selectors_comparator = self.init_selectors_comparator()

        self.trading_level_asc = list(
            set([IntervalLevel(selector.level)
                 for selector in self.selectors]))
        self.trading_level_asc.sort()

        self.trading_level_desc = list(self.trading_level_asc)
        self.trading_level_desc.reverse()

        self.targets_slot: TargetsSlot = TargetsSlot()

        self.session = get_db_session('zvt', 'business')
        trader = get_trader(session=self.session,
                            trader_name=self.trader_name,
                            return_type='domain',
                            limit=1)

        if trader:
            self.logger.warning(
                "trader:{} has run before,old result would be deleted".format(
                    self.trader_name))
            self.session.query(business.Trader).filter(
                business.Trader.trader_name == self.trader_name).delete()
            self.session.commit()
        self.on_start()

    def on_start(self):
        if not self.selectors:
            raise Exception(
                'please setup self.selectors in init_selectors at first')

        # run all the selectors
        for selector in self.selectors:
            # run for the history data at first
            selector.run()

        if self.entity_ids:
            entity_ids = json.dumps(self.entity_ids)
        else:
            entity_ids = None

        if self.exchanges:
            exchanges = json.dumps(self.exchanges)
        else:
            exchanges = None

        if self.codes:
            codes = json.dumps(self.codes)
        else:
            codes = None

        trader_domain = business.Trader(
            id=self.trader_name,
            timestamp=self.start_timestamp,
            trader_name=self.trader_name,
            entity_type=entity_ids,
            exchanges=exchanges,
            codes=codes,
            start_timestamp=self.start_timestamp,
            end_timestamp=self.end_timestamp,
            provider=self.provider,
            level=self.level.value,
            real_time=self.real_time,
            kdata_use_begin_time=self.kdata_use_begin_time)
        self.session.add(trader_domain)
        self.session.commit()

    def init_selectors(self, entity_ids, entity_schema, exchanges, codes,
                       start_timestamp, end_timestamp):
        """
        implement this to init selectors

        """
        raise NotImplementedError

    def init_selectors_comparator(self):
        """
        overwrite this to set selectors_comparator

        """
        return LimitSelectorsComparator(self.selectors)

    def add_trading_signal_listener(self, listener):
        if listener not in self.trading_signal_listeners:
            self.trading_signal_listeners.append(listener)

    def remove_trading_signal_listener(self, listener):
        if listener in self.trading_signal_listeners:
            self.trading_signal_listeners.remove(listener)

    def handle_targets_slot(self, timestamp):
        """
        this function would be called in every cycle, you could overwrite it for your custom algorithm to select the
        targets of different levels

        the default implementation is selecting the targets in all levels

        :param timestamp:
        :type timestamp:
        """
        long_selected = None
        short_selected = None
        for level in self.trading_level_desc:
            targets = self.targets_slot.get_targets(level=level)
            if targets:
                long_targets = set(targets[0])
                short_targets = set(targets[1])

                if not long_selected:
                    long_selected = long_targets
                else:
                    long_selected = long_selected & long_targets

                if not short_selected:
                    short_selected = short_targets
                else:
                    short_selected = short_selected & short_targets

        self.logger.debug('timestamp:{},long_selected:{}'.format(
            timestamp, long_selected))

        self.logger.debug('timestamp:{},short_selected:{}'.format(
            timestamp, short_selected))

        self.send_trading_signals(timestamp=timestamp,
                                  long_selected=long_selected,
                                  short_selected=short_selected)

    def send_trading_signals(self, timestamp, long_selected, short_selected):
        # current position
        account = self.account_service.latest_account
        current_holdings = [
            position['entity_id'] for position in account['positions']
            if position['available_long'] > 0
        ]

        if long_selected:
            # just long the security not in the positions
            longed = long_selected - set(current_holdings)
            if longed:
                position_pct = 1.0 / len(longed)
                order_money = account['cash'] * position_pct

                for entity_id in longed:
                    trading_signal = TradingSignal(
                        entity_id=entity_id,
                        the_timestamp=timestamp,
                        trading_signal_type=TradingSignalType.
                        trading_signal_open_long,
                        trading_level=self.level,
                        order_money=order_money)
                    for listener in self.trading_signal_listeners:
                        listener.on_trading_signal(trading_signal)

        # just short the security in current_holdings and short_selected
        if short_selected:
            shorted = set(current_holdings) & short_selected

            for entity_id in shorted:
                trading_signal = TradingSignal(
                    entity_id=entity_id,
                    the_timestamp=timestamp,
                    trading_signal_type=TradingSignalType.
                    trading_signal_close_long,
                    position_pct=1.0,
                    trading_level=self.level)
                for listener in self.trading_signal_listeners:
                    listener.on_trading_signal(trading_signal)

    def on_finish(self):
        # show the result
        if self.draw_result:
            import plotly.io as pio
            pio.renderers.default = "browser"
            reader = AccountReader(trader_names=[self.trader_name])
            df = reader.data_df.reset_index()
            drawer = Drawer(main_data=NormalData(
                df.copy()[['trader_name', 'timestamp', 'all_value']],
                category_field='trader_name'))
            drawer.draw_line()

    def run(self):
        # iterate timestamp of the min level,e.g,9:30,9:35,9.40...for 5min level
        # timestamp represents the timestamp in kdata
        for timestamp in iterate_timestamps(
                entity_type=self.entity_schema,
                exchange=self.exchanges[0],
                start_timestamp=self.start_timestamp,
                end_timestamp=self.end_timestamp,
                level=self.level):

            if not is_trading_date(entity_type=self.entity_schema,
                                   exchange=self.exchanges[0],
                                   timestamp=timestamp):
                continue
            if self.real_time:
                # all selector move on to handle the coming data
                if self.kdata_use_begin_time:
                    real_end_timestamp = timestamp + pd.Timedelta(
                        seconds=self.level.to_second())
                else:
                    real_end_timestamp = timestamp

                waiting_seconds, _ = self.level.count_from_timestamp(
                    real_end_timestamp,
                    one_day_trading_minutes=get_one_day_trading_minutes(
                        entity_type=self.entity_schema))
                # meaning the future kdata not ready yet,we could move on to check
                if waiting_seconds and (waiting_seconds > 0):
                    # iterate the selector from min to max which in finished timestamp kdata
                    for level in self.trading_level_asc:
                        if (is_in_finished_timestamps(
                                entity_type=self.entity_schema,
                                exchange=self.exchanges[0],
                                timestamp=timestamp,
                                level=level)):
                            for selector in self.selectors:
                                if selector.level == level:
                                    selector.move_on(timestamp,
                                                     self.kdata_use_begin_time,
                                                     timeout=waiting_seconds +
                                                     20)

            # on_trading_open to setup the account
            if self.level == IntervalLevel.LEVEL_1DAY or (
                    self.level != IntervalLevel.LEVEL_1DAY
                    and is_open_time(entity_type=self.entity_schema,
                                     exchange=self.exchanges[0],
                                     timestamp=timestamp)):
                self.account_service.on_trading_open(timestamp)

            # the time always move on by min level step and we could check all level targets in the slot
            self.handle_targets_slot(timestamp=timestamp)

            for level in self.trading_level_asc:
                # in every cycle, all level selector do its job in its time
                if (is_in_finished_timestamps(
                        entity_type=self.entity_schema.__name__.lower(),
                        exchange=self.exchanges[0],
                        timestamp=timestamp,
                        level=level)):
                    long_targets, short_targets = self.selectors_comparator.make_decision(
                        timestamp=timestamp, trading_level=level)

                    self.targets_slot.input_targets(level, long_targets,
                                                    short_targets)

            # on_trading_close to calculate date account
            if self.level == IntervalLevel.LEVEL_1DAY or (
                    self.level != IntervalLevel.LEVEL_1DAY
                    and is_close_time(entity_type=self.entity_schema,
                                      exchange=self.exchanges[0],
                                      timestamp=timestamp)):
                self.account_service.on_trading_close(timestamp)

        self.on_finish()
Beispiel #9
0
class Trader(object):
    entity_schema: Type[TradableEntity] = None

    def __init__(self,
                 entity_ids: List[str] = None,
                 exchanges: List[str] = None,
                 codes: List[str] = None,
                 start_timestamp: Union[str, pd.Timestamp] = None,
                 end_timestamp: Union[str, pd.Timestamp] = None,
                 provider: str = None,
                 level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY,
                 trader_name: str = None,
                 real_time: bool = False,
                 kdata_use_begin_time: bool = False,
                 draw_result: bool = True,
                 rich_mode: bool = False,
                 adjust_type: AdjustType = None,
                 profit_threshold=(3, -0.3),
                 keep_history=False) -> None:
        assert self.entity_schema is not None
        assert start_timestamp is not None
        assert end_timestamp is not None

        self.logger = logging.getLogger(__name__)

        if trader_name:
            self.trader_name = trader_name
        else:
            self.trader_name = type(self).__name__.lower()

        self.entity_ids = entity_ids
        self.exchanges = exchanges
        self.codes = codes
        self.provider = provider
        # make sure the min level selector correspond to the provider and level
        self.level = IntervalLevel(level)
        self.real_time = real_time
        self.start_timestamp = to_pd_timestamp(start_timestamp)
        self.end_timestamp = to_pd_timestamp(end_timestamp)

        self.trading_dates = self.entity_schema.get_trading_dates(start_date=self.start_timestamp,
                                                                  end_date=self.end_timestamp)

        if real_time:
            self.logger.info(
                'real_time mode, end_timestamp should be future,you could set it big enough for running forever')
            assert self.end_timestamp >= now_pd_timestamp()

        self.kdata_use_begin_time = kdata_use_begin_time
        self.draw_result = draw_result
        self.rich_mode = rich_mode

        self.adjust_type = AdjustType(adjust_type)
        self.profit_threshold = profit_threshold
        self.keep_history = keep_history

        self.level_map_long_targets = {}
        self.level_map_short_targets = {}
        self.trading_signals: List[TradingSignal] = []
        self.trading_signal_listeners: List[TradingListener] = []
        self.selectors: List[TargetSelector] = []

        self.account_service = SimAccountService(entity_schema=self.entity_schema,
                                                 trader_name=self.trader_name,
                                                 timestamp=self.start_timestamp,
                                                 provider=self.provider,
                                                 level=self.level,
                                                 rich_mode=self.rich_mode,
                                                 adjust_type=self.adjust_type,
                                                 keep_history=self.keep_history)

        self.register_trading_signal_listener(self.account_service)

        self.init_selectors(entity_ids=self.entity_ids, entity_schema=self.entity_schema, exchanges=self.exchanges,
                            codes=self.codes, start_timestamp=self.start_timestamp, end_timestamp=self.end_timestamp,
                            adjust_type=self.adjust_type)

        if self.selectors:
            self.trading_level_asc = list(set([IntervalLevel(selector.level) for selector in self.selectors]))
            self.trading_level_asc.sort()

            self.logger.info(f'trader level:{self.level},selectors level:{self.trading_level_asc}')

            if self.level != self.trading_level_asc[0]:
                raise Exception("trader level should be the min of the selectors")

            self.trading_level_desc = list(self.trading_level_asc)
            self.trading_level_desc.reverse()

            # run selectors for history data at first
            for selector in self.selectors:
                selector.run()

        self.on_start()

    def on_start(self):
        self.logger.info(f'trader:{self.trader_name} on_start')

    def init_selectors(self, entity_ids, entity_schema, exchanges, codes, start_timestamp, end_timestamp,
                       adjust_type=None):
        """
        overwrite it to init selectors if you want to use selector/factor computing model
        :param adjust_type:

        """
        pass

    def update_targets_by_level(self, level: IntervalLevel, long_targets: List[str],
                                short_targets: List[str], ) -> None:
        """
        the trading signals is generated in min level,before that,we should cache targets of all levels

        :param level:
        :param long_targets:
        :param short_targets:
        """
        self.logger.debug(
            f'level:{level},old long targets:{self.level_map_long_targets.get(level)},new long targets:{long_targets}')
        self.level_map_long_targets[level] = long_targets

        self.logger.debug(
            f'level:{level},old short targets:{self.level_map_short_targets.get(level)},new short targets:{short_targets}')
        self.level_map_short_targets[level] = short_targets

    def get_long_targets_by_level(self, level: IntervalLevel) -> List[str]:
        return self.level_map_long_targets.get(level)

    def get_short_targets_by_level(self, level: IntervalLevel) -> List[str]:
        return self.level_map_short_targets.get(level)

    def on_targets_selected_from_levels(self, timestamp) -> Tuple[List[str], List[str]]:
        """
        this method's called in every min level cycle to select targets in all levels generated by the previous cycle
        the default implementation is selecting the targets in all levels
        overwrite it for your custom logic

        :param timestamp: current event time
        :return: long targets, short targets
        """

        long_selected = None

        short_selected = None

        for level in self.trading_level_desc:
            long_targets = self.level_map_long_targets.get(level)
            # long must in all
            if long_targets:
                long_targets = set(long_targets)
                if long_selected is None:
                    long_selected = long_targets
                else:
                    long_selected = long_selected & long_targets
            else:
                long_selected = set()

            short_targets = self.level_map_short_targets.get(level)
            # short any
            if short_targets:
                short_targets = set(short_targets)
                if short_selected is None:
                    short_selected = short_targets
                else:
                    short_selected = short_selected | short_targets

        return long_selected, short_selected

    def get_current_account(self) -> AccountStats:
        return self.account_service.get_current_account()

    def get_current_positions(self) -> List[Position]:
        return self.get_current_account().positions

    def long_position_control(self):
        positions = self.get_current_positions()

        position_pct = 1.0
        if not positions:
            # 没有仓位,买2成
            position_pct = 0.2
        elif len(positions) <= 10:
            # 小于10个持仓,买5成
            position_pct = 0.5

        # 买完
        return position_pct

    def short_position_control(self):
        # 卖完
        return 1.0

    def on_profit_control(self):
        if self.profit_threshold and self.get_current_positions():
            positive = self.profit_threshold[0]
            negative = self.profit_threshold[1]
            close_long_entity_ids = []
            for position in self.get_current_positions():
                if position.available_long > 1:
                    # 止盈
                    if position.profit_rate >= positive:
                        close_long_entity_ids.append(position.entity_id)
                        self.logger.info(f'close profit {position.profit_rate} for {position.entity_id}')
                    # 止损
                    if position.profit_rate <= negative:
                        close_long_entity_ids.append(position.entity_id)
                        self.logger.info(f'cut lost {position.profit_rate} for {position.entity_id}')

            return close_long_entity_ids, None
        return None, None

    def buy(self, due_timestamp, happen_timestamp, entity_ids, ignore_in_position=True):
        if ignore_in_position:
            account = self.get_current_account()
            current_holdings = []
            if account.positions:
                current_holdings = [position.entity_id for position in account.positions if position != None and
                                    position.available_long > 0]

            entity_ids = set(entity_ids) - set(current_holdings)

        if entity_ids:
            position_pct = self.long_position_control()
            position_pct = (1.0 / len(entity_ids)) * position_pct

            for entity_id in entity_ids:
                trading_signal = TradingSignal(entity_id=entity_id,
                                               due_timestamp=due_timestamp,
                                               happen_timestamp=happen_timestamp,
                                               trading_signal_type=TradingSignalType.open_long,
                                               trading_level=self.level,
                                               position_pct=position_pct)
                self.trading_signals.append(trading_signal)

    def sell(self, due_timestamp, happen_timestamp, entity_ids):
        # current position
        account = self.get_current_account()
        current_holdings = []
        if account.positions:
            current_holdings = [position.entity_id for position in account.positions if position != None and
                                position.available_long > 0]

        shorted = set(current_holdings) & set(entity_ids)

        if shorted:
            position_pct = self.short_position_control()

            for entity_id in shorted:
                trading_signal = TradingSignal(entity_id=entity_id,
                                               due_timestamp=due_timestamp,
                                               happen_timestamp=happen_timestamp,
                                               trading_signal_type=TradingSignalType.close_long,
                                               trading_level=self.level,
                                               position_pct=position_pct)
                self.trading_signals.append(trading_signal)

    def trade_the_targets(self, due_timestamp, happen_timestamp, long_selected, short_selected):
        if short_selected:
            self.sell(due_timestamp=due_timestamp, happen_timestamp=happen_timestamp, entity_ids=short_selected)
        if long_selected:
            self.buy(due_timestamp=due_timestamp, happen_timestamp=happen_timestamp, entity_ids=long_selected)

    def on_finish(self, timestmap):
        self.on_trading_finish(timestmap)
        # show the result
        if self.draw_result:
            import plotly.io as pio
            pio.renderers.default = "browser"
            reader = AccountStatsReader(trader_names=[self.trader_name])
            df = reader.data_df
            drawer = Drawer(main_data=NormalData(df.copy()[['trader_name', 'timestamp', 'all_value']],
                                                 category_field='trader_name'))
            drawer.draw_line(show=True)

    def on_targets_filtered(self, timestamp, level, selector: TargetSelector, long_targets: List[str],
                            short_targets: List[str]) -> Tuple[List[str], List[str]]:
        """
        overwrite it to filter the targets from selector

        :param timestamp: the event time
        :param level: the level
        :param selector: the selector
        :param long_targets: the long targets from the selector
        :param short_targets: the short targets from the selector
        :return: filtered long targets, filtered short targets
        """
        self.logger.info(f'on_targets_filtered {level} long:{long_targets}')

        if len(long_targets) > 10:
            long_targets = long_targets[0:10]
        self.logger.info(f'on_targets_filtered {level} filtered long:{long_targets}')

        return long_targets, short_targets

    def in_trading_date(self, timestamp):
        return to_time_str(timestamp) in self.trading_dates

    def on_time(self, timestamp: pd.Timestamp):
        """
        called in every min level cycle

        :param timestamp: event time
        """
        self.logger.debug(f'current timestamp:{timestamp}')

    def on_trading_signals(self, trading_signals: List[TradingSignal]):
        for l in self.trading_signal_listeners:
            l.on_trading_signals(trading_signals)

    def on_trading_signal(self, trading_signal: TradingSignal):
        for l in self.trading_signal_listeners:
            try:
                l.on_trading_signal(trading_signal)
            except Exception as e:
                self.logger.exception(e)
                l.on_trading_error(timestamp=trading_signal.happen_timestamp, error=e)

    def on_trading_open(self, timestamp):
        for l in self.trading_signal_listeners:
            l.on_trading_open(timestamp)

    def on_trading_close(self, timestamp):
        for l in self.trading_signal_listeners:
            l.on_trading_close(timestamp)

    def on_trading_finish(self, timestamp):
        for l in self.trading_signal_listeners:
            l.on_trading_finish(timestamp)

    def on_trading_error(self, timestamp, error):
        for l in self.trading_signal_listeners:
            l.on_trading_error(timestamp, error)

    def run(self):
        # iterate timestamp of the min level,e.g,9:30,9:35,9.40...for 5min level
        # timestamp represents the timestamp in kdata
        for timestamp in self.entity_schema.get_interval_timestamps(start_date=self.start_timestamp,
                                                                    end_date=self.end_timestamp, level=self.level):

            if not self.in_trading_date(timestamp=timestamp):
                continue

            waiting_seconds = 0

            if self.level == IntervalLevel.LEVEL_1DAY:
                if is_same_date(timestamp, now_pd_timestamp()):
                    while True:
                        self.logger.info(f'time is:{now_pd_timestamp()},just smoke for minutes')
                        time.sleep(60)
                        current = now_pd_timestamp()
                        if current.hour >= 19:
                            waiting_seconds = 20
                            break

            elif self.real_time:
                # all selector move on to handle the coming data
                if self.kdata_use_begin_time:
                    real_end_timestamp = timestamp + pd.Timedelta(seconds=self.level.to_second())
                else:
                    real_end_timestamp = timestamp

                seconds = (now_pd_timestamp() - real_end_timestamp).total_seconds()
                waiting_seconds = self.level.to_second() - seconds

            # meaning the future kdata not ready yet,we could move on to check
            if waiting_seconds > 0:
                # iterate the selector from min to max which in finished timestamp kdata
                for level in self.trading_level_asc:
                    if self.entity_schema.is_finished_kdata_timestamp(timestamp=timestamp, level=level):
                        for selector in self.selectors:
                            if selector.level == level:
                                selector.move_on(timestamp, self.kdata_use_begin_time, timeout=waiting_seconds + 20)

            # on_trading_open to setup the account
            if self.level >= IntervalLevel.LEVEL_1DAY or (
                    self.level != IntervalLevel.LEVEL_1DAY and self.entity_schema.is_open_timestamp(timestamp)):
                self.on_trading_open(timestamp)

            self.on_time(timestamp=timestamp)

            # 一般来说selector(factors)计算 多标的 历史数据比较快,多级别的计算也比较方便,常用于全市场标的粗过滤
            # 更细节的控制可以在on_targets_filtered里进一步处理
            # 也可以在on_time里面设计一些自己的逻辑配合过滤
            if self.selectors:
                # 多级别的遍历算法要点:
                # 1)计算各级别的 标的,通过 on_targets_filtered 过滤,缓存在level_map_long_targets,level_map_short_targets
                # 2)在最小的level通过 on_targets_selected_from_levels 根据多级别的缓存标的,生成最终的选中标的
                # 这里需要注意的是,小级别拿到上一个周期的大级别的标的,这是合理的
                for level in self.trading_level_asc:
                    # in every cycle, all level selector do its job in its time
                    if self.entity_schema.is_finished_kdata_timestamp(timestamp=timestamp, level=level):
                        all_long_targets = []
                        all_short_targets = []

                        # 从该level的selector中过滤targets
                        for selector in self.selectors:
                            if selector.level == level:
                                long_targets = selector.get_open_long_targets(timestamp=timestamp)
                                short_targets = selector.get_open_short_targets(timestamp=timestamp)

                                if long_targets or short_targets:
                                    long_targets, short_targets = self.on_targets_filtered(timestamp=timestamp,
                                                                                           level=level,
                                                                                           selector=selector,
                                                                                           long_targets=long_targets,
                                                                                           short_targets=short_targets)

                                if long_targets:
                                    all_long_targets += long_targets
                                if short_targets:
                                    all_short_targets += short_targets

                        # 将各级别的targets缓存在level_map_long_targets,level_map_short_targets
                        self.update_targets_by_level(level, all_long_targets, all_short_targets)

                        # the time always move on by min level step and we could check all targets of levels
                        # 1)the targets is generated for next interval
                        # 2)the acceptable price is next interval prices,you could buy it at current price
                        # if the time is before the timestamp(due_timestamp) when trading signal received
                        # 3)the suggest price the the close price for generating the signal(happen_timestamp)
                        due_timestamp = timestamp + pd.Timedelta(seconds=self.level.to_second())

                        # 在最小level生成最终的 交易信号
                        if level == self.level:
                            long_selected, short_selected = self.on_targets_selected_from_levels(timestamp)

                            # 处理 止赢 止损
                            passive_short, _ = self.on_profit_control()
                            if passive_short:
                                if not short_selected:
                                    short_selected = passive_short
                                else:
                                    short_selected = list(set(short_selected) | set(passive_short))

                            self.logger.debug('timestamp:{},long_selected:{}'.format(due_timestamp, long_selected))
                            self.logger.debug('timestamp:{},short_selected:{}'.format(due_timestamp, short_selected))

                            if long_selected or short_selected:
                                self.trade_the_targets(due_timestamp=due_timestamp, happen_timestamp=timestamp,
                                                       long_selected=long_selected, short_selected=short_selected)

            if self.trading_signals:
                self.on_trading_signals(self.trading_signals)
            # clear
            self.trading_signals = []

            # on_trading_close to calculate date account
            if self.level >= IntervalLevel.LEVEL_1DAY or (
                    self.level != IntervalLevel.LEVEL_1DAY and self.entity_schema.is_close_timestamp(timestamp)):
                self.on_trading_close(timestamp)

        self.on_finish(timestamp)

    def register_trading_signal_listener(self, listener):
        if listener not in self.trading_signal_listeners:
            self.trading_signal_listeners.append(listener)

    def deregister_trading_signal_listener(self, listener):
        if listener in self.trading_signal_listeners:
            self.trading_signal_listeners.remove(listener)
Beispiel #10
0
class Trader(object):
    logger = logging.getLogger(__name__)

    def __init__(self,
                 security_type=SecurityType.stock,
                 exchanges=['sh', 'sz'],
                 codes=None,
                 start_timestamp=None,
                 end_timestamp=None) -> None:

        self.trader_name = type(self).__name__.lower()
        self.trading_signal_queue = queue.Queue()
        self.trading_signal_listeners = []
        self.state_listeners = []

        self.selectors: List[TargetSelector] = None

        self.security_type = security_type
        self.exchanges = exchanges
        self.codes = codes
        self.start_timestamp = start_timestamp
        self.end_timestamp = end_timestamp

        if self.start_timestamp:
            self.start_timestamp = to_pd_timestamp(self.start_timestamp)
        else:
            self.start_timestamp = now_pd_timestamp()

        self.current_timestamp = self.start_timestamp

        if self.end_timestamp:
            self.end_timestamp = to_pd_timestamp(self.end_timestamp)

        self.account_service = SimAccountService(
            trader_name=self.trader_name, timestamp=self.start_timestamp)

        self.add_trading_signal_listener(self.account_service)

        self.init_selectors(security_type=self.security_type,
                            exchanges=self.exchanges,
                            codes=self.codes,
                            start_timestamp=self.start_timestamp,
                            end_timestamp=self.end_timestamp)

    def init_selectors(self, security_type, exchanges, codes, start_timestamp,
                       end_timestamp):
        """
        implement this to init selectors

        :param security_type:
        :type security_type:
        :param exchanges:
        :type exchanges:
        :param codes:
        :type codes:
        :param start_timestamp:
        :type start_timestamp:
        :param end_timestamp:
        :type end_timestamp:
        """
        raise NotImplementedError

    def send_trading_signal(self, trading_signal):
        self.trading_signal_queue.put(trading_signal)

    def add_trading_signal_listener(self, listener):
        if listener not in self.trading_signal_listeners:
            self.trading_signal_listeners.append(listener)

    def remove_trading_signal_listener(self, listener):
        if listener in self.trading_signal_listeners:
            self.trading_signal_listeners.remove(listener)

    def run(self):
        # now we just support day level
        for timestamp in pd.date_range(start=self.start_timestamp,
                                       end=self.end_timestamp,
                                       freq='B').tolist():

            account = self.account_service.get_account_at_time(timestamp)
            positions = [
                position.security_id for position in account.positions
            ]

            # select the targets from the selectors
            selected = set()
            for selector in self.selectors:
                df = selector.get_targets(timestamp)
                if not df.empty:
                    targets = set(df['security_id'].to_list())
                    if not selected:
                        selected = targets
                    else:
                        selected = selected & targets

            if selected:
                # just long the security not in the positions
                longed = selected - set(positions)
                position_pct = 1.0 / len(longed)

                for security_id in longed:
                    trading_signal = TradingSignal(
                        security_id=security_id,
                        the_timestamp=timestamp,
                        trading_signal_type=TradingSignalType.
                        trading_signal_open_long,
                        trading_level=None,
                        position_pct=position_pct)
                    for listener in self.trading_signal_listeners:
                        listener.on_trading_signal(trading_signal)

            shorted = set(positions) - selected

            for security_id in shorted:
                trading_signal = TradingSignal(
                    security_id=security_id,
                    the_timestamp=timestamp,
                    trading_signal_type=TradingSignalType.
                    trading_signal_close_long,
                    trading_level=None)
                for listener in self.trading_signal_listeners:
                    listener.on_trading_signal(trading_signal)

            self.account_service.save_closing_account(timestamp)
Beispiel #11
0
    def __init__(self,
                 security_list: List[str] = None,
                 security_type: Union[str, SecurityType] = SecurityType.stock,
                 exchanges: List[str] = ['sh', 'sz'],
                 codes: List[str] = None,
                 start_timestamp: Union[str, pd.Timestamp] = None,
                 end_timestamp: Union[str, pd.Timestamp] = None,
                 provider: Union[str, Provider] = Provider.JOINQUANT,
                 level: Union[str, TradingLevel] = TradingLevel.LEVEL_1DAY,
                 trader_name: str = None,
                 real_time: bool = False,
                 kdata_use_begin_time: bool = False) -> None:
        if trader_name:
            self.trader_name = trader_name
        else:
            self.trader_name = type(self).__name__.lower()

        self.trading_signal_listeners = []
        self.state_listeners = []

        self.selectors: List[TargetSelector] = None

        self.security_list = security_list
        self.security_type = SecurityType(security_type)
        self.exchanges = exchanges
        self.codes = codes

        self.provider = provider
        # make sure the min level selector correspond to the provider and level
        self.level = TradingLevel(level)
        self.real_time = real_time

        if start_timestamp and end_timestamp:
            self.start_timestamp = to_pd_timestamp(start_timestamp)
            self.end_timestamp = to_pd_timestamp(end_timestamp)
        else:
            assert False

        if real_time:
            logger.info(
                'real_time mode, end_timestamp should be future,you could set it big enough for running forever'
            )
            assert self.end_timestamp >= now_pd_timestamp()

        self.kdata_use_begin_time = kdata_use_begin_time

        self.account_service = SimAccountService(
            trader_name=self.trader_name,
            timestamp=self.start_timestamp,
            provider=self.provider,
            level=self.level)

        self.add_trading_signal_listener(self.account_service)

        self.init_selectors(security_list=security_list,
                            security_type=self.security_type,
                            exchanges=self.exchanges,
                            codes=self.codes,
                            start_timestamp=self.start_timestamp,
                            end_timestamp=self.end_timestamp)

        self.selectors_comparator = self.init_selectors_comparator()

        self.trading_level_asc = list(
            set([TradingLevel(selector.level) for selector in self.selectors]))
        self.trading_level_asc.sort()

        self.trading_level_desc = list(self.trading_level_asc)
        self.trading_level_desc.reverse()

        self.targets_slot: TargetsSlot = TargetsSlot()
Beispiel #12
0
class Trader(Constructor):
    logger = logging.getLogger(__name__)

    def __init__(self,
                 security_list: List[str] = None,
                 security_type: Union[str, SecurityType] = SecurityType.stock,
                 exchanges: List[str] = ['sh', 'sz'],
                 codes: List[str] = None,
                 start_timestamp: Union[str, pd.Timestamp] = None,
                 end_timestamp: Union[str, pd.Timestamp] = None,
                 provider: Union[str, Provider] = Provider.JOINQUANT,
                 level: Union[str, TradingLevel] = TradingLevel.LEVEL_1DAY,
                 trader_name: str = None,
                 real_time: bool = False,
                 kdata_use_begin_time: bool = False) -> None:
        if trader_name:
            self.trader_name = trader_name
        else:
            self.trader_name = type(self).__name__.lower()

        self.trading_signal_listeners = []
        self.state_listeners = []

        self.selectors: List[TargetSelector] = None

        self.security_list = security_list
        self.security_type = SecurityType(security_type)
        self.exchanges = exchanges
        self.codes = codes

        self.provider = provider
        # make sure the min level selector correspond to the provider and level
        self.level = TradingLevel(level)
        self.real_time = real_time

        if start_timestamp and end_timestamp:
            self.start_timestamp = to_pd_timestamp(start_timestamp)
            self.end_timestamp = to_pd_timestamp(end_timestamp)
        else:
            assert False

        if real_time:
            logger.info(
                'real_time mode, end_timestamp should be future,you could set it big enough for running forever'
            )
            assert self.end_timestamp >= now_pd_timestamp()

        self.kdata_use_begin_time = kdata_use_begin_time

        self.account_service = SimAccountService(
            trader_name=self.trader_name,
            timestamp=self.start_timestamp,
            provider=self.provider,
            level=self.level)

        self.add_trading_signal_listener(self.account_service)

        self.init_selectors(security_list=security_list,
                            security_type=self.security_type,
                            exchanges=self.exchanges,
                            codes=self.codes,
                            start_timestamp=self.start_timestamp,
                            end_timestamp=self.end_timestamp)

        self.selectors_comparator = self.init_selectors_comparator()

        self.trading_level_asc = list(
            set([TradingLevel(selector.level) for selector in self.selectors]))
        self.trading_level_asc.sort()

        self.trading_level_desc = list(self.trading_level_asc)
        self.trading_level_desc.reverse()

        self.targets_slot: TargetsSlot = TargetsSlot()

    def init_selectors(self, security_list, security_type, exchanges, codes,
                       start_timestamp, end_timestamp):
        """
        implement this to init selectors

        """
        raise NotImplementedError

    def init_selectors_comparator(self):
        """
        overwrite this to set selectors_comparator

        """
        return LimitSelectorsComparator(self.selectors)

    def add_trading_signal_listener(self, listener):
        if listener not in self.trading_signal_listeners:
            self.trading_signal_listeners.append(listener)

    def remove_trading_signal_listener(self, listener):
        if listener in self.trading_signal_listeners:
            self.trading_signal_listeners.remove(listener)

    def handle_targets_slot(self, timestamp):
        """
        this function would be called in every cycle, you could overwrite it for your custom algorithm to select the
        targets of different levels

        the default implementation is selecting the targets in all levels

        :param timestamp:
        :type timestamp:
        """
        selected = None
        for level in self.trading_level_desc:
            targets = self.targets_slot.get_targets(level=level)
            if not targets:
                targets = set()

            if not selected:
                selected = targets
            else:
                selected = selected & targets

        if selected:
            self.logger.debug('timestamp:{},selected:{}'.format(
                timestamp, selected))

        self.send_trading_signals(timestamp=timestamp, selected=selected)

    def send_trading_signals(self, timestamp, selected):
        # current position
        account = self.account_service.latest_account
        current_holdings = [
            position['security_id'] for position in account['positions']
            if position['available_long'] > 0
        ]

        if selected:
            # just long the security not in the positions
            longed = selected - set(current_holdings)
            if longed:
                position_pct = 1.0 / len(longed)
                order_money = account['cash'] * position_pct

                for security_id in longed:
                    trading_signal = TradingSignal(
                        security_id=security_id,
                        the_timestamp=timestamp,
                        trading_signal_type=TradingSignalType.
                        trading_signal_open_long,
                        trading_level=self.level,
                        order_money=order_money)
                    for listener in self.trading_signal_listeners:
                        listener.on_trading_signal(trading_signal)

        # just short the security not in the selected but in current_holdings
        if selected:
            shorted = set(current_holdings) - selected
        else:
            shorted = set(current_holdings)

        for security_id in shorted:
            trading_signal = TradingSignal(
                security_id=security_id,
                the_timestamp=timestamp,
                trading_signal_type=TradingSignalType.
                trading_signal_close_long,
                position_pct=1.0,
                trading_level=self.level)
            for listener in self.trading_signal_listeners:
                listener.on_trading_signal(trading_signal)

    def on_finish(self):
        # show the result
        pass

    def run(self):
        # iterate timestamp of the min level,e.g,9:30,9:35,9.40...for 5min level
        # timestamp represents the timestamp in kdata
        for timestamp in iterate_timestamps(
                security_type=self.security_type,
                exchange=self.exchanges[0],
                start_timestamp=self.start_timestamp,
                end_timestamp=self.end_timestamp,
                level=self.level):
            if self.real_time:
                # all selector move on to handle the coming data
                if self.kdata_use_begin_time:
                    real_end_timestamp = timestamp + pd.Timedelta(
                        seconds=self.level.to_second())
                else:
                    real_end_timestamp = timestamp

                waiting_seconds, _ = self.level.count_from_timestamp(
                    real_end_timestamp,
                    one_day_trading_minutes=get_one_day_trading_minutes(
                        security_type=self.security_type))
                # meaning the future kdata not ready yet,we could move on to check
                if waiting_seconds and (waiting_seconds > 0):
                    # iterate the selector from min to max which in finished timestamp kdata
                    for level in self.trading_level_asc:
                        if (is_in_finished_timestamps(
                                security_type=self.security_type,
                                exchange=self.exchanges[0],
                                timestamp=timestamp,
                                level=level)):
                            for selector in self.selectors:
                                if selector.level == level:
                                    selector.move_on(timestamp,
                                                     self.kdata_use_begin_time)

            # on_trading_open to setup the account
            if self.level == TradingLevel.LEVEL_1DAY or (
                    self.level != TradingLevel.LEVEL_1DAY
                    and is_open_time(security_type=self.security_type,
                                     exchange=self.exchanges[0],
                                     timestamp=timestamp)):
                self.account_service.on_trading_open(timestamp)

            # the time always move on by min level step and we could check all level targets in the slot
            self.handle_targets_slot(timestamp=timestamp)

            for level in self.trading_level_asc:
                # in every cycle, all level selector do its job in its time
                if (is_in_finished_timestamps(security_type=self.security_type,
                                              exchange=self.exchanges[0],
                                              timestamp=timestamp,
                                              level=level)):
                    df = self.selectors_comparator.make_decision(
                        timestamp=timestamp, trading_level=level)
                    if not df.empty:
                        selected = set(df['security_id'].to_list())
                    else:
                        selected = {}

                    self.targets_slot.input_targets(level, selected)

            # on_trading_close to calculate date account
            if self.level == TradingLevel.LEVEL_1DAY or (
                    self.level != TradingLevel.LEVEL_1DAY
                    and is_close_time(security_type=self.security_type,
                                      exchange=self.exchanges[0],
                                      timestamp=timestamp)):
                self.account_service.on_trading_close(timestamp)

        self.on_finish()
Beispiel #13
0
class Trader(object):
    logger = logging.getLogger(__name__)

    def __init__(self,
                 security_type=SecurityType.stock,
                 exchanges=['sh', 'sz'],
                 codes=None,
                 start_timestamp=None,
                 end_timestamp=None,
                 provider=Provider.JOINQUANT,
                 trading_level=TradingLevel.LEVEL_1DAY,
                 trader_name=None) -> None:
        if trader_name:
            self.trader_name = trader_name
        else:
            self.trader_name = type(self).__name__.lower()
        self.trading_signal_listeners = []
        self.state_listeners = []

        self.selectors: List[TargetSelector] = None

        self.security_type = security_type
        self.exchanges = exchanges
        self.codes = codes
        # make sure the min level selector correspond to the provider and level
        self.provider = provider
        self.trading_level = trading_level

        if start_timestamp and end_timestamp:
            self.start_timestamp = to_pd_timestamp(start_timestamp)
            self.end_timestamp = to_pd_timestamp(end_timestamp)
        else:
            assert False

        self.account_service = SimAccountService(
            trader_name=self.trader_name,
            timestamp=self.start_timestamp,
            provider=self.provider,
            level=self.trading_level)

        self.add_trading_signal_listener(self.account_service)

        self.init_selectors(security_type=self.security_type,
                            exchanges=self.exchanges,
                            codes=self.codes,
                            start_timestamp=self.start_timestamp,
                            end_timestamp=self.end_timestamp)

        self.selectors_comparator = LimitSelectorsComparator(self.selectors)

        self.trading_level_asc = list(
            set([TradingLevel(selector.level) for selector in self.selectors]))
        self.trading_level_asc.sort()

        self.trading_level_desc = list(self.trading_level_asc)
        self.trading_level_desc.reverse()

        self.targets_slot: TargetsSlot = TargetsSlot()

    def init_selectors(self, security_type, exchanges, codes, start_timestamp,
                       end_timestamp):
        """
        implement this to init selectors

        :param security_type:
        :type security_type:
        :param exchanges:
        :type exchanges:
        :param codes:
        :type codes:
        :param start_timestamp:
        :type start_timestamp:
        :param end_timestamp:
        :type end_timestamp:
        """
        raise NotImplementedError

    def add_trading_signal_listener(self, listener):
        if listener not in self.trading_signal_listeners:
            self.trading_signal_listeners.append(listener)

    def remove_trading_signal_listener(self, listener):
        if listener in self.trading_signal_listeners:
            self.trading_signal_listeners.remove(listener)

    def handle_targets_slot(self, timestamp):
        # the default behavior is selecting the targets in all levels
        selected = None
        for level in self.trading_level_desc:
            targets = self.targets_slot.get_targets(level=level)
            if not targets:
                targets = {}

            if not selected:
                selected = targets
            else:
                selected = selected & targets

        if selected:
            self.logger.info('timestamp:{},selected:{}'.format(
                timestamp, selected))

        self.send_trading_signals(timestamp=timestamp, selected=selected)

    def send_trading_signals(self, timestamp, selected):
        # current position
        account = self.account_service.latest_account
        current_holdings = [
            position['security_id'] for position in account['positions']
        ]

        if selected:
            # just long the security not in the positions
            longed = selected - set(current_holdings)
            if longed:
                position_pct = 1.0 / len(longed)
                order_money = account['cash'] * position_pct

                for security_id in longed:
                    trading_signal = TradingSignal(
                        security_id=security_id,
                        the_timestamp=timestamp,
                        trading_signal_type=TradingSignalType.
                        trading_signal_open_long,
                        trading_level=self.trading_level,
                        order_money=order_money)
                    for listener in self.trading_signal_listeners:
                        listener.on_trading_signal(trading_signal)

        # just short the security not in the selected but in current_holdings
        if selected:
            shorted = set(current_holdings) - selected
        else:
            shorted = set(current_holdings)

        for security_id in shorted:
            trading_signal = TradingSignal(
                security_id=security_id,
                the_timestamp=timestamp,
                trading_signal_type=TradingSignalType.
                trading_signal_close_long,
                position_pct=1.0,
                trading_level=self.trading_level)
            for listener in self.trading_signal_listeners:
                listener.on_trading_signal(trading_signal)

    def on_finish(self):
        draw_account_details(trader_name=self.trader_name)
        draw_order_signals(trader_name=self.trader_name)

    def run(self):
        # iterate timestamp of the min level
        for timestamp in iterate_timestamps(
                security_type=self.security_type,
                exchange=self.exchanges[0],
                start_timestamp=self.start_timestamp,
                end_timestamp=self.end_timestamp,
                level=self.trading_level):
            # on_trading_open to setup the account
            if self.trading_level == TradingLevel.LEVEL_1DAY or (
                    self.trading_level != TradingLevel.LEVEL_1DAY
                    and is_open_time(security_type=self.security_type,
                                     exchange=self.exchanges[0],
                                     timestamp=timestamp)):
                self.account_service.on_trading_open(timestamp)

            # the time always move on by min level step and we could check all level targets in the slot
            self.handle_targets_slot(timestamp=timestamp)

            for level in self.trading_level_asc:
                # in every cycle, all level selector do its job in its time
                if (is_in_finished_timestamps(security_type=self.security_type,
                                              exchange=self.exchanges[0],
                                              timestamp=timestamp,
                                              level=level)):
                    df = self.selectors_comparator.make_decision(
                        timestamp=timestamp, trading_level=level)
                    if not df.empty:
                        selected = set(df['security_id'].to_list())
                    else:
                        selected = {}

                    self.targets_slot.input_targets(level, selected)

            # on_trading_close to calculate date account
            if self.trading_level == TradingLevel.LEVEL_1DAY or (
                    self.trading_level != TradingLevel.LEVEL_1DAY
                    and is_close_time(security_type=self.security_type,
                                      exchange=self.exchanges[0],
                                      timestamp=timestamp)):
                self.account_service.on_trading_close(timestamp)

        self.on_finish()
Beispiel #14
0
    def __init__(self,
                 entity_ids: List[str] = None,
                 exchanges: List[str] = None,
                 codes: List[str] = None,
                 start_timestamp: Union[str, pd.Timestamp] = None,
                 end_timestamp: Union[str, pd.Timestamp] = None,
                 provider: str = None,
                 level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY,
                 trader_name: str = None,
                 real_time: bool = False,
                 kdata_use_begin_time: bool = False,
                 draw_result: bool = True,
                 rich_mode: bool = False,
                 adjust_type: AdjustType = None,
                 profit_threshold=(3, -0.3),
                 keep_history=False) -> None:
        assert self.entity_schema is not None
        assert start_timestamp is not None
        assert end_timestamp is not None

        self.logger = logging.getLogger(__name__)

        if trader_name:
            self.trader_name = trader_name
        else:
            self.trader_name = type(self).__name__.lower()

        self.entity_ids = entity_ids
        self.exchanges = exchanges
        self.codes = codes
        self.provider = provider
        # make sure the min level selector correspond to the provider and level
        self.level = IntervalLevel(level)
        self.real_time = real_time
        self.start_timestamp = to_pd_timestamp(start_timestamp)
        self.end_timestamp = to_pd_timestamp(end_timestamp)

        self.trading_dates = self.entity_schema.get_trading_dates(start_date=self.start_timestamp,
                                                                  end_date=self.end_timestamp)

        if real_time:
            self.logger.info(
                'real_time mode, end_timestamp should be future,you could set it big enough for running forever')
            assert self.end_timestamp >= now_pd_timestamp()

        self.kdata_use_begin_time = kdata_use_begin_time
        self.draw_result = draw_result
        self.rich_mode = rich_mode

        self.adjust_type = AdjustType(adjust_type)
        self.profit_threshold = profit_threshold
        self.keep_history = keep_history

        self.level_map_long_targets = {}
        self.level_map_short_targets = {}
        self.trading_signals: List[TradingSignal] = []
        self.trading_signal_listeners: List[TradingListener] = []
        self.selectors: List[TargetSelector] = []

        self.account_service = SimAccountService(entity_schema=self.entity_schema,
                                                 trader_name=self.trader_name,
                                                 timestamp=self.start_timestamp,
                                                 provider=self.provider,
                                                 level=self.level,
                                                 rich_mode=self.rich_mode,
                                                 adjust_type=self.adjust_type,
                                                 keep_history=self.keep_history)

        self.register_trading_signal_listener(self.account_service)

        self.init_selectors(entity_ids=self.entity_ids, entity_schema=self.entity_schema, exchanges=self.exchanges,
                            codes=self.codes, start_timestamp=self.start_timestamp, end_timestamp=self.end_timestamp,
                            adjust_type=self.adjust_type)

        if self.selectors:
            self.trading_level_asc = list(set([IntervalLevel(selector.level) for selector in self.selectors]))
            self.trading_level_asc.sort()

            self.logger.info(f'trader level:{self.level},selectors level:{self.trading_level_asc}')

            if self.level != self.trading_level_asc[0]:
                raise Exception("trader level should be the min of the selectors")

            self.trading_level_desc = list(self.trading_level_asc)
            self.trading_level_desc.reverse()

            # run selectors for history data at first
            for selector in self.selectors:
                selector.run()

        self.on_start()
Beispiel #15
0
    def __init__(self,
                 entity_ids: List[str] = None,
                 exchanges: List[str] = None,
                 codes: List[str] = None,
                 start_timestamp: Union[str, pd.Timestamp] = None,
                 end_timestamp: Union[str, pd.Timestamp] = None,
                 provider: str = None,
                 level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY,
                 trader_name: str = None,
                 real_time: bool = False,
                 kdata_use_begin_time: bool = False,
                 draw_result: bool = True) -> None:
        assert self.entity_schema is not None

        self.logger = logging.getLogger(__name__)

        if trader_name:
            self.trader_name = trader_name
        else:
            self.trader_name = type(self).__name__.lower()

        self.trading_signal_listeners: List[TradingListener] = []

        self.selectors: List[TargetSelector] = []

        self.entity_ids = entity_ids

        self.exchanges = exchanges
        self.codes = codes

        self.provider = provider
        # make sure the min level selector correspond to the provider and level
        self.level = IntervalLevel(level)
        self.real_time = real_time

        if start_timestamp and end_timestamp:
            self.start_timestamp = to_pd_timestamp(start_timestamp)
            self.end_timestamp = to_pd_timestamp(end_timestamp)
        else:
            assert False

        self.trading_dates = self.entity_schema.get_trading_dates(
            start_date=self.start_timestamp, end_date=self.end_timestamp)

        if real_time:
            logger.info(
                'real_time mode, end_timestamp should be future,you could set it big enough for running forever'
            )
            assert self.end_timestamp >= now_pd_timestamp()

        self.kdata_use_begin_time = kdata_use_begin_time
        self.draw_result = draw_result

        self.account_service = SimAccountService(
            entity_schema=self.entity_schema,
            trader_name=self.trader_name,
            timestamp=self.start_timestamp,
            provider=self.provider,
            level=self.level)

        self.add_trading_signal_listener(self.account_service)

        self.init_selectors(entity_ids=entity_ids,
                            entity_schema=self.entity_schema,
                            exchanges=self.exchanges,
                            codes=self.codes,
                            start_timestamp=self.start_timestamp,
                            end_timestamp=self.end_timestamp)

        if self.selectors:
            self.trading_level_asc = list(
                set([
                    IntervalLevel(selector.level)
                    for selector in self.selectors
                ]))
            self.trading_level_asc.sort()

            self.logger.info(
                f'trader level:{self.level},selectors level:{self.trading_level_asc}'
            )

            if self.level != self.trading_level_asc[0]:
                raise Exception(
                    "trader level should be the min of the selectors")

            self.trading_level_desc = list(self.trading_level_asc)
            self.trading_level_desc.reverse()

        self.targets_slot: TargetsSlot = TargetsSlot()

        self.session = get_db_session('zvt', data_schema=TraderInfo)
        self.on_start()
Beispiel #16
0
class Trader(object):
    logger = logging.getLogger(__name__)

    # overwrite it to custom your trader
    selectors_comparator = SelectorsComparator(limit=5)

    def __init__(self,
                 security_type=SecurityType.stock,
                 exchanges=['sh', 'sz'],
                 codes=None,
                 start_timestamp=None,
                 end_timestamp=None) -> None:

        self.trader_name = type(self).__name__.lower()
        self.trading_signal_listeners = []
        self.state_listeners = []

        self.selectors: List[TargetSelector] = None

        self.security_type = security_type
        self.exchanges = exchanges
        self.codes = codes

        if start_timestamp and end_timestamp:
            self.start_timestamp = to_pd_timestamp(start_timestamp)
            self.end_timestamp = to_pd_timestamp(end_timestamp)
        else:
            assert False

        self.account_service = SimAccountService(
            trader_name=self.trader_name, timestamp=self.start_timestamp)

        self.add_trading_signal_listener(self.account_service)

        self.init_selectors(security_type=self.security_type,
                            exchanges=self.exchanges,
                            codes=self.codes,
                            start_timestamp=self.start_timestamp,
                            end_timestamp=self.end_timestamp)

        self.selectors_comparator.add_selectors(self.selectors)

    def init_selectors(self, security_type, exchanges, codes, start_timestamp,
                       end_timestamp):
        """
        implement this to init selectors

        :param security_type:
        :type security_type:
        :param exchanges:
        :type exchanges:
        :param codes:
        :type codes:
        :param start_timestamp:
        :type start_timestamp:
        :param end_timestamp:
        :type end_timestamp:
        """
        raise NotImplementedError

    def add_trading_signal_listener(self, listener):
        if listener not in self.trading_signal_listeners:
            self.trading_signal_listeners.append(listener)

    def remove_trading_signal_listener(self, listener):
        if listener in self.trading_signal_listeners:
            self.trading_signal_listeners.remove(listener)

    def run(self):
        # now we just support day level
        for timestamp in pd.date_range(start=self.start_timestamp,
                                       end=self.end_timestamp,
                                       freq='B').tolist():

            self.account_service.on_trading_open(timestamp)

            account = self.account_service.latest_account
            current_holdings = [
                position['security_id'] for position in account['positions']
            ]

            df = self.selectors_comparator.make_decision(timestamp=timestamp)

            selected = set()
            if not df.empty:
                selected = set(df['security_id'].to_list())

            if selected:
                # just long the security not in the positions
                longed = selected - set(current_holdings)
                if longed:
                    position_pct = 1.0 / len(longed)
                    order_money = account['cash'] * position_pct

                    for security_id in longed:
                        trading_signal = TradingSignal(
                            security_id=security_id,
                            the_timestamp=timestamp,
                            trading_signal_type=TradingSignalType.
                            trading_signal_open_long,
                            trading_level=TradingLevel.LEVEL_1DAY,
                            order_money=order_money)
                        for listener in self.trading_signal_listeners:
                            listener.on_trading_signal(trading_signal)

            shorted = set(current_holdings) - selected

            for security_id in shorted:
                trading_signal = TradingSignal(
                    security_id=security_id,
                    the_timestamp=timestamp,
                    trading_signal_type=TradingSignalType.
                    trading_signal_close_long,
                    position_pct=1.0,
                    trading_level=TradingLevel.LEVEL_1DAY)
                for listener in self.trading_signal_listeners:
                    listener.on_trading_signal(trading_signal)

            self.account_service.on_trading_close(timestamp)
Beispiel #17
0
    def __init__(self, security_list=None, security_type=SecurityType.stock, exchanges=['sh', 'sz'], codes=None,
                 start_timestamp=None,
                 end_timestamp=None,
                 provider=Provider.JOINQUANT,
                 trading_level=TradingLevel.LEVEL_1DAY,
                 trader_name=None,
                 real_time=False,
                 kdata_use_begin_time=False) -> None:
        """

        :param security_list:
        :type security_list:
        :param security_type:
        :type security_type:
        :param exchanges:
        :type exchanges:
        :param codes:
        :type codes:
        :param start_timestamp:
        :type start_timestamp:
        :param end_timestamp:
        :type end_timestamp:
        :param provider:
        :type provider:
        :param trading_level:
        :type trading_level:
        :param trader_name:
        :type trader_name:
        :param real_time:
        :type real_time:
        :param kdata_use_begin_time: true means the interval [timestamp,timestamp+level),false means [timestamp-level,timestamp)
        :type kdata_use_begin_time: bool

        """
        if trader_name:
            self.trader_name = trader_name
        else:
            self.trader_name = type(self).__name__.lower()

        self.trading_signal_listeners = []
        self.state_listeners = []

        self.selectors: List[TargetSelector] = None

        self.security_list = security_list
        self.security_type = security_type
        self.exchanges = exchanges
        self.codes = codes

        self.provider = provider
        # make sure the min level selector correspond to the provider and level
        self.trading_level = trading_level
        self.real_time = real_time

        if start_timestamp and end_timestamp:
            self.start_timestamp = to_pd_timestamp(start_timestamp)
            self.end_timestamp = to_pd_timestamp(end_timestamp)
        else:
            assert False

        if real_time:
            logger.info(
                'real_time mode, end_timestamp should be future,you could set it big enough for running forever')
            assert self.end_timestamp >= now_pd_timestamp()

        self.kdata_use_begin_time = kdata_use_begin_time

        self.account_service = SimAccountService(trader_name=self.trader_name,
                                                 timestamp=self.start_timestamp,
                                                 provider=self.provider,
                                                 level=self.trading_level)

        self.add_trading_signal_listener(self.account_service)

        self.init_selectors(security_list=security_list, security_type=self.security_type, exchanges=self.exchanges,
                            codes=self.codes, start_timestamp=self.start_timestamp, end_timestamp=self.end_timestamp)

        self.selectors_comparator = LimitSelectorsComparator(self.selectors)

        self.trading_level_asc = list(set([TradingLevel(selector.level) for selector in self.selectors]))
        self.trading_level_asc.sort()

        self.trading_level_desc = list(self.trading_level_asc)
        self.trading_level_desc.reverse()

        self.targets_slot: TargetsSlot = TargetsSlot()