class Trader(object): entity_schema: EntityMixin = None def __init__(self, entity_ids: List[str] = None, exchanges: List[str] = None, codes: List[str] = None, start_timestamp: Union[str, pd.Timestamp] = None, end_timestamp: Union[str, pd.Timestamp] = None, provider: str = None, level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY, trader_name: str = None, real_time: bool = False, kdata_use_begin_time: bool = False, draw_result: bool = True) -> None: assert self.entity_schema is not None self.logger = logging.getLogger(__name__) if trader_name: self.trader_name = trader_name else: self.trader_name = type(self).__name__.lower() self.trading_signal_listeners: List[TradingListener] = [] self.selectors: List[TargetSelector] = [] self.entity_ids = entity_ids self.exchanges = exchanges self.codes = codes self.provider = provider # make sure the min level selector correspond to the provider and level self.level = IntervalLevel(level) self.real_time = real_time if start_timestamp and end_timestamp: self.start_timestamp = to_pd_timestamp(start_timestamp) self.end_timestamp = to_pd_timestamp(end_timestamp) else: assert False self.trading_dates = self.entity_schema.get_trading_dates( start_date=self.start_timestamp, end_date=self.end_timestamp) if real_time: logger.info( 'real_time mode, end_timestamp should be future,you could set it big enough for running forever' ) assert self.end_timestamp >= now_pd_timestamp() self.kdata_use_begin_time = kdata_use_begin_time self.draw_result = draw_result self.account_service = SimAccountService( entity_schema=self.entity_schema, trader_name=self.trader_name, timestamp=self.start_timestamp, provider=self.provider, level=self.level) self.add_trading_signal_listener(self.account_service) self.init_selectors(entity_ids=entity_ids, entity_schema=self.entity_schema, exchanges=self.exchanges, codes=self.codes, start_timestamp=self.start_timestamp, end_timestamp=self.end_timestamp) if self.selectors: self.trading_level_asc = list( set([ IntervalLevel(selector.level) for selector in self.selectors ])) self.trading_level_asc.sort() self.logger.info( f'trader level:{self.level},selectors level:{self.trading_level_asc}' ) if self.level != self.trading_level_asc[0]: raise Exception( "trader level should be the min of the selectors") self.trading_level_desc = list(self.trading_level_asc) self.trading_level_desc.reverse() self.targets_slot: TargetsSlot = TargetsSlot() self.session = get_db_session('zvt', data_schema=TraderInfo) self.on_start() def on_start(self): # run all the selectors for selector in self.selectors: # run for the history data at first selector.run() if self.entity_ids: entity_ids = json.dumps(self.entity_ids) else: entity_ids = None if self.exchanges: exchanges = json.dumps(self.exchanges) else: exchanges = None if self.codes: codes = json.dumps(self.codes) else: codes = None sim_account = TraderInfo( id=self.trader_name, entity_id=f'trader_zvt_{self.trader_name}', timestamp=self.start_timestamp, trader_name=self.trader_name, entity_ids=entity_ids, exchanges=exchanges, codes=codes, start_timestamp=self.start_timestamp, end_timestamp=self.end_timestamp, provider=self.provider, level=self.level.value, real_time=self.real_time, kdata_use_begin_time=self.kdata_use_begin_time) self.session.add(sim_account) self.session.commit() def init_selectors(self, entity_ids, entity_schema, exchanges, codes, start_timestamp, end_timestamp): """ implement this to init selectors """ pass def add_trading_signal_listener(self, listener): if listener not in self.trading_signal_listeners: self.trading_signal_listeners.append(listener) def remove_trading_signal_listener(self, listener): if listener in self.trading_signal_listeners: self.trading_signal_listeners.remove(listener) def handle_targets_slot(self, due_timestamp: pd.Timestamp, happen_timestamp: pd.Timestamp): """ this function would be called in every cycle, you could overwrite it for your custom algorithm to select the targets of different levels the default implementation is selecting the targets in all levels :param due_timestamp: :param happen_timestamp: """ long_selected = None short_selected = None for level in self.trading_level_desc: targets = self.targets_slot.get_targets(level=level) if targets: long_targets = set(targets[0]) short_targets = set(targets[1]) if not long_selected: long_selected = long_targets else: long_selected = long_selected & long_targets if not short_selected: short_selected = short_targets else: short_selected = short_selected & short_targets self.logger.debug('timestamp:{},long_selected:{}'.format( due_timestamp, long_selected)) self.logger.debug('timestamp:{},short_selected:{}'.format( due_timestamp, short_selected)) self.trade_the_targets(due_timestamp=due_timestamp, happen_timestamp=happen_timestamp, long_selected=long_selected, short_selected=short_selected) def get_current_account(self) -> AccountStats: return self.account_service.account def buy(self, due_timestamp, happen_timestamp, entity_ids, position_pct=1.0, ignore_in_position=True): if ignore_in_position: account = self.get_current_account() current_holdings = [] if account.positions: current_holdings = [ position.entity_id for position in account.positions if position != None and position.available_long > 0 ] entity_ids = set(entity_ids) - set(current_holdings) if entity_ids: position_pct = (1.0 / len(entity_ids)) * position_pct for entity_id in entity_ids: trading_signal = TradingSignal( entity_id=entity_id, due_timestamp=due_timestamp, happen_timestamp=happen_timestamp, trading_signal_type=TradingSignalType.open_long, trading_level=self.level, position_pct=position_pct) self.send_trading_signal(trading_signal) def sell(self, due_timestamp, happen_timestamp, entity_ids, position_pct=1.0): # current position account = self.get_current_account() current_holdings = [] if account.positions: current_holdings = [ position.entity_id for position in account.positions if position != None and position.available_long > 0 ] shorted = set(current_holdings) & entity_ids for entity_id in shorted: trading_signal = TradingSignal( entity_id=entity_id, due_timestamp=due_timestamp, happen_timestamp=happen_timestamp, trading_signal_type=TradingSignalType.close_long, trading_level=self.level, position_pct=position_pct) self.send_trading_signal(trading_signal) def trade_the_targets(self, due_timestamp, happen_timestamp, long_selected, short_selected, long_pct=1.0, short_pct=1.0): self.buy(due_timestamp=due_timestamp, happen_timestamp=happen_timestamp, entity_ids=long_selected, position_pct=long_pct) self.sell(due_timestamp=due_timestamp, happen_timestamp=happen_timestamp, entity_ids=short_selected, position_pct=short_pct) def send_trading_signal(self, signal: TradingSignal): for listener in self.trading_signal_listeners: try: listener.on_trading_signal(signal) except Exception as e: self.logger.exception(e) listener.on_trading_error(timestamp=signal.happen_timestamp, error=e) def on_finish(self): # show the result if self.draw_result: import plotly.io as pio pio.renderers.default = "browser" reader = AccountStatsReader(trader_names=[self.trader_name]) df = reader.data_df drawer = Drawer(main_data=NormalData( df.copy()[['trader_name', 'timestamp', 'all_value']], category_field='trader_name')) drawer.draw_line(show=True) def select_long_targets(self, long_targets: List[str]) -> List[str]: if len(long_targets) > 10: return long_targets[0:10] return long_targets def select_short_targets(self, short_targets: List[str]) -> List[str]: if len(short_targets) > 10: return short_targets[0:10] return short_targets def in_trading_date(self, timestamp): return to_time_str(timestamp) in self.trading_dates def on_time(self, timestamp): self.logger.debug(f'current timestamp:{timestamp}') def run(self): # iterate timestamp of the min level,e.g,9:30,9:35,9.40...for 5min level # timestamp represents the timestamp in kdata for timestamp in self.entity_schema.get_interval_timestamps( start_date=self.start_timestamp, end_date=self.end_timestamp, level=self.level): if not self.in_trading_date(timestamp=timestamp): continue if self.real_time: # all selector move on to handle the coming data if self.kdata_use_begin_time: real_end_timestamp = timestamp + pd.Timedelta( seconds=self.level.to_second()) else: real_end_timestamp = timestamp seconds = (now_pd_timestamp() - real_end_timestamp).total_seconds() waiting_seconds = self.level.to_second() - seconds # meaning the future kdata not ready yet,we could move on to check if waiting_seconds > 0: # iterate the selector from min to max which in finished timestamp kdata for level in self.trading_level_asc: if self.entity_schema.is_finished_kdata_timestamp( timestamp=timestamp, level=level): for selector in self.selectors: if selector.level == level: selector.move_on(timestamp, self.kdata_use_begin_time, timeout=waiting_seconds + 20) # on_trading_open to setup the account if self.level == IntervalLevel.LEVEL_1DAY or ( self.level != IntervalLevel.LEVEL_1DAY and self.entity_schema.is_open_timestamp(timestamp)): self.account_service.on_trading_open(timestamp) self.on_time(timestamp=timestamp) if self.selectors: for level in self.trading_level_asc: # in every cycle, all level selector do its job in its time if self.entity_schema.is_finished_kdata_timestamp( timestamp=timestamp, level=level): all_long_targets = [] all_short_targets = [] for selector in self.selectors: if selector.level == level: long_targets = selector.get_open_long_targets( timestamp=timestamp) long_targets = self.select_long_targets( long_targets) short_targets = selector.get_open_short_targets( timestamp=timestamp) short_targets = self.select_short_targets( short_targets) all_long_targets += long_targets all_short_targets += short_targets if all_long_targets or all_short_targets: self.targets_slot.input_targets( level, all_long_targets, all_short_targets) # the time always move on by min level step and we could check all level targets in the slot # 1)the targets is generated for next interval # 2)the acceptable price is next interval prices,you could buy it at current price if the time is before the timestamp(due_timestamp) when trading signal received # 3)the suggest price the the close price for generating the signal(happen_timestamp) due_timestamp = timestamp + pd.Timedelta( seconds=self.level.to_second()) if level == self.level: self.handle_targets_slot( due_timestamp=due_timestamp, happen_timestamp=timestamp) # on_trading_close to calculate date account if self.level == IntervalLevel.LEVEL_1DAY or ( self.level != IntervalLevel.LEVEL_1DAY and self.entity_schema.is_close_timestamp(timestamp)): self.account_service.on_trading_close(timestamp) self.on_finish()
class Trader(object): entity_schema: EntityMixin = None def __init__(self, region: Region, entity_ids: List[str] = None, exchanges: List[str] = None, codes: List[str] = None, start_timestamp: Union[str, pd.Timestamp] = None, end_timestamp: Union[str, pd.Timestamp] = None, provider: Provider = Provider.Default, level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY, trader_name: str = None, real_time: bool = False, kdata_use_begin_time: bool = False, draw_result: bool = True, rich_mode: bool = True) -> None: assert self.entity_schema is not None self.logger = logging.getLogger(__name__) if trader_name: self.trader_name = trader_name else: self.trader_name = type(self).__name__.lower() self.trading_signal_listeners: List[TradingListener] = [] self.selectors: List[TargetSelector] = [] self.entity_ids = entity_ids self.exchanges = exchanges self.codes = codes self.region = region self.provider = provider # make sure the min level selector correspond to the provider and level self.level = IntervalLevel(level) self.real_time = real_time if start_timestamp and end_timestamp: self.start_timestamp = to_pd_timestamp(start_timestamp) self.end_timestamp = to_pd_timestamp(end_timestamp) else: assert False self.trading_dates = self.entity_schema.get_trading_dates( start_date=self.start_timestamp, end_date=self.end_timestamp) if real_time: logger.info( 'real_time mode, end_timestamp should be future,you could set it big enough for running forever' ) assert self.end_timestamp >= now_pd_timestamp(self.region) self.kdata_use_begin_time = kdata_use_begin_time self.draw_result = draw_result self.rich_mode = rich_mode self.account_service = SimAccountService( entity_schema=self.entity_schema, trader_name=self.trader_name, timestamp=self.start_timestamp, provider=self.provider, level=self.level, rich_mode=rich_mode) self.register_trading_signal_listener(self.account_service) self.init_selectors(entity_ids=entity_ids, entity_schema=self.entity_schema, exchanges=self.exchanges, codes=self.codes, start_timestamp=self.start_timestamp, end_timestamp=self.end_timestamp) if self.selectors: self.trading_level_asc = list( set([ IntervalLevel(selector.level) for selector in self.selectors ])) self.trading_level_asc.sort() self.logger.info( f'trader level:{self.level},selectors level:{self.trading_level_asc}' ) if self.level != self.trading_level_asc[0]: raise Exception( "trader level should be the min of the selectors") self.trading_level_desc = list(self.trading_level_asc) self.trading_level_desc.reverse() self.session = get_db_session('zvt', data_schema=TraderInfo) self.level_map_long_targets = {} self.level_map_short_targets = {} self.trading_signals: List[TradingSignal] = [] self.on_start() def on_start(self): # run all the selectors for selector in self.selectors: # run for the history data at first selector.run() if self.entity_ids: entity_ids = json.dumps(self.entity_ids) else: entity_ids = None if self.exchanges: exchanges = json.dumps(self.exchanges) else: exchanges = None if self.codes: codes = json.dumps(self.codes) else: codes = None sim_account = TraderInfo( id=self.trader_name, entity_id=f'trader_zvt_{self.trader_name}', timestamp=self.start_timestamp, trader_name=self.trader_name, entity_ids=entity_ids, exchanges=exchanges, codes=codes, start_timestamp=self.start_timestamp, end_timestamp=self.end_timestamp, provider=self.provider, level=self.level.value, real_time=self.real_time, kdata_use_begin_time=self.kdata_use_begin_time) self.session.add(sim_account) self.session.commit() def init_selectors(self, entity_ids, entity_schema, exchanges, codes, start_timestamp, end_timestamp): """ overwrite it to init selectors if you want to use selector/factor computing model or just write strategy in on_time """ pass def register_trading_signal_listener(self, listener): if listener not in self.trading_signal_listeners: self.trading_signal_listeners.append(listener) def deregister_trading_signal_listener(self, listener): if listener in self.trading_signal_listeners: self.trading_signal_listeners.remove(listener) def set_long_targets_by_level(self, level: IntervalLevel, targets: List[str]) -> None: logger.debug( f'level:{level},old long targets:{self.level_map_long_targets.get(level)},new long targets:{targets}' ) self.level_map_long_targets[level] = targets def set_short_targets_by_level(self, level: IntervalLevel, targets: List[str]) -> None: logger.debug( f'level:{level},old short targets:{self.level_map_short_targets.get(level)},new short targets:{targets}' ) self.level_map_short_targets[level] = targets def get_long_targets_by_level(self, level: IntervalLevel) -> List[str]: return self.level_map_long_targets.get(level) def get_short_targets_by_level(self, level: IntervalLevel) -> List[str]: return self.level_map_short_targets.get(level) def select_long_targets_from_levels(self, timestamp): """ overwrite it to select long targets from multiple levels,the default implementation is selecting the targets in all level :param timestamp: """ long_selected = None for level in self.trading_level_desc: long_targets = self.level_map_long_targets.get(level) if long_targets: long_targets = set(long_targets) if not long_selected: long_selected = long_targets else: long_selected = long_selected & long_targets return long_selected def select_short_targets_from_levels(self, timestamp): """ overwrite it to select short targets from multiple levels,the default implementation is selecting the targets in all level :param timestamp: """ short_selected = None for level in self.trading_level_desc: short_targets = self.level_map_short_targets.get(level) if short_targets: short_targets = set(short_targets) if not short_selected: short_selected = short_targets else: short_selected = short_selected & short_targets return short_selected def get_current_account(self) -> AccountStats: return self.account_service.account def get_current_positions(self) -> List[Position]: return self.get_current_account().positions def long_position_control(self): positions = self.get_current_positions() position_pct = 1.0 if not positions: position_pct = 0.2 elif len(positions) <= 10: position_pct = 0.5 return position_pct def short_position_control(self): return 1.0 def buy(self, due_timestamp, happen_timestamp, entity_ids, ignore_in_position=True): if ignore_in_position: account = self.get_current_account() current_holdings = [] if account.positions: current_holdings = [ position.entity_id for position in account.positions if position != None and position.available_long > 0 ] entity_ids = set(entity_ids) - set(current_holdings) if entity_ids: position_pct = self.long_position_control() position_pct = (1.0 / len(entity_ids)) * position_pct for entity_id in entity_ids: trading_signal = TradingSignal( entity_id=entity_id, due_timestamp=due_timestamp, happen_timestamp=happen_timestamp, trading_signal_type=TradingSignalType.open_long, trading_level=self.level, position_pct=position_pct) self.trading_signals.append(trading_signal) def sell(self, due_timestamp, happen_timestamp, entity_ids): # current position account = self.get_current_account() current_holdings = [] if account.positions: current_holdings = [ position.entity_id for position in account.positions if position != None and position.available_long > 0 ] shorted = set(current_holdings) & set(entity_ids) if shorted: position_pct = self.short_position_control() for entity_id in shorted: trading_signal = TradingSignal( entity_id=entity_id, due_timestamp=due_timestamp, happen_timestamp=happen_timestamp, trading_signal_type=TradingSignalType.close_long, trading_level=self.level, position_pct=position_pct) self.trading_signals.append(trading_signal) def trade_the_targets(self, due_timestamp, happen_timestamp, long_selected, short_selected): if short_selected: self.sell(due_timestamp=due_timestamp, happen_timestamp=happen_timestamp, entity_ids=short_selected) if long_selected: self.buy(due_timestamp=due_timestamp, happen_timestamp=happen_timestamp, entity_ids=long_selected) def on_finish(self, timestmap): self.on_trading_finish(timestmap) # show the result if self.draw_result: import plotly.io as pio pio.renderers.default = "browser" reader = AccountStatsReader(trader_names=[self.trader_name]) df = reader.data_df drawer = Drawer(main_data=NormalData( df.copy()[['trader_name', 'timestamp', 'all_value']], category_field='trader_name')) drawer.draw_line(show=True) def filter_selector_long_targets(self, timestamp, selector: TargetSelector, long_targets: List[str]) -> List[str]: if len(long_targets) > 10: return long_targets[0:10] return long_targets def filter_selector_short_targets(self, timestamp, selector: TargetSelector, short_targets: List[str]) -> List[str]: if len(short_targets) > 10: return short_targets[0:10] return short_targets def in_trading_date(self, timestamp): return to_time_str(timestamp) in self.trading_dates def on_time(self, timestamp): self.logger.debug(f'current timestamp:{timestamp}') def on_trading_signals(self, trading_signals: List[TradingSignal]): for l in self.trading_signal_listeners: l.on_trading_signals(trading_signals) def on_trading_signal(self, trading_signal: TradingSignal): for l in self.trading_signal_listeners: try: l.on_trading_signal(trading_signal) except Exception as e: self.logger.exception(e) l.on_trading_error(timestamp=trading_signal.happen_timestamp, error=e) def on_trading_open(self, timestamp): for l in self.trading_signal_listeners: l.on_trading_open(timestamp) def on_trading_close(self, timestamp): for l in self.trading_signal_listeners: l.on_trading_close(timestamp) def on_trading_finish(self, timestamp): for l in self.trading_signal_listeners: l.on_trading_finish(timestamp) def on_trading_error(self, timestamp, error): for l in self.trading_signal_listeners: l.on_trading_error(timestamp, error) def run(self): now = now_pd_timestamp(self.region) # iterate timestamp of the min level,e.g,9:30,9:35,9.40...for 5min level # timestamp represents the timestamp in kdata for timestamp in self.entity_schema.get_interval_timestamps( start_date=self.start_timestamp, end_date=self.end_timestamp, level=self.level): if not self.in_trading_date(timestamp=timestamp): continue waiting_seconds = 0 if self.level == IntervalLevel.LEVEL_1DAY: if is_same_date(timestamp, now): while True: self.logger.info( f'time is:{now},just smoke for minutes') time.sleep(60) if now.hour >= 19: waiting_seconds = 20 break elif self.real_time: # all selector move on to handle the coming data if self.kdata_use_begin_time: real_end_timestamp = timestamp + pd.Timedelta( seconds=self.level.to_second()) else: real_end_timestamp = timestamp seconds = (now - real_end_timestamp).total_seconds() waiting_seconds = self.level.to_second() - seconds # meaning the future kdata not ready yet,we could move on to check if waiting_seconds > 0: # iterate the selector from min to max which in finished timestamp kdata for level in self.trading_level_asc: if self.entity_schema.is_finished_kdata_timestamp( timestamp=timestamp, level=level): for selector in self.selectors: if selector.level == level: selector.move_on(timestamp, self.kdata_use_begin_time, timeout=waiting_seconds + 20) # on_trading_open to setup the account if self.level >= IntervalLevel.LEVEL_1DAY or ( self.level != IntervalLevel.LEVEL_1DAY and self.entity_schema.is_open_timestamp(timestamp)): self.on_trading_open(timestamp) self.on_time(timestamp=timestamp) if self.selectors: for level in self.trading_level_asc: # in every cycle, all level selector do its job in its time if self.entity_schema.is_finished_kdata_timestamp( timestamp=timestamp, level=level): all_long_targets = [] all_short_targets = [] for selector in self.selectors: if selector.level == level: long_targets = selector.get_open_long_targets( timestamp=timestamp) long_targets = self.filter_selector_long_targets( timestamp=timestamp, selector=selector, long_targets=long_targets) short_targets = selector.get_open_short_targets( timestamp=timestamp) short_targets = self.filter_selector_short_targets( timestamp=timestamp, selector=selector, short_targets=short_targets) if long_targets: all_long_targets += long_targets if short_targets: all_short_targets += short_targets if all_long_targets: self.set_long_targets_by_level( level, all_long_targets) if all_short_targets: self.set_short_targets_by_level( level, all_short_targets) # the time always move on by min level step and we could check all targets of levels # 1)the targets is generated for next interval # 2)the acceptable price is next interval prices,you could buy it at current price # if the time is before the timestamp(due_timestamp) when trading signal received # 3)the suggest price the the close price for generating the signal(happen_timestamp) due_timestamp = timestamp + pd.Timedelta( seconds=self.level.to_second()) if level == self.level: long_selected = self.select_long_targets_from_levels( timestamp) short_selected = self.select_short_targets_from_levels( timestamp) self.logger.debug( 'timestamp:{},long_selected:{}'.format( due_timestamp, long_selected)) self.logger.debug( 'timestamp:{},short_selected:{}'.format( due_timestamp, short_selected)) self.trade_the_targets( due_timestamp=due_timestamp, happen_timestamp=timestamp, long_selected=long_selected, short_selected=short_selected) if self.trading_signals: self.on_trading_signals(self.trading_signals) # clear self.trading_signals = [] # on_trading_close to calculate date account if self.level >= IntervalLevel.LEVEL_1DAY or ( self.level != IntervalLevel.LEVEL_1DAY and self.entity_schema.is_close_timestamp(timestamp)): self.on_trading_close(timestamp) self.on_finish(timestamp)
def evaluate_size_from_timestamp(start_timestamp: pd.Timestamp, end_timestamp: pd.Timestamp, level: IntervalLevel, one_day_trading_minutes, trade_day=None): """ given from timestamp,level,one_day_trading_minutes,this func evaluate size of kdata to current. it maybe a little bigger than the real size for fetching all the kdata. :param start_timestamp: :type start_timestamp: pd.Timestamp :param level: :type level: IntervalLevel :param one_day_trading_minutes: :type one_day_trading_minutes: int """ # if not end_timestamp: # end_timestamp = now_pd_timestamp() # else: # end_timestamp = to_pd_timestamp(end_timestamp) time_delta = end_timestamp - to_pd_timestamp(start_timestamp) one_day_trading_seconds = one_day_trading_minutes * 60 if level == IntervalLevel.LEVEL_1MON: if trade_day is not None: try: size = int(math.ceil(trade_day.index(start_timestamp) / 22)) size = 0 if size == 0 else size + 1 return size except ValueError as _: if start_timestamp < trade_day[-1]: return int(math.ceil(len(trade_day) / 22)) # raise Exception("wrong start time:{}, error:{}".format(start_timestamp, e)) return int(math.ceil(time_delta.days / 30)) if level == IntervalLevel.LEVEL_1WEEK: if trade_day is not None: try: size = int(math.ceil(trade_day.index(start_timestamp) / 5)) size = 0 if size == 0 else size + 1 return size except ValueError as _: if start_timestamp < trade_day[-1]: return int(math.ceil(len(trade_day) / 5)) # raise Exception("wrong start time:{}, error:{}".format(start_timestamp, e)) return int(math.ceil(time_delta.days / 7)) if level == IntervalLevel.LEVEL_1DAY: if trade_day is not None and len(trade_day) > 0: try: return trade_day.index(start_timestamp) except ValueError as _: if start_timestamp < trade_day[-1]: return len(trade_day) # raise Exception("wrong start time:{}, error:{}".format(start_timestamp, e)) return time_delta.days if level == IntervalLevel.LEVEL_1HOUR: if trade_day is not None: start_date = start_timestamp.replace(hour=0, minute=0, second=0) try: days = trade_day.index(start_date) time = datetime.datetime.time(start_timestamp) size = (days) * 4 + int(math.ceil(count_hours_from_day(time))) return size except ValueError as _: if start_date < trade_day[-1]: return len(trade_day) * 4 # raise Exception("wrong start time:{}, error:{}".format(start_timestamp, e)) return int(math.ceil(time_delta.days * 4 * 2)) if level == IntervalLevel.LEVEL_30MIN: if trade_day is not None: start_date = start_timestamp.replace(hour=0, minute=0, second=0) try: days = trade_day.index(start_date) time = datetime.datetime.time(start_timestamp) size = (days) * 4 * 2 + int( math.ceil(count_mins_from_day(time) / 5)) return size except ValueError as _: if start_date < trade_day[-1]: return len(trade_day) * 4 * 2 # raise Exception("wrong start time:{}, error:{}".format(start_timestamp, e)) return int(math.ceil(time_delta.days * 4 * 2)) if level == IntervalLevel.LEVEL_15MIN: if trade_day is not None: start_date = start_timestamp.replace(hour=0, minute=0, second=0) try: days = trade_day.index(start_date) time = datetime.datetime.time(start_timestamp) size = (days) * 4 * 4 + int( math.ceil(count_mins_from_day(time) / 5)) return size except ValueError as _: if start_date < trade_day[-1]: return len(trade_day) * 4 * 4 # raise Exception("wrong start time:{}, error:{}".format(start_timestamp, e)) return int(math.ceil(time_delta.days * 4 * 4)) if level == IntervalLevel.LEVEL_5MIN: if trade_day is not None: start_date = start_timestamp.replace(hour=0, minute=0, second=0) try: days = trade_day.index(start_date) time = datetime.datetime.time(start_timestamp) size = (days) * 4 * 12 + int( math.ceil(count_mins_from_day(time) / 5)) return size except ValueError as _: if start_date < trade_day[-1]: return len(trade_day) * 4 * 12 # raise Exception("wrong start time:{}, error:{}".format(start_timestamp, e)) return int(math.ceil(time_delta.days * 4 * 12)) if level == IntervalLevel.LEVEL_1MIN: if trade_day is not None: start_date = start_timestamp.replace(hour=0, minute=0, second=0) try: days = trade_day.index(start_date) time = datetime.datetime.time(start_timestamp) size = (days) * 4 * 60 + count_mins_from_day(time) return size except ValueError as _: if start_date < trade_day[-1]: return len(trade_day) * 4 * 60 # raise Exception("wrong start time:{}, error:{}".format(start_timestamp, e)) return int(math.ceil(time_delta.days * 4 * 60)) if time_delta.days > 0: seconds = (time_delta.days + 1) * one_day_trading_seconds return int(math.ceil(seconds / level.to_second())) else: seconds = time_delta.total_seconds() return min(int(math.ceil(seconds / level.to_second())), one_day_trading_seconds / level.to_second())
def next_timestamp(current_timestamp: pd.Timestamp, level: IntervalLevel) -> pd.Timestamp: current_timestamp = to_pd_timestamp(current_timestamp) return current_timestamp + pd.Timedelta(seconds=level.to_second())
class Trader(object): entity_schema: Type[TradableEntity] = None def __init__(self, entity_ids: List[str] = None, exchanges: List[str] = None, codes: List[str] = None, start_timestamp: Union[str, pd.Timestamp] = None, end_timestamp: Union[str, pd.Timestamp] = None, provider: str = None, level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY, trader_name: str = None, real_time: bool = False, kdata_use_begin_time: bool = False, draw_result: bool = True, rich_mode: bool = False, adjust_type: AdjustType = None, profit_threshold=(3, -0.3), keep_history=False) -> None: assert self.entity_schema is not None assert start_timestamp is not None assert end_timestamp is not None self.logger = logging.getLogger(__name__) if trader_name: self.trader_name = trader_name else: self.trader_name = type(self).__name__.lower() self.entity_ids = entity_ids self.exchanges = exchanges self.codes = codes self.provider = provider # make sure the min level selector correspond to the provider and level self.level = IntervalLevel(level) self.real_time = real_time self.start_timestamp = to_pd_timestamp(start_timestamp) self.end_timestamp = to_pd_timestamp(end_timestamp) self.trading_dates = self.entity_schema.get_trading_dates(start_date=self.start_timestamp, end_date=self.end_timestamp) if real_time: self.logger.info( 'real_time mode, end_timestamp should be future,you could set it big enough for running forever') assert self.end_timestamp >= now_pd_timestamp() self.kdata_use_begin_time = kdata_use_begin_time self.draw_result = draw_result self.rich_mode = rich_mode self.adjust_type = AdjustType(adjust_type) self.profit_threshold = profit_threshold self.keep_history = keep_history self.level_map_long_targets = {} self.level_map_short_targets = {} self.trading_signals: List[TradingSignal] = [] self.trading_signal_listeners: List[TradingListener] = [] self.selectors: List[TargetSelector] = [] self.account_service = SimAccountService(entity_schema=self.entity_schema, trader_name=self.trader_name, timestamp=self.start_timestamp, provider=self.provider, level=self.level, rich_mode=self.rich_mode, adjust_type=self.adjust_type, keep_history=self.keep_history) self.register_trading_signal_listener(self.account_service) self.init_selectors(entity_ids=self.entity_ids, entity_schema=self.entity_schema, exchanges=self.exchanges, codes=self.codes, start_timestamp=self.start_timestamp, end_timestamp=self.end_timestamp, adjust_type=self.adjust_type) if self.selectors: self.trading_level_asc = list(set([IntervalLevel(selector.level) for selector in self.selectors])) self.trading_level_asc.sort() self.logger.info(f'trader level:{self.level},selectors level:{self.trading_level_asc}') if self.level != self.trading_level_asc[0]: raise Exception("trader level should be the min of the selectors") self.trading_level_desc = list(self.trading_level_asc) self.trading_level_desc.reverse() # run selectors for history data at first for selector in self.selectors: selector.run() self.on_start() def on_start(self): self.logger.info(f'trader:{self.trader_name} on_start') def init_selectors(self, entity_ids, entity_schema, exchanges, codes, start_timestamp, end_timestamp, adjust_type=None): """ overwrite it to init selectors if you want to use selector/factor computing model :param adjust_type: """ pass def update_targets_by_level(self, level: IntervalLevel, long_targets: List[str], short_targets: List[str], ) -> None: """ the trading signals is generated in min level,before that,we should cache targets of all levels :param level: :param long_targets: :param short_targets: """ self.logger.debug( f'level:{level},old long targets:{self.level_map_long_targets.get(level)},new long targets:{long_targets}') self.level_map_long_targets[level] = long_targets self.logger.debug( f'level:{level},old short targets:{self.level_map_short_targets.get(level)},new short targets:{short_targets}') self.level_map_short_targets[level] = short_targets def get_long_targets_by_level(self, level: IntervalLevel) -> List[str]: return self.level_map_long_targets.get(level) def get_short_targets_by_level(self, level: IntervalLevel) -> List[str]: return self.level_map_short_targets.get(level) def on_targets_selected_from_levels(self, timestamp) -> Tuple[List[str], List[str]]: """ this method's called in every min level cycle to select targets in all levels generated by the previous cycle the default implementation is selecting the targets in all levels overwrite it for your custom logic :param timestamp: current event time :return: long targets, short targets """ long_selected = None short_selected = None for level in self.trading_level_desc: long_targets = self.level_map_long_targets.get(level) # long must in all if long_targets: long_targets = set(long_targets) if long_selected is None: long_selected = long_targets else: long_selected = long_selected & long_targets else: long_selected = set() short_targets = self.level_map_short_targets.get(level) # short any if short_targets: short_targets = set(short_targets) if short_selected is None: short_selected = short_targets else: short_selected = short_selected | short_targets return long_selected, short_selected def get_current_account(self) -> AccountStats: return self.account_service.get_current_account() def get_current_positions(self) -> List[Position]: return self.get_current_account().positions def long_position_control(self): positions = self.get_current_positions() position_pct = 1.0 if not positions: # 没有仓位,买2成 position_pct = 0.2 elif len(positions) <= 10: # 小于10个持仓,买5成 position_pct = 0.5 # 买完 return position_pct def short_position_control(self): # 卖完 return 1.0 def on_profit_control(self): if self.profit_threshold and self.get_current_positions(): positive = self.profit_threshold[0] negative = self.profit_threshold[1] close_long_entity_ids = [] for position in self.get_current_positions(): if position.available_long > 1: # 止盈 if position.profit_rate >= positive: close_long_entity_ids.append(position.entity_id) self.logger.info(f'close profit {position.profit_rate} for {position.entity_id}') # 止损 if position.profit_rate <= negative: close_long_entity_ids.append(position.entity_id) self.logger.info(f'cut lost {position.profit_rate} for {position.entity_id}') return close_long_entity_ids, None return None, None def buy(self, due_timestamp, happen_timestamp, entity_ids, ignore_in_position=True): if ignore_in_position: account = self.get_current_account() current_holdings = [] if account.positions: current_holdings = [position.entity_id for position in account.positions if position != None and position.available_long > 0] entity_ids = set(entity_ids) - set(current_holdings) if entity_ids: position_pct = self.long_position_control() position_pct = (1.0 / len(entity_ids)) * position_pct for entity_id in entity_ids: trading_signal = TradingSignal(entity_id=entity_id, due_timestamp=due_timestamp, happen_timestamp=happen_timestamp, trading_signal_type=TradingSignalType.open_long, trading_level=self.level, position_pct=position_pct) self.trading_signals.append(trading_signal) def sell(self, due_timestamp, happen_timestamp, entity_ids): # current position account = self.get_current_account() current_holdings = [] if account.positions: current_holdings = [position.entity_id for position in account.positions if position != None and position.available_long > 0] shorted = set(current_holdings) & set(entity_ids) if shorted: position_pct = self.short_position_control() for entity_id in shorted: trading_signal = TradingSignal(entity_id=entity_id, due_timestamp=due_timestamp, happen_timestamp=happen_timestamp, trading_signal_type=TradingSignalType.close_long, trading_level=self.level, position_pct=position_pct) self.trading_signals.append(trading_signal) def trade_the_targets(self, due_timestamp, happen_timestamp, long_selected, short_selected): if short_selected: self.sell(due_timestamp=due_timestamp, happen_timestamp=happen_timestamp, entity_ids=short_selected) if long_selected: self.buy(due_timestamp=due_timestamp, happen_timestamp=happen_timestamp, entity_ids=long_selected) def on_finish(self, timestmap): self.on_trading_finish(timestmap) # show the result if self.draw_result: import plotly.io as pio pio.renderers.default = "browser" reader = AccountStatsReader(trader_names=[self.trader_name]) df = reader.data_df drawer = Drawer(main_data=NormalData(df.copy()[['trader_name', 'timestamp', 'all_value']], category_field='trader_name')) drawer.draw_line(show=True) def on_targets_filtered(self, timestamp, level, selector: TargetSelector, long_targets: List[str], short_targets: List[str]) -> Tuple[List[str], List[str]]: """ overwrite it to filter the targets from selector :param timestamp: the event time :param level: the level :param selector: the selector :param long_targets: the long targets from the selector :param short_targets: the short targets from the selector :return: filtered long targets, filtered short targets """ self.logger.info(f'on_targets_filtered {level} long:{long_targets}') if len(long_targets) > 10: long_targets = long_targets[0:10] self.logger.info(f'on_targets_filtered {level} filtered long:{long_targets}') return long_targets, short_targets def in_trading_date(self, timestamp): return to_time_str(timestamp) in self.trading_dates def on_time(self, timestamp: pd.Timestamp): """ called in every min level cycle :param timestamp: event time """ self.logger.debug(f'current timestamp:{timestamp}') def on_trading_signals(self, trading_signals: List[TradingSignal]): for l in self.trading_signal_listeners: l.on_trading_signals(trading_signals) def on_trading_signal(self, trading_signal: TradingSignal): for l in self.trading_signal_listeners: try: l.on_trading_signal(trading_signal) except Exception as e: self.logger.exception(e) l.on_trading_error(timestamp=trading_signal.happen_timestamp, error=e) def on_trading_open(self, timestamp): for l in self.trading_signal_listeners: l.on_trading_open(timestamp) def on_trading_close(self, timestamp): for l in self.trading_signal_listeners: l.on_trading_close(timestamp) def on_trading_finish(self, timestamp): for l in self.trading_signal_listeners: l.on_trading_finish(timestamp) def on_trading_error(self, timestamp, error): for l in self.trading_signal_listeners: l.on_trading_error(timestamp, error) def run(self): # iterate timestamp of the min level,e.g,9:30,9:35,9.40...for 5min level # timestamp represents the timestamp in kdata for timestamp in self.entity_schema.get_interval_timestamps(start_date=self.start_timestamp, end_date=self.end_timestamp, level=self.level): if not self.in_trading_date(timestamp=timestamp): continue waiting_seconds = 0 if self.level == IntervalLevel.LEVEL_1DAY: if is_same_date(timestamp, now_pd_timestamp()): while True: self.logger.info(f'time is:{now_pd_timestamp()},just smoke for minutes') time.sleep(60) current = now_pd_timestamp() if current.hour >= 19: waiting_seconds = 20 break elif self.real_time: # all selector move on to handle the coming data if self.kdata_use_begin_time: real_end_timestamp = timestamp + pd.Timedelta(seconds=self.level.to_second()) else: real_end_timestamp = timestamp seconds = (now_pd_timestamp() - real_end_timestamp).total_seconds() waiting_seconds = self.level.to_second() - seconds # meaning the future kdata not ready yet,we could move on to check if waiting_seconds > 0: # iterate the selector from min to max which in finished timestamp kdata for level in self.trading_level_asc: if self.entity_schema.is_finished_kdata_timestamp(timestamp=timestamp, level=level): for selector in self.selectors: if selector.level == level: selector.move_on(timestamp, self.kdata_use_begin_time, timeout=waiting_seconds + 20) # on_trading_open to setup the account if self.level >= IntervalLevel.LEVEL_1DAY or ( self.level != IntervalLevel.LEVEL_1DAY and self.entity_schema.is_open_timestamp(timestamp)): self.on_trading_open(timestamp) self.on_time(timestamp=timestamp) # 一般来说selector(factors)计算 多标的 历史数据比较快,多级别的计算也比较方便,常用于全市场标的粗过滤 # 更细节的控制可以在on_targets_filtered里进一步处理 # 也可以在on_time里面设计一些自己的逻辑配合过滤 if self.selectors: # 多级别的遍历算法要点: # 1)计算各级别的 标的,通过 on_targets_filtered 过滤,缓存在level_map_long_targets,level_map_short_targets # 2)在最小的level通过 on_targets_selected_from_levels 根据多级别的缓存标的,生成最终的选中标的 # 这里需要注意的是,小级别拿到上一个周期的大级别的标的,这是合理的 for level in self.trading_level_asc: # in every cycle, all level selector do its job in its time if self.entity_schema.is_finished_kdata_timestamp(timestamp=timestamp, level=level): all_long_targets = [] all_short_targets = [] # 从该level的selector中过滤targets for selector in self.selectors: if selector.level == level: long_targets = selector.get_open_long_targets(timestamp=timestamp) short_targets = selector.get_open_short_targets(timestamp=timestamp) if long_targets or short_targets: long_targets, short_targets = self.on_targets_filtered(timestamp=timestamp, level=level, selector=selector, long_targets=long_targets, short_targets=short_targets) if long_targets: all_long_targets += long_targets if short_targets: all_short_targets += short_targets # 将各级别的targets缓存在level_map_long_targets,level_map_short_targets self.update_targets_by_level(level, all_long_targets, all_short_targets) # the time always move on by min level step and we could check all targets of levels # 1)the targets is generated for next interval # 2)the acceptable price is next interval prices,you could buy it at current price # if the time is before the timestamp(due_timestamp) when trading signal received # 3)the suggest price the the close price for generating the signal(happen_timestamp) due_timestamp = timestamp + pd.Timedelta(seconds=self.level.to_second()) # 在最小level生成最终的 交易信号 if level == self.level: long_selected, short_selected = self.on_targets_selected_from_levels(timestamp) # 处理 止赢 止损 passive_short, _ = self.on_profit_control() if passive_short: if not short_selected: short_selected = passive_short else: short_selected = list(set(short_selected) | set(passive_short)) self.logger.debug('timestamp:{},long_selected:{}'.format(due_timestamp, long_selected)) self.logger.debug('timestamp:{},short_selected:{}'.format(due_timestamp, short_selected)) if long_selected or short_selected: self.trade_the_targets(due_timestamp=due_timestamp, happen_timestamp=timestamp, long_selected=long_selected, short_selected=short_selected) if self.trading_signals: self.on_trading_signals(self.trading_signals) # clear self.trading_signals = [] # on_trading_close to calculate date account if self.level >= IntervalLevel.LEVEL_1DAY or ( self.level != IntervalLevel.LEVEL_1DAY and self.entity_schema.is_close_timestamp(timestamp)): self.on_trading_close(timestamp) self.on_finish(timestamp) def register_trading_signal_listener(self, listener): if listener not in self.trading_signal_listeners: self.trading_signal_listeners.append(listener) def deregister_trading_signal_listener(self, listener): if listener in self.trading_signal_listeners: self.trading_signal_listeners.remove(listener)