class Trader(object): entity_schema: EntityMixin = None def __init__(self, entity_ids: List[str] = None, exchanges: List[str] = None, codes: List[str] = None, start_timestamp: Union[str, pd.Timestamp] = None, end_timestamp: Union[str, pd.Timestamp] = None, provider: str = None, level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY, trader_name: str = None, real_time: bool = False, kdata_use_begin_time: bool = False, draw_result: bool = True) -> None: assert self.entity_schema is not None self.logger = logging.getLogger(__name__) if trader_name: self.trader_name = trader_name else: self.trader_name = type(self).__name__.lower() self.trading_signal_listeners: List[TradingListener] = [] self.selectors: List[TargetSelector] = [] self.entity_ids = entity_ids self.exchanges = exchanges self.codes = codes self.provider = provider # make sure the min level selector correspond to the provider and level self.level = IntervalLevel(level) self.real_time = real_time if start_timestamp and end_timestamp: self.start_timestamp = to_pd_timestamp(start_timestamp) self.end_timestamp = to_pd_timestamp(end_timestamp) else: assert False self.trading_dates = self.entity_schema.get_trading_dates( start_date=self.start_timestamp, end_date=self.end_timestamp) if real_time: logger.info( 'real_time mode, end_timestamp should be future,you could set it big enough for running forever' ) assert self.end_timestamp >= now_pd_timestamp() self.kdata_use_begin_time = kdata_use_begin_time self.draw_result = draw_result self.account_service = SimAccountService( entity_schema=self.entity_schema, trader_name=self.trader_name, timestamp=self.start_timestamp, provider=self.provider, level=self.level) self.add_trading_signal_listener(self.account_service) self.init_selectors(entity_ids=entity_ids, entity_schema=self.entity_schema, exchanges=self.exchanges, codes=self.codes, start_timestamp=self.start_timestamp, end_timestamp=self.end_timestamp) if self.selectors: self.trading_level_asc = list( set([ IntervalLevel(selector.level) for selector in self.selectors ])) self.trading_level_asc.sort() self.logger.info( f'trader level:{self.level},selectors level:{self.trading_level_asc}' ) if self.level != self.trading_level_asc[0]: raise Exception( "trader level should be the min of the selectors") self.trading_level_desc = list(self.trading_level_asc) self.trading_level_desc.reverse() self.targets_slot: TargetsSlot = TargetsSlot() self.session = get_db_session('zvt', data_schema=TraderInfo) self.on_start() def on_start(self): # run all the selectors for selector in self.selectors: # run for the history data at first selector.run() if self.entity_ids: entity_ids = json.dumps(self.entity_ids) else: entity_ids = None if self.exchanges: exchanges = json.dumps(self.exchanges) else: exchanges = None if self.codes: codes = json.dumps(self.codes) else: codes = None sim_account = TraderInfo( id=self.trader_name, entity_id=f'trader_zvt_{self.trader_name}', timestamp=self.start_timestamp, trader_name=self.trader_name, entity_ids=entity_ids, exchanges=exchanges, codes=codes, start_timestamp=self.start_timestamp, end_timestamp=self.end_timestamp, provider=self.provider, level=self.level.value, real_time=self.real_time, kdata_use_begin_time=self.kdata_use_begin_time) self.session.add(sim_account) self.session.commit() def init_selectors(self, entity_ids, entity_schema, exchanges, codes, start_timestamp, end_timestamp): """ implement this to init selectors """ pass def add_trading_signal_listener(self, listener): if listener not in self.trading_signal_listeners: self.trading_signal_listeners.append(listener) def remove_trading_signal_listener(self, listener): if listener in self.trading_signal_listeners: self.trading_signal_listeners.remove(listener) def handle_targets_slot(self, due_timestamp: pd.Timestamp, happen_timestamp: pd.Timestamp): """ this function would be called in every cycle, you could overwrite it for your custom algorithm to select the targets of different levels the default implementation is selecting the targets in all levels :param due_timestamp: :param happen_timestamp: """ long_selected = None short_selected = None for level in self.trading_level_desc: targets = self.targets_slot.get_targets(level=level) if targets: long_targets = set(targets[0]) short_targets = set(targets[1]) if not long_selected: long_selected = long_targets else: long_selected = long_selected & long_targets if not short_selected: short_selected = short_targets else: short_selected = short_selected & short_targets self.logger.debug('timestamp:{},long_selected:{}'.format( due_timestamp, long_selected)) self.logger.debug('timestamp:{},short_selected:{}'.format( due_timestamp, short_selected)) self.trade_the_targets(due_timestamp=due_timestamp, happen_timestamp=happen_timestamp, long_selected=long_selected, short_selected=short_selected) def get_current_account(self) -> AccountStats: return self.account_service.account def buy(self, due_timestamp, happen_timestamp, entity_ids, position_pct=1.0, ignore_in_position=True): if ignore_in_position: account = self.get_current_account() current_holdings = [] if account.positions: current_holdings = [ position.entity_id for position in account.positions if position != None and position.available_long > 0 ] entity_ids = set(entity_ids) - set(current_holdings) if entity_ids: position_pct = (1.0 / len(entity_ids)) * position_pct for entity_id in entity_ids: trading_signal = TradingSignal( entity_id=entity_id, due_timestamp=due_timestamp, happen_timestamp=happen_timestamp, trading_signal_type=TradingSignalType.open_long, trading_level=self.level, position_pct=position_pct) self.send_trading_signal(trading_signal) def sell(self, due_timestamp, happen_timestamp, entity_ids, position_pct=1.0): # current position account = self.get_current_account() current_holdings = [] if account.positions: current_holdings = [ position.entity_id for position in account.positions if position != None and position.available_long > 0 ] shorted = set(current_holdings) & entity_ids for entity_id in shorted: trading_signal = TradingSignal( entity_id=entity_id, due_timestamp=due_timestamp, happen_timestamp=happen_timestamp, trading_signal_type=TradingSignalType.close_long, trading_level=self.level, position_pct=position_pct) self.send_trading_signal(trading_signal) def trade_the_targets(self, due_timestamp, happen_timestamp, long_selected, short_selected, long_pct=1.0, short_pct=1.0): self.buy(due_timestamp=due_timestamp, happen_timestamp=happen_timestamp, entity_ids=long_selected, position_pct=long_pct) self.sell(due_timestamp=due_timestamp, happen_timestamp=happen_timestamp, entity_ids=short_selected, position_pct=short_pct) def send_trading_signal(self, signal: TradingSignal): for listener in self.trading_signal_listeners: try: listener.on_trading_signal(signal) except Exception as e: self.logger.exception(e) listener.on_trading_error(timestamp=signal.happen_timestamp, error=e) def on_finish(self): # show the result if self.draw_result: import plotly.io as pio pio.renderers.default = "browser" reader = AccountStatsReader(trader_names=[self.trader_name]) df = reader.data_df drawer = Drawer(main_data=NormalData( df.copy()[['trader_name', 'timestamp', 'all_value']], category_field='trader_name')) drawer.draw_line(show=True) def select_long_targets(self, long_targets: List[str]) -> List[str]: if len(long_targets) > 10: return long_targets[0:10] return long_targets def select_short_targets(self, short_targets: List[str]) -> List[str]: if len(short_targets) > 10: return short_targets[0:10] return short_targets def in_trading_date(self, timestamp): return to_time_str(timestamp) in self.trading_dates def on_time(self, timestamp): self.logger.debug(f'current timestamp:{timestamp}') def run(self): # iterate timestamp of the min level,e.g,9:30,9:35,9.40...for 5min level # timestamp represents the timestamp in kdata for timestamp in self.entity_schema.get_interval_timestamps( start_date=self.start_timestamp, end_date=self.end_timestamp, level=self.level): if not self.in_trading_date(timestamp=timestamp): continue if self.real_time: # all selector move on to handle the coming data if self.kdata_use_begin_time: real_end_timestamp = timestamp + pd.Timedelta( seconds=self.level.to_second()) else: real_end_timestamp = timestamp seconds = (now_pd_timestamp() - real_end_timestamp).total_seconds() waiting_seconds = self.level.to_second() - seconds # meaning the future kdata not ready yet,we could move on to check if waiting_seconds > 0: # iterate the selector from min to max which in finished timestamp kdata for level in self.trading_level_asc: if self.entity_schema.is_finished_kdata_timestamp( timestamp=timestamp, level=level): for selector in self.selectors: if selector.level == level: selector.move_on(timestamp, self.kdata_use_begin_time, timeout=waiting_seconds + 20) # on_trading_open to setup the account if self.level == IntervalLevel.LEVEL_1DAY or ( self.level != IntervalLevel.LEVEL_1DAY and self.entity_schema.is_open_timestamp(timestamp)): self.account_service.on_trading_open(timestamp) self.on_time(timestamp=timestamp) if self.selectors: for level in self.trading_level_asc: # in every cycle, all level selector do its job in its time if self.entity_schema.is_finished_kdata_timestamp( timestamp=timestamp, level=level): all_long_targets = [] all_short_targets = [] for selector in self.selectors: if selector.level == level: long_targets = selector.get_open_long_targets( timestamp=timestamp) long_targets = self.select_long_targets( long_targets) short_targets = selector.get_open_short_targets( timestamp=timestamp) short_targets = self.select_short_targets( short_targets) all_long_targets += long_targets all_short_targets += short_targets if all_long_targets or all_short_targets: self.targets_slot.input_targets( level, all_long_targets, all_short_targets) # the time always move on by min level step and we could check all level targets in the slot # 1)the targets is generated for next interval # 2)the acceptable price is next interval prices,you could buy it at current price if the time is before the timestamp(due_timestamp) when trading signal received # 3)the suggest price the the close price for generating the signal(happen_timestamp) due_timestamp = timestamp + pd.Timedelta( seconds=self.level.to_second()) if level == self.level: self.handle_targets_slot( due_timestamp=due_timestamp, happen_timestamp=timestamp) # on_trading_close to calculate date account if self.level == IntervalLevel.LEVEL_1DAY or ( self.level != IntervalLevel.LEVEL_1DAY and self.entity_schema.is_close_timestamp(timestamp)): self.account_service.on_trading_close(timestamp) self.on_finish()
class Trader(Constructor): logger = logging.getLogger(__name__) security_type: SecurityType = None def __init__(self, security_list: List[str] = None, exchanges: List[str] = ['sh', 'sz'], codes: List[str] = None, start_timestamp: Union[str, pd.Timestamp] = None, end_timestamp: Union[str, pd.Timestamp] = None, provider: Union[str, Provider] = Provider.JOINQUANT, level: Union[str, TradingLevel] = TradingLevel.LEVEL_1DAY, trader_name: str = None, real_time: bool = False, kdata_use_begin_time: bool = False) -> None: assert self.security_type is not None if trader_name: self.trader_name = trader_name else: self.trader_name = type(self).__name__.lower() self.trading_signal_listeners = [] self.state_listeners = [] self.selectors: List[TargetSelector] = [] self.security_list = security_list self.exchanges = exchanges self.codes = codes # FIXME:handle this case gracefully if self.security_list: security_type, exchange, code = decode_security_id(self.security_list[0]) if not self.security_type: self.security_type = security_type if not self.exchanges: self.exchanges = [exchange] self.provider = Provider(provider) # make sure the min level selector correspond to the provider and level self.level = TradingLevel(level) self.real_time = real_time if start_timestamp and end_timestamp: self.start_timestamp = to_pd_timestamp(start_timestamp) self.end_timestamp = to_pd_timestamp(end_timestamp) else: assert False if real_time: logger.info( 'real_time mode, end_timestamp should be future,you could set it big enough for running forever') assert self.end_timestamp >= now_pd_timestamp() self.kdata_use_begin_time = kdata_use_begin_time self.account_service = SimAccountService(trader_name=self.trader_name, timestamp=self.start_timestamp, provider=self.provider, level=self.level) self.add_trading_signal_listener(self.account_service) self.init_selectors(security_list=security_list, security_type=self.security_type, exchanges=self.exchanges, codes=self.codes, start_timestamp=self.start_timestamp, end_timestamp=self.end_timestamp) self.selectors_comparator = self.init_selectors_comparator() self.trading_level_asc = list(set([TradingLevel(selector.level) for selector in self.selectors])) self.trading_level_asc.sort() self.trading_level_desc = list(self.trading_level_asc) self.trading_level_desc.reverse() self.targets_slot: TargetsSlot = TargetsSlot() self.session = get_db_session('zvt', StoreCategory.business) trader = get_trader(session=self.session, trader_name=self.trader_name, return_type='domain', limit=1) if trader: self.logger.warning("trader:{} has run before,old result would be deleted".format(self.trader_name)) self.session.query(business.Trader).filter(business.Trader.trader_name == self.trader_name).delete() self.session.commit() self.on_start() def on_start(self): if not self.selectors: raise Exception('please setup self.selectors in init_selectors at first') # run all the selectors technical_factors = [] for selector in self.selectors: # run for the history data at first selector.run() for factor in selector.filter_factors: if isinstance(factor, TechnicalFactor): technical_factors.append(factor) technical_factors = simplejson.dumps(technical_factors, for_json=True) if self.security_list: security_list = json.dumps(self.security_list) else: security_list = None if self.exchanges: exchanges = json.dumps(self.exchanges) else: exchanges = None if self.codes: codes = json.dumps(self.codes) else: codes = None trader_domain = business.Trader(id=self.trader_name, timestamp=self.start_timestamp, trader_name=self.trader_name, security_type=security_list, exchanges=exchanges, codes=codes, start_timestamp=self.start_timestamp, end_timestamp=self.end_timestamp, provider=self.provider.value, level=self.level.value, real_time=self.real_time, kdata_use_begin_time=self.kdata_use_begin_time, technical_factors=technical_factors) self.session.add(trader_domain) self.session.commit() def init_selectors(self, security_list, security_type, exchanges, codes, start_timestamp, end_timestamp): """ implement this to init selectors """ raise NotImplementedError def init_selectors_comparator(self): """ overwrite this to set selectors_comparator """ return LimitSelectorsComparator(self.selectors) def add_trading_signal_listener(self, listener): if listener not in self.trading_signal_listeners: self.trading_signal_listeners.append(listener) def remove_trading_signal_listener(self, listener): if listener in self.trading_signal_listeners: self.trading_signal_listeners.remove(listener) def handle_targets_slot(self, timestamp): """ this function would be called in every cycle, you could overwrite it for your custom algorithm to select the targets of different levels the default implementation is selecting the targets in all levels :param timestamp: :type timestamp: """ long_selected = None short_selected = None for level in self.trading_level_desc: targets = self.targets_slot.get_targets(level=level) if targets: long_targets = set(targets[0]) short_targets = set(targets[1]) if not long_selected: long_selected = long_targets else: long_selected = long_selected & long_targets if not short_selected: short_selected = short_targets else: short_selected = short_selected & short_targets self.logger.debug('timestamp:{},long_selected:{}'.format(timestamp, long_selected)) self.logger.debug('timestamp:{},short_selected:{}'.format(timestamp, short_selected)) self.send_trading_signals(timestamp=timestamp, long_selected=long_selected, short_selected=short_selected) def send_trading_signals(self, timestamp, long_selected, short_selected): # current position account = self.account_service.latest_account current_holdings = [position['security_id'] for position in account['positions'] if position['available_long'] > 0] if long_selected: # just long the security not in the positions longed = long_selected - set(current_holdings) if longed: position_pct = 1.0 / len(longed) order_money = account['cash'] * position_pct for security_id in longed: trading_signal = TradingSignal(security_id=security_id, the_timestamp=timestamp, trading_signal_type=TradingSignalType.trading_signal_open_long, trading_level=self.level, order_money=order_money) for listener in self.trading_signal_listeners: listener.on_trading_signal(trading_signal) # just short the security in current_holdings and short_selected if short_selected: shorted = set(current_holdings) & short_selected for security_id in shorted: trading_signal = TradingSignal(security_id=security_id, the_timestamp=timestamp, trading_signal_type=TradingSignalType.trading_signal_close_long, position_pct=1.0, trading_level=self.level) for listener in self.trading_signal_listeners: listener.on_trading_signal(trading_signal) def on_finish(self): # show the result pass def run(self): # iterate timestamp of the min level,e.g,9:30,9:35,9.40...for 5min level # timestamp represents the timestamp in kdata for timestamp in iterate_timestamps(security_type=self.security_type, exchange=self.exchanges[0], start_timestamp=self.start_timestamp, end_timestamp=self.end_timestamp, level=self.level): if not is_trading_date(security_type=self.security_type, exchange=self.exchanges[0], timestamp=timestamp): continue if self.real_time: # all selector move on to handle the coming data if self.kdata_use_begin_time: real_end_timestamp = timestamp + pd.Timedelta(seconds=self.level.to_second()) else: real_end_timestamp = timestamp waiting_seconds, _ = self.level.count_from_timestamp(real_end_timestamp, one_day_trading_minutes=get_one_day_trading_minutes( security_type=self.security_type)) # meaning the future kdata not ready yet,we could move on to check if waiting_seconds and (waiting_seconds > 0): # iterate the selector from min to max which in finished timestamp kdata for level in self.trading_level_asc: if (is_in_finished_timestamps(security_type=self.security_type, exchange=self.exchanges[0], timestamp=timestamp, level=level)): for selector in self.selectors: if selector.level == level: selector.move_on(timestamp, self.kdata_use_begin_time, timeout=waiting_seconds + 20) # on_trading_open to setup the account if self.level == TradingLevel.LEVEL_1DAY or ( self.level != TradingLevel.LEVEL_1DAY and is_open_time(security_type=self.security_type, exchange=self.exchanges[0], timestamp=timestamp)): self.account_service.on_trading_open(timestamp) # the time always move on by min level step and we could check all level targets in the slot self.handle_targets_slot(timestamp=timestamp) for level in self.trading_level_asc: # in every cycle, all level selector do its job in its time if (is_in_finished_timestamps(security_type=self.security_type, exchange=self.exchanges[0], timestamp=timestamp, level=level)): long_targets, short_targets = self.selectors_comparator.make_decision(timestamp=timestamp, trading_level=level) self.targets_slot.input_targets(level, long_targets, short_targets) # on_trading_close to calculate date account if self.level == TradingLevel.LEVEL_1DAY or ( self.level != TradingLevel.LEVEL_1DAY and is_close_time(security_type=self.security_type, exchange=self.exchanges[0], timestamp=timestamp)): self.account_service.on_trading_close(timestamp) self.on_finish() @classmethod def get_constructor_meta(cls): meta = super().get_constructor_meta() meta.metas['security_type'] = marshal_object_for_ui(cls.security_type) return meta
class Trader(object): logger = logging.getLogger(__name__) entity_schema: EntityMixin = None def __init__(self, entity_ids: List[str] = None, exchanges: List[str] = ['sh', 'sz'], codes: List[str] = None, start_timestamp: Union[str, pd.Timestamp] = None, end_timestamp: Union[str, pd.Timestamp] = None, provider: str = None, level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY, trader_name: str = None, real_time: bool = False, kdata_use_begin_time: bool = False, draw_result: bool = True) -> None: assert self.entity_schema is not None if trader_name: self.trader_name = trader_name else: self.trader_name = type(self).__name__.lower() self.trading_signal_listeners = [] self.selectors: List[TargetSelector] = [] self.entity_ids = entity_ids self.exchanges = exchanges self.codes = codes self.provider = provider # make sure the min level selector correspond to the provider and level self.level = IntervalLevel(level) self.real_time = real_time if start_timestamp and end_timestamp: self.start_timestamp = to_pd_timestamp(start_timestamp) self.end_timestamp = to_pd_timestamp(end_timestamp) else: assert False if real_time: logger.info( 'real_time mode, end_timestamp should be future,you could set it big enough for running forever' ) assert self.end_timestamp >= now_pd_timestamp() self.kdata_use_begin_time = kdata_use_begin_time self.draw_result = draw_result self.account_service = SimAccountService( trader_name=self.trader_name, timestamp=self.start_timestamp, provider=self.provider, level=self.level) self.add_trading_signal_listener(self.account_service) self.init_selectors(entity_ids=entity_ids, entity_schema=self.entity_schema, exchanges=self.exchanges, codes=self.codes, start_timestamp=self.start_timestamp, end_timestamp=self.end_timestamp) self.selectors_comparator = self.init_selectors_comparator() self.trading_level_asc = list( set([IntervalLevel(selector.level) for selector in self.selectors])) self.trading_level_asc.sort() self.trading_level_desc = list(self.trading_level_asc) self.trading_level_desc.reverse() self.targets_slot: TargetsSlot = TargetsSlot() self.session = get_db_session('zvt', 'business') trader = get_trader(session=self.session, trader_name=self.trader_name, return_type='domain', limit=1) if trader: self.logger.warning( "trader:{} has run before,old result would be deleted".format( self.trader_name)) self.session.query(business.Trader).filter( business.Trader.trader_name == self.trader_name).delete() self.session.commit() self.on_start() def on_start(self): if not self.selectors: raise Exception( 'please setup self.selectors in init_selectors at first') # run all the selectors for selector in self.selectors: # run for the history data at first selector.run() if self.entity_ids: entity_ids = json.dumps(self.entity_ids) else: entity_ids = None if self.exchanges: exchanges = json.dumps(self.exchanges) else: exchanges = None if self.codes: codes = json.dumps(self.codes) else: codes = None trader_domain = business.Trader( id=self.trader_name, timestamp=self.start_timestamp, trader_name=self.trader_name, entity_type=entity_ids, exchanges=exchanges, codes=codes, start_timestamp=self.start_timestamp, end_timestamp=self.end_timestamp, provider=self.provider, level=self.level.value, real_time=self.real_time, kdata_use_begin_time=self.kdata_use_begin_time) self.session.add(trader_domain) self.session.commit() def init_selectors(self, entity_ids, entity_schema, exchanges, codes, start_timestamp, end_timestamp): """ implement this to init selectors """ raise NotImplementedError def init_selectors_comparator(self): """ overwrite this to set selectors_comparator """ return LimitSelectorsComparator(self.selectors) def add_trading_signal_listener(self, listener): if listener not in self.trading_signal_listeners: self.trading_signal_listeners.append(listener) def remove_trading_signal_listener(self, listener): if listener in self.trading_signal_listeners: self.trading_signal_listeners.remove(listener) def handle_targets_slot(self, timestamp): """ this function would be called in every cycle, you could overwrite it for your custom algorithm to select the targets of different levels the default implementation is selecting the targets in all levels :param timestamp: :type timestamp: """ long_selected = None short_selected = None for level in self.trading_level_desc: targets = self.targets_slot.get_targets(level=level) if targets: long_targets = set(targets[0]) short_targets = set(targets[1]) if not long_selected: long_selected = long_targets else: long_selected = long_selected & long_targets if not short_selected: short_selected = short_targets else: short_selected = short_selected & short_targets self.logger.debug('timestamp:{},long_selected:{}'.format( timestamp, long_selected)) self.logger.debug('timestamp:{},short_selected:{}'.format( timestamp, short_selected)) self.send_trading_signals(timestamp=timestamp, long_selected=long_selected, short_selected=short_selected) def send_trading_signals(self, timestamp, long_selected, short_selected): # current position account = self.account_service.latest_account current_holdings = [ position['entity_id'] for position in account['positions'] if position['available_long'] > 0 ] if long_selected: # just long the security not in the positions longed = long_selected - set(current_holdings) if longed: position_pct = 1.0 / len(longed) order_money = account['cash'] * position_pct for entity_id in longed: trading_signal = TradingSignal( entity_id=entity_id, the_timestamp=timestamp, trading_signal_type=TradingSignalType. trading_signal_open_long, trading_level=self.level, order_money=order_money) for listener in self.trading_signal_listeners: listener.on_trading_signal(trading_signal) # just short the security in current_holdings and short_selected if short_selected: shorted = set(current_holdings) & short_selected for entity_id in shorted: trading_signal = TradingSignal( entity_id=entity_id, the_timestamp=timestamp, trading_signal_type=TradingSignalType. trading_signal_close_long, position_pct=1.0, trading_level=self.level) for listener in self.trading_signal_listeners: listener.on_trading_signal(trading_signal) def on_finish(self): # show the result if self.draw_result: import plotly.io as pio pio.renderers.default = "browser" reader = AccountReader(trader_names=[self.trader_name]) df = reader.data_df.reset_index() drawer = Drawer(main_data=NormalData( df.copy()[['trader_name', 'timestamp', 'all_value']], category_field='trader_name')) drawer.draw_line() def run(self): # iterate timestamp of the min level,e.g,9:30,9:35,9.40...for 5min level # timestamp represents the timestamp in kdata for timestamp in iterate_timestamps( entity_type=self.entity_schema, exchange=self.exchanges[0], start_timestamp=self.start_timestamp, end_timestamp=self.end_timestamp, level=self.level): if not is_trading_date(entity_type=self.entity_schema, exchange=self.exchanges[0], timestamp=timestamp): continue if self.real_time: # all selector move on to handle the coming data if self.kdata_use_begin_time: real_end_timestamp = timestamp + pd.Timedelta( seconds=self.level.to_second()) else: real_end_timestamp = timestamp waiting_seconds, _ = self.level.count_from_timestamp( real_end_timestamp, one_day_trading_minutes=get_one_day_trading_minutes( entity_type=self.entity_schema)) # meaning the future kdata not ready yet,we could move on to check if waiting_seconds and (waiting_seconds > 0): # iterate the selector from min to max which in finished timestamp kdata for level in self.trading_level_asc: if (is_in_finished_timestamps( entity_type=self.entity_schema, exchange=self.exchanges[0], timestamp=timestamp, level=level)): for selector in self.selectors: if selector.level == level: selector.move_on(timestamp, self.kdata_use_begin_time, timeout=waiting_seconds + 20) # on_trading_open to setup the account if self.level == IntervalLevel.LEVEL_1DAY or ( self.level != IntervalLevel.LEVEL_1DAY and is_open_time(entity_type=self.entity_schema, exchange=self.exchanges[0], timestamp=timestamp)): self.account_service.on_trading_open(timestamp) # the time always move on by min level step and we could check all level targets in the slot self.handle_targets_slot(timestamp=timestamp) for level in self.trading_level_asc: # in every cycle, all level selector do its job in its time if (is_in_finished_timestamps( entity_type=self.entity_schema.__name__.lower(), exchange=self.exchanges[0], timestamp=timestamp, level=level)): long_targets, short_targets = self.selectors_comparator.make_decision( timestamp=timestamp, trading_level=level) self.targets_slot.input_targets(level, long_targets, short_targets) # on_trading_close to calculate date account if self.level == IntervalLevel.LEVEL_1DAY or ( self.level != IntervalLevel.LEVEL_1DAY and is_close_time(entity_type=self.entity_schema, exchange=self.exchanges[0], timestamp=timestamp)): self.account_service.on_trading_close(timestamp) self.on_finish()
class Trader(object): logger = logging.getLogger(__name__) # overwrite it to custom your trader selectors_comparator = SelectorsComparator(limit=5) def __init__(self, security_type=SecurityType.stock, exchanges=['sh', 'sz'], codes=None, start_timestamp=None, end_timestamp=None) -> None: self.trader_name = type(self).__name__.lower() self.trading_signal_listeners = [] self.state_listeners = [] self.selectors: List[TargetSelector] = None self.security_type = security_type self.exchanges = exchanges self.codes = codes if start_timestamp and end_timestamp: self.start_timestamp = to_pd_timestamp(start_timestamp) self.end_timestamp = to_pd_timestamp(end_timestamp) else: assert False self.account_service = SimAccountService( trader_name=self.trader_name, timestamp=self.start_timestamp) self.add_trading_signal_listener(self.account_service) self.init_selectors(security_type=self.security_type, exchanges=self.exchanges, codes=self.codes, start_timestamp=self.start_timestamp, end_timestamp=self.end_timestamp) self.selectors_comparator.add_selectors(self.selectors) def init_selectors(self, security_type, exchanges, codes, start_timestamp, end_timestamp): """ implement this to init selectors :param security_type: :type security_type: :param exchanges: :type exchanges: :param codes: :type codes: :param start_timestamp: :type start_timestamp: :param end_timestamp: :type end_timestamp: """ raise NotImplementedError def add_trading_signal_listener(self, listener): if listener not in self.trading_signal_listeners: self.trading_signal_listeners.append(listener) def remove_trading_signal_listener(self, listener): if listener in self.trading_signal_listeners: self.trading_signal_listeners.remove(listener) def run(self): # now we just support day level for timestamp in pd.date_range(start=self.start_timestamp, end=self.end_timestamp, freq='B').tolist(): self.account_service.on_trading_open(timestamp) account = self.account_service.latest_account current_holdings = [ position['security_id'] for position in account['positions'] ] df = self.selectors_comparator.make_decision(timestamp=timestamp) selected = set() if not df.empty: selected = set(df['security_id'].to_list()) if selected: # just long the security not in the positions longed = selected - set(current_holdings) if longed: position_pct = 1.0 / len(longed) order_money = account['cash'] * position_pct for security_id in longed: trading_signal = TradingSignal( security_id=security_id, the_timestamp=timestamp, trading_signal_type=TradingSignalType. trading_signal_open_long, trading_level=TradingLevel.LEVEL_1DAY, order_money=order_money) for listener in self.trading_signal_listeners: listener.on_trading_signal(trading_signal) shorted = set(current_holdings) - selected for security_id in shorted: trading_signal = TradingSignal( security_id=security_id, the_timestamp=timestamp, trading_signal_type=TradingSignalType. trading_signal_close_long, position_pct=1.0, trading_level=TradingLevel.LEVEL_1DAY) for listener in self.trading_signal_listeners: listener.on_trading_signal(trading_signal) self.account_service.on_trading_close(timestamp)
class Trader(Constructor): logger = logging.getLogger(__name__) def __init__(self, security_list: List[str] = None, security_type: Union[str, SecurityType] = SecurityType.stock, exchanges: List[str] = ['sh', 'sz'], codes: List[str] = None, start_timestamp: Union[str, pd.Timestamp] = None, end_timestamp: Union[str, pd.Timestamp] = None, provider: Union[str, Provider] = Provider.JOINQUANT, level: Union[str, TradingLevel] = TradingLevel.LEVEL_1DAY, trader_name: str = None, real_time: bool = False, kdata_use_begin_time: bool = False) -> None: if trader_name: self.trader_name = trader_name else: self.trader_name = type(self).__name__.lower() self.trading_signal_listeners = [] self.state_listeners = [] self.selectors: List[TargetSelector] = None self.security_list = security_list self.security_type = SecurityType(security_type) self.exchanges = exchanges self.codes = codes self.provider = provider # make sure the min level selector correspond to the provider and level self.level = TradingLevel(level) self.real_time = real_time if start_timestamp and end_timestamp: self.start_timestamp = to_pd_timestamp(start_timestamp) self.end_timestamp = to_pd_timestamp(end_timestamp) else: assert False if real_time: logger.info( 'real_time mode, end_timestamp should be future,you could set it big enough for running forever' ) assert self.end_timestamp >= now_pd_timestamp() self.kdata_use_begin_time = kdata_use_begin_time self.account_service = SimAccountService( trader_name=self.trader_name, timestamp=self.start_timestamp, provider=self.provider, level=self.level) self.add_trading_signal_listener(self.account_service) self.init_selectors(security_list=security_list, security_type=self.security_type, exchanges=self.exchanges, codes=self.codes, start_timestamp=self.start_timestamp, end_timestamp=self.end_timestamp) self.selectors_comparator = self.init_selectors_comparator() self.trading_level_asc = list( set([TradingLevel(selector.level) for selector in self.selectors])) self.trading_level_asc.sort() self.trading_level_desc = list(self.trading_level_asc) self.trading_level_desc.reverse() self.targets_slot: TargetsSlot = TargetsSlot() def init_selectors(self, security_list, security_type, exchanges, codes, start_timestamp, end_timestamp): """ implement this to init selectors """ raise NotImplementedError def init_selectors_comparator(self): """ overwrite this to set selectors_comparator """ return LimitSelectorsComparator(self.selectors) def add_trading_signal_listener(self, listener): if listener not in self.trading_signal_listeners: self.trading_signal_listeners.append(listener) def remove_trading_signal_listener(self, listener): if listener in self.trading_signal_listeners: self.trading_signal_listeners.remove(listener) def handle_targets_slot(self, timestamp): """ this function would be called in every cycle, you could overwrite it for your custom algorithm to select the targets of different levels the default implementation is selecting the targets in all levels :param timestamp: :type timestamp: """ selected = None for level in self.trading_level_desc: targets = self.targets_slot.get_targets(level=level) if not targets: targets = set() if not selected: selected = targets else: selected = selected & targets if selected: self.logger.debug('timestamp:{},selected:{}'.format( timestamp, selected)) self.send_trading_signals(timestamp=timestamp, selected=selected) def send_trading_signals(self, timestamp, selected): # current position account = self.account_service.latest_account current_holdings = [ position['security_id'] for position in account['positions'] if position['available_long'] > 0 ] if selected: # just long the security not in the positions longed = selected - set(current_holdings) if longed: position_pct = 1.0 / len(longed) order_money = account['cash'] * position_pct for security_id in longed: trading_signal = TradingSignal( security_id=security_id, the_timestamp=timestamp, trading_signal_type=TradingSignalType. trading_signal_open_long, trading_level=self.level, order_money=order_money) for listener in self.trading_signal_listeners: listener.on_trading_signal(trading_signal) # just short the security not in the selected but in current_holdings if selected: shorted = set(current_holdings) - selected else: shorted = set(current_holdings) for security_id in shorted: trading_signal = TradingSignal( security_id=security_id, the_timestamp=timestamp, trading_signal_type=TradingSignalType. trading_signal_close_long, position_pct=1.0, trading_level=self.level) for listener in self.trading_signal_listeners: listener.on_trading_signal(trading_signal) def on_finish(self): # show the result pass def run(self): # iterate timestamp of the min level,e.g,9:30,9:35,9.40...for 5min level # timestamp represents the timestamp in kdata for timestamp in iterate_timestamps( security_type=self.security_type, exchange=self.exchanges[0], start_timestamp=self.start_timestamp, end_timestamp=self.end_timestamp, level=self.level): if self.real_time: # all selector move on to handle the coming data if self.kdata_use_begin_time: real_end_timestamp = timestamp + pd.Timedelta( seconds=self.level.to_second()) else: real_end_timestamp = timestamp waiting_seconds, _ = self.level.count_from_timestamp( real_end_timestamp, one_day_trading_minutes=get_one_day_trading_minutes( security_type=self.security_type)) # meaning the future kdata not ready yet,we could move on to check if waiting_seconds and (waiting_seconds > 0): # iterate the selector from min to max which in finished timestamp kdata for level in self.trading_level_asc: if (is_in_finished_timestamps( security_type=self.security_type, exchange=self.exchanges[0], timestamp=timestamp, level=level)): for selector in self.selectors: if selector.level == level: selector.move_on(timestamp, self.kdata_use_begin_time) # on_trading_open to setup the account if self.level == TradingLevel.LEVEL_1DAY or ( self.level != TradingLevel.LEVEL_1DAY and is_open_time(security_type=self.security_type, exchange=self.exchanges[0], timestamp=timestamp)): self.account_service.on_trading_open(timestamp) # the time always move on by min level step and we could check all level targets in the slot self.handle_targets_slot(timestamp=timestamp) for level in self.trading_level_asc: # in every cycle, all level selector do its job in its time if (is_in_finished_timestamps(security_type=self.security_type, exchange=self.exchanges[0], timestamp=timestamp, level=level)): df = self.selectors_comparator.make_decision( timestamp=timestamp, trading_level=level) if not df.empty: selected = set(df['security_id'].to_list()) else: selected = {} self.targets_slot.input_targets(level, selected) # on_trading_close to calculate date account if self.level == TradingLevel.LEVEL_1DAY or ( self.level != TradingLevel.LEVEL_1DAY and is_close_time(security_type=self.security_type, exchange=self.exchanges[0], timestamp=timestamp)): self.account_service.on_trading_close(timestamp) self.on_finish()
class Trader(object): logger = logging.getLogger(__name__) def __init__(self, security_type=SecurityType.stock, exchanges=['sh', 'sz'], codes=None, start_timestamp=None, end_timestamp=None, provider=Provider.JOINQUANT, trading_level=TradingLevel.LEVEL_1DAY, trader_name=None) -> None: if trader_name: self.trader_name = trader_name else: self.trader_name = type(self).__name__.lower() self.trading_signal_listeners = [] self.state_listeners = [] self.selectors: List[TargetSelector] = None self.security_type = security_type self.exchanges = exchanges self.codes = codes # make sure the min level selector correspond to the provider and level self.provider = provider self.trading_level = trading_level if start_timestamp and end_timestamp: self.start_timestamp = to_pd_timestamp(start_timestamp) self.end_timestamp = to_pd_timestamp(end_timestamp) else: assert False self.account_service = SimAccountService( trader_name=self.trader_name, timestamp=self.start_timestamp, provider=self.provider, level=self.trading_level) self.add_trading_signal_listener(self.account_service) self.init_selectors(security_type=self.security_type, exchanges=self.exchanges, codes=self.codes, start_timestamp=self.start_timestamp, end_timestamp=self.end_timestamp) self.selectors_comparator = LimitSelectorsComparator(self.selectors) self.trading_level_asc = list( set([TradingLevel(selector.level) for selector in self.selectors])) self.trading_level_asc.sort() self.trading_level_desc = list(self.trading_level_asc) self.trading_level_desc.reverse() self.targets_slot: TargetsSlot = TargetsSlot() def init_selectors(self, security_type, exchanges, codes, start_timestamp, end_timestamp): """ implement this to init selectors :param security_type: :type security_type: :param exchanges: :type exchanges: :param codes: :type codes: :param start_timestamp: :type start_timestamp: :param end_timestamp: :type end_timestamp: """ raise NotImplementedError def add_trading_signal_listener(self, listener): if listener not in self.trading_signal_listeners: self.trading_signal_listeners.append(listener) def remove_trading_signal_listener(self, listener): if listener in self.trading_signal_listeners: self.trading_signal_listeners.remove(listener) def handle_targets_slot(self, timestamp): # the default behavior is selecting the targets in all levels selected = None for level in self.trading_level_desc: targets = self.targets_slot.get_targets(level=level) if not targets: targets = {} if not selected: selected = targets else: selected = selected & targets if selected: self.logger.info('timestamp:{},selected:{}'.format( timestamp, selected)) self.send_trading_signals(timestamp=timestamp, selected=selected) def send_trading_signals(self, timestamp, selected): # current position account = self.account_service.latest_account current_holdings = [ position['security_id'] for position in account['positions'] ] if selected: # just long the security not in the positions longed = selected - set(current_holdings) if longed: position_pct = 1.0 / len(longed) order_money = account['cash'] * position_pct for security_id in longed: trading_signal = TradingSignal( security_id=security_id, the_timestamp=timestamp, trading_signal_type=TradingSignalType. trading_signal_open_long, trading_level=self.trading_level, order_money=order_money) for listener in self.trading_signal_listeners: listener.on_trading_signal(trading_signal) # just short the security not in the selected but in current_holdings if selected: shorted = set(current_holdings) - selected else: shorted = set(current_holdings) for security_id in shorted: trading_signal = TradingSignal( security_id=security_id, the_timestamp=timestamp, trading_signal_type=TradingSignalType. trading_signal_close_long, position_pct=1.0, trading_level=self.trading_level) for listener in self.trading_signal_listeners: listener.on_trading_signal(trading_signal) def on_finish(self): draw_account_details(trader_name=self.trader_name) draw_order_signals(trader_name=self.trader_name) def run(self): # iterate timestamp of the min level for timestamp in iterate_timestamps( security_type=self.security_type, exchange=self.exchanges[0], start_timestamp=self.start_timestamp, end_timestamp=self.end_timestamp, level=self.trading_level): # on_trading_open to setup the account if self.trading_level == TradingLevel.LEVEL_1DAY or ( self.trading_level != TradingLevel.LEVEL_1DAY and is_open_time(security_type=self.security_type, exchange=self.exchanges[0], timestamp=timestamp)): self.account_service.on_trading_open(timestamp) # the time always move on by min level step and we could check all level targets in the slot self.handle_targets_slot(timestamp=timestamp) for level in self.trading_level_asc: # in every cycle, all level selector do its job in its time if (is_in_finished_timestamps(security_type=self.security_type, exchange=self.exchanges[0], timestamp=timestamp, level=level)): df = self.selectors_comparator.make_decision( timestamp=timestamp, trading_level=level) if not df.empty: selected = set(df['security_id'].to_list()) else: selected = {} self.targets_slot.input_targets(level, selected) # on_trading_close to calculate date account if self.trading_level == TradingLevel.LEVEL_1DAY or ( self.trading_level != TradingLevel.LEVEL_1DAY and is_close_time(security_type=self.security_type, exchange=self.exchanges[0], timestamp=timestamp)): self.account_service.on_trading_close(timestamp) self.on_finish()