Example #1
0
class Trader(object):
    logger = logging.getLogger(__name__)

    entity_schema: EntityMixin = None

    def __init__(self,
                 entity_ids: List[str] = None,
                 exchanges: List[str] = ['sh', 'sz'],
                 codes: List[str] = None,
                 start_timestamp: Union[str, pd.Timestamp] = None,
                 end_timestamp: Union[str, pd.Timestamp] = None,
                 provider: str = None,
                 level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY,
                 trader_name: str = None,
                 real_time: bool = False,
                 kdata_use_begin_time: bool = False,
                 draw_result: bool = True) -> None:

        assert self.entity_schema is not None

        if trader_name:
            self.trader_name = trader_name
        else:
            self.trader_name = type(self).__name__.lower()

        self.trading_signal_listeners = []

        self.selectors: List[TargetSelector] = []

        self.entity_ids = entity_ids

        self.exchanges = exchanges
        self.codes = codes

        self.provider = provider
        # make sure the min level selector correspond to the provider and level
        self.level = IntervalLevel(level)
        self.real_time = real_time

        if start_timestamp and end_timestamp:
            self.start_timestamp = to_pd_timestamp(start_timestamp)
            self.end_timestamp = to_pd_timestamp(end_timestamp)
        else:
            assert False

        if real_time:
            logger.info(
                'real_time mode, end_timestamp should be future,you could set it big enough for running forever'
            )
            assert self.end_timestamp >= now_pd_timestamp()

        self.kdata_use_begin_time = kdata_use_begin_time
        self.draw_result = draw_result

        self.account_service = SimAccountService(
            trader_name=self.trader_name,
            timestamp=self.start_timestamp,
            provider=self.provider,
            level=self.level)

        self.add_trading_signal_listener(self.account_service)

        self.init_selectors(entity_ids=entity_ids,
                            entity_schema=self.entity_schema,
                            exchanges=self.exchanges,
                            codes=self.codes,
                            start_timestamp=self.start_timestamp,
                            end_timestamp=self.end_timestamp)

        self.selectors_comparator = self.init_selectors_comparator()

        self.trading_level_asc = list(
            set([IntervalLevel(selector.level)
                 for selector in self.selectors]))
        self.trading_level_asc.sort()

        self.trading_level_desc = list(self.trading_level_asc)
        self.trading_level_desc.reverse()

        self.targets_slot: TargetsSlot = TargetsSlot()

        self.session = get_db_session('zvt', 'business')
        trader = get_trader(session=self.session,
                            trader_name=self.trader_name,
                            return_type='domain',
                            limit=1)

        if trader:
            self.logger.warning(
                "trader:{} has run before,old result would be deleted".format(
                    self.trader_name))
            self.session.query(business.Trader).filter(
                business.Trader.trader_name == self.trader_name).delete()
            self.session.commit()
        self.on_start()

    def on_start(self):
        if not self.selectors:
            raise Exception(
                'please setup self.selectors in init_selectors at first')

        # run all the selectors
        for selector in self.selectors:
            # run for the history data at first
            selector.run()

        if self.entity_ids:
            entity_ids = json.dumps(self.entity_ids)
        else:
            entity_ids = None

        if self.exchanges:
            exchanges = json.dumps(self.exchanges)
        else:
            exchanges = None

        if self.codes:
            codes = json.dumps(self.codes)
        else:
            codes = None

        trader_domain = business.Trader(
            id=self.trader_name,
            timestamp=self.start_timestamp,
            trader_name=self.trader_name,
            entity_type=entity_ids,
            exchanges=exchanges,
            codes=codes,
            start_timestamp=self.start_timestamp,
            end_timestamp=self.end_timestamp,
            provider=self.provider,
            level=self.level.value,
            real_time=self.real_time,
            kdata_use_begin_time=self.kdata_use_begin_time)
        self.session.add(trader_domain)
        self.session.commit()

    def init_selectors(self, entity_ids, entity_schema, exchanges, codes,
                       start_timestamp, end_timestamp):
        """
        implement this to init selectors

        """
        raise NotImplementedError

    def init_selectors_comparator(self):
        """
        overwrite this to set selectors_comparator

        """
        return LimitSelectorsComparator(self.selectors)

    def add_trading_signal_listener(self, listener):
        if listener not in self.trading_signal_listeners:
            self.trading_signal_listeners.append(listener)

    def remove_trading_signal_listener(self, listener):
        if listener in self.trading_signal_listeners:
            self.trading_signal_listeners.remove(listener)

    def handle_targets_slot(self, timestamp):
        """
        this function would be called in every cycle, you could overwrite it for your custom algorithm to select the
        targets of different levels

        the default implementation is selecting the targets in all levels

        :param timestamp:
        :type timestamp:
        """
        long_selected = None
        short_selected = None
        for level in self.trading_level_desc:
            targets = self.targets_slot.get_targets(level=level)
            if targets:
                long_targets = set(targets[0])
                short_targets = set(targets[1])

                if not long_selected:
                    long_selected = long_targets
                else:
                    long_selected = long_selected & long_targets

                if not short_selected:
                    short_selected = short_targets
                else:
                    short_selected = short_selected & short_targets

        self.logger.debug('timestamp:{},long_selected:{}'.format(
            timestamp, long_selected))

        self.logger.debug('timestamp:{},short_selected:{}'.format(
            timestamp, short_selected))

        self.send_trading_signals(timestamp=timestamp,
                                  long_selected=long_selected,
                                  short_selected=short_selected)

    def send_trading_signals(self, timestamp, long_selected, short_selected):
        # current position
        account = self.account_service.latest_account
        current_holdings = [
            position['entity_id'] for position in account['positions']
            if position['available_long'] > 0
        ]

        if long_selected:
            # just long the security not in the positions
            longed = long_selected - set(current_holdings)
            if longed:
                position_pct = 1.0 / len(longed)
                order_money = account['cash'] * position_pct

                for entity_id in longed:
                    trading_signal = TradingSignal(
                        entity_id=entity_id,
                        the_timestamp=timestamp,
                        trading_signal_type=TradingSignalType.
                        trading_signal_open_long,
                        trading_level=self.level,
                        order_money=order_money)
                    for listener in self.trading_signal_listeners:
                        listener.on_trading_signal(trading_signal)

        # just short the security in current_holdings and short_selected
        if short_selected:
            shorted = set(current_holdings) & short_selected

            for entity_id in shorted:
                trading_signal = TradingSignal(
                    entity_id=entity_id,
                    the_timestamp=timestamp,
                    trading_signal_type=TradingSignalType.
                    trading_signal_close_long,
                    position_pct=1.0,
                    trading_level=self.level)
                for listener in self.trading_signal_listeners:
                    listener.on_trading_signal(trading_signal)

    def on_finish(self):
        # show the result
        if self.draw_result:
            import plotly.io as pio
            pio.renderers.default = "browser"
            reader = AccountReader(trader_names=[self.trader_name])
            drawer = Drawer(main_data=NormalData(reader.data_df.copy()[
                ['trader_name', 'timestamp', 'all_value']],
                                                 category_field='trader_name'))
            drawer.draw_line()

    def run(self):
        # iterate timestamp of the min level,e.g,9:30,9:35,9.40...for 5min level
        # timestamp represents the timestamp in kdata
        for timestamp in iterate_timestamps(
                entity_type=self.entity_schema,
                exchange=self.exchanges[0],
                start_timestamp=self.start_timestamp,
                end_timestamp=self.end_timestamp,
                level=self.level):

            if not is_trading_date(entity_type=self.entity_schema,
                                   exchange=self.exchanges[0],
                                   timestamp=timestamp):
                continue
            if self.real_time:
                # all selector move on to handle the coming data
                if self.kdata_use_begin_time:
                    real_end_timestamp = timestamp + pd.Timedelta(
                        seconds=self.level.to_second())
                else:
                    real_end_timestamp = timestamp

                waiting_seconds, _ = self.level.count_from_timestamp(
                    real_end_timestamp,
                    one_day_trading_minutes=get_one_day_trading_minutes(
                        entity_type=self.entity_schema))
                # meaning the future kdata not ready yet,we could move on to check
                if waiting_seconds and (waiting_seconds > 0):
                    # iterate the selector from min to max which in finished timestamp kdata
                    for level in self.trading_level_asc:
                        if (is_in_finished_timestamps(
                                entity_type=self.entity_schema,
                                exchange=self.exchanges[0],
                                timestamp=timestamp,
                                level=level)):
                            for selector in self.selectors:
                                if selector.level == level:
                                    selector.move_on(timestamp,
                                                     self.kdata_use_begin_time,
                                                     timeout=waiting_seconds +
                                                     20)

            # on_trading_open to setup the account
            if self.level == IntervalLevel.LEVEL_1DAY or (
                    self.level != IntervalLevel.LEVEL_1DAY
                    and is_open_time(entity_type=self.entity_schema,
                                     exchange=self.exchanges[0],
                                     timestamp=timestamp)):
                self.account_service.on_trading_open(timestamp)

            # the time always move on by min level step and we could check all level targets in the slot
            self.handle_targets_slot(timestamp=timestamp)

            for level in self.trading_level_asc:
                # in every cycle, all level selector do its job in its time
                if (is_in_finished_timestamps(entity_type=self.entity_schema,
                                              exchange=self.exchanges[0],
                                              timestamp=timestamp,
                                              level=level)):
                    long_targets, short_targets = self.selectors_comparator.make_decision(
                        timestamp=timestamp, trading_level=level)

                    self.targets_slot.input_targets(level, long_targets,
                                                    short_targets)

            # on_trading_close to calculate date account
            if self.level == IntervalLevel.LEVEL_1DAY or (
                    self.level != IntervalLevel.LEVEL_1DAY
                    and is_close_time(entity_type=self.entity_schema,
                                      exchange=self.exchanges[0],
                                      timestamp=timestamp)):
                self.account_service.on_trading_close(timestamp)

        self.on_finish()
Example #2
0
class FixedCycleDataRecorder(TimeSeriesDataRecorder):
    def __init__(self,
                 entity_type='stock',
                 exchanges=['sh', 'sz'],
                 entity_ids=None,
                 codes=None,
                 batch_size=10,
                 force_update=False,
                 sleeping_time=10,
                 default_size=2000,
                 one_shot=False,
                 fix_duplicate_way='add',
                 start_timestamp=None,
                 end_timestamp=None,
                 contain_unfinished_data=False,
                 level=IntervalLevel.LEVEL_1DAY,
                 kdata_use_begin_time=False,
                 close_hour=0,
                 close_minute=0,
                 one_day_trading_minutes=24 * 60) -> None:
        super().__init__(entity_type, exchanges, entity_ids, codes, batch_size,
                         force_update, sleeping_time, default_size, one_shot,
                         fix_duplicate_way, start_timestamp, end_timestamp)

        self.level = IntervalLevel(level)
        # FIXME:should remove unfinished data when recording,always set it to False now
        self.contain_unfinished_data = contain_unfinished_data
        self.kdata_use_begin_time = kdata_use_begin_time
        self.close_hour = close_hour
        self.close_minute = close_minute
        self.one_day_trading_minutes = one_day_trading_minutes

    def get_latest_saved_record(self, entity):
        order = eval('self.data_schema.{}.desc()'.format(
            self.get_evaluated_time_field()))

        return get_data(entity_id=entity.id,
                        provider=self.provider,
                        data_schema=self.data_schema,
                        order=order,
                        limit=1,
                        return_type='domain',
                        session=self.session,
                        level=self.level.value)

    def evaluate_start_end_size_timestamps(self, entity):
        # get latest record
        latest_saved_record = self.get_latest_saved_record(entity=entity)

        if latest_saved_record:
            latest_timestamp = latest_saved_record[0].timestamp
        else:
            latest_timestamp = entity.timestamp

        if not latest_timestamp:
            return latest_timestamp, None, self.default_size, None

        current_time = pd.Timestamp.now()
        time_delta = current_time - latest_timestamp

        if self.level == IntervalLevel.LEVEL_1DAY:
            if is_same_date(current_time, latest_timestamp):
                return latest_timestamp, None, 0, None
            return latest_timestamp, None, time_delta.days + 1, None

        # to today,check closing time
        # 0,0 means never stop,e.g,coin
        if (self.close_hour != 0
                and self.close_minute != 0) and time_delta.days == 0:
            if latest_timestamp.hour == self.close_hour and latest_timestamp.minute == self.close_minute:
                return latest_timestamp, None, 0, None

        if self.kdata_use_begin_time:
            touching_timestamp = latest_timestamp + pd.Timedelta(
                seconds=self.level.to_second())
        else:
            touching_timestamp = latest_timestamp

        waiting_seconds, size = self.level.count_from_timestamp(
            touching_timestamp,
            one_day_trading_minutes=self.one_day_trading_minutes)
        if not self.one_shot and waiting_seconds and (waiting_seconds > 30):
            t = waiting_seconds / 2
            self.logger.info(
                'level:{},recorded_time:{},touching_timestamp:{},current_time:{},next_ok_time:{},just sleep:{} seconds'
                .format(
                    self.level.value, latest_timestamp, touching_timestamp,
                    current_time, touching_timestamp +
                    pd.Timedelta(seconds=self.level.to_second()), t))
            time.sleep(t)

        return latest_timestamp, None, size, None

    def persist(self, entity, domain_list):
        if domain_list:
            if domain_list[0].timestamp >= domain_list[-1].timestamp:
                first_timestamp = domain_list[-1].timestamp
                last_timestamp = domain_list[0].timestamp
            else:
                first_timestamp = domain_list[0].timestamp
                last_timestamp = domain_list[-1].timestamp

            self.logger.info(
                "persist {} for entity_id:{},time interval:[{},{}]".format(
                    self.data_schema, entity.id, first_timestamp,
                    last_timestamp))

            current_timestamp = now_pd_timestamp()

            saving_datas = domain_list

            # FIXME:remove this logic
            # FIXME:should remove unfinished data when recording,always set it to False now
            if is_same_date(current_timestamp,
                            last_timestamp) and self.contain_unfinished_data:
                if current_timestamp.hour >= self.close_hour and current_timestamp.minute >= self.close_minute + 2:
                    # after the closing time of the day,we think the last data is finished
                    saving_datas = domain_list
                else:
                    # ignore unfinished kdata
                    saving_datas = domain_list[:-1]
                    self.logger.info(
                        "ignore kdata for entity_id:{},level:{},timestamp:{},current_timestamp"
                        .format(entity.id, self.level, last_timestamp,
                                current_timestamp))

            self.session.add_all(saving_datas)
            self.session.commit()