def __init__(self, entity_type='stock', exchanges=['sh', 'sz'], entity_ids=None, codes=None, batch_size=10, force_update=False, sleeping_time=10, default_size=2000, one_shot=False, fix_duplicate_way='add', start_timestamp=None, end_timestamp=None, contain_unfinished_data=False, level=IntervalLevel.LEVEL_1DAY, kdata_use_begin_time=False, close_hour=0, close_minute=0, one_day_trading_minutes=24 * 60) -> None: super().__init__(entity_type, exchanges, entity_ids, codes, batch_size, force_update, sleeping_time, default_size, one_shot, fix_duplicate_way, start_timestamp, end_timestamp) self.level = IntervalLevel(level) # FIXME:should remove unfinished data when recording,always set it to False now self.contain_unfinished_data = contain_unfinished_data self.kdata_use_begin_time = kdata_use_begin_time self.close_hour = close_hour self.close_minute = close_minute self.one_day_trading_minutes = one_day_trading_minutes
def is_in_same_interval(t1: pd.Timestamp, t2: pd.Timestamp, level: IntervalLevel): t1 = to_pd_timestamp(t1) t2 = to_pd_timestamp(t2) if level == IntervalLevel.LEVEL_1WEEK: return t1.week == t2.week if level == IntervalLevel.LEVEL_1MON: return t1.month == t2.month return level.floor_timestamp(t1) == level.floor_timestamp(t2)
def to_high_level_kdata(kdata_df: pd.DataFrame, to_level: IntervalLevel): def to_close(s): if pd_is_not_null(s): return s[-1] def to_open(s): if pd_is_not_null(s): return s[0] def to_high(s): return np.max(s) def to_low(s): return np.min(s) def to_sum(s): return np.sum(s) original_level = kdata_df['level'][0] entity_id = kdata_df['entity_id'][0] provider = kdata_df['provider'][0] name = kdata_df['name'][0] code = kdata_df['code'][0] entity_type, _, _ = decode_entity_id(entity_id=entity_id) assert IntervalLevel(original_level) <= IntervalLevel.LEVEL_1DAY assert IntervalLevel(original_level) < IntervalLevel(to_level) df: pd.DataFrame = None if to_level == IntervalLevel.LEVEL_1WEEK: # loffset='-2' 用周五作为时间标签 if entity_type == 'stock': df = kdata_df.resample('W', loffset=pd.DateOffset(days=-2)).apply({'close': to_close, 'open': to_open, 'high': to_high, 'low': to_low, 'volume': to_sum, 'turnover': to_sum}) else: df = kdata_df.resample('W', loffset=pd.DateOffset(days=-2)).apply({'close': to_close, 'open': to_open, 'high': to_high, 'low': to_low, 'volume': to_sum, 'turnover': to_sum}) df = df.dropna() # id entity_id timestamp provider code name level df['entity_id'] = entity_id df['provider'] = provider df['code'] = code df['name'] = name return df
def __init__(self, exchanges=['sh', 'sz'], entity_ids=None, codes=None, batch_size=10, force_update=True, sleeping_time=10, default_size=2000, real_time=False, fix_duplicate_way='ignore', start_timestamp=None, end_timestamp=None, level=IntervalLevel.LEVEL_1WEEK, kdata_use_begin_time=False, close_hour=15, close_minute=0, one_day_trading_minutes=4 * 60) -> None: level = IntervalLevel(level) self.data_schema = get_kdata_schema(entity_type='stock', level=level) self.jq_trading_level = to_jq_trading_level(level) super().__init__('stock', exchanges, entity_ids, codes, batch_size, force_update, sleeping_time, default_size, real_time, fix_duplicate_way, start_timestamp, end_timestamp, close_hour, close_minute, level, kdata_use_begin_time, one_day_trading_minutes) self.factor = 0 self.last_timestamp = None auth(zvt_env['jq_username'], zvt_env['jq_password'])
def coin_finished_timestamp(timestamp: pd.Timestamp, level: IntervalLevel): timestamp = to_pd_timestamp(timestamp) if timestamp.microsecond != 0: return False return timestamp.minute % level.to_minute() == 0
def __init__(self, entity_type='stock', exchanges=['sh', 'sz'], entity_ids=None, codes=None, batch_size=10, force_update=True, sleeping_time=10, default_size=2000, real_time=False, fix_duplicate_way='ignore', start_timestamp=None, end_timestamp=None, close_hour=0, close_minute=0, # child add level=IntervalLevel.LEVEL_1DAY, kdata_use_begin_time=False, one_day_trading_minutes=24 * 60) -> None: super().__init__(entity_type, exchanges, entity_ids, codes, batch_size, force_update, sleeping_time, default_size, real_time, fix_duplicate_way, start_timestamp, end_timestamp, close_hour, close_minute) self.level = IntervalLevel(level) self.kdata_use_begin_time = kdata_use_begin_time self.one_day_trading_minutes = one_day_trading_minutes
def get_zen_factor_schema(entity_type: str, level: Union[IntervalLevel, str] = IntervalLevel.LEVEL_1DAY): if type(level) == str: level = IntervalLevel(level) schema_str = '{}{}ZenFactor'.format(entity_type.capitalize(), level.value.capitalize()) return eval(schema_str)
def get_ma_state_stats_schema(entity_type: str, level: Union[IntervalLevel, str] = IntervalLevel.LEVEL_1DAY): if type(level) == str: level = IntervalLevel(level) # ma state stats schema rule # 1)name:{SecurityType.value.capitalize()}{IntervalLevel.value.upper()}MaStateStats schema_str = '{}{}MaStateStats'.format(entity_type.capitalize(), level.value.capitalize()) return eval(schema_str)
def level_flag(level: IntervalLevel): level = IntervalLevel(level) if level == IntervalLevel.LEVEL_1DAY: return 101 if level == IntervalLevel.LEVEL_1WEEK: return 102 if level == IntervalLevel.LEVEL_1MON: return 103 assert False
def evaluate_size_from_timestamp(start_timestamp, level: IntervalLevel, one_day_trading_minutes, end_timestamp: pd.Timestamp = None): """ given from timestamp,level,one_day_trading_minutes,this func evaluate size of kdata to current. it maybe a little bigger than the real size for fetching all the kdata. :param start_timestamp: :type start_timestamp: pd.Timestamp :param level: :type level: IntervalLevel :param one_day_trading_minutes: :type one_day_trading_minutes: int """ if not end_timestamp: end_timestamp = pd.Timestamp.now() else: end_timestamp = to_pd_timestamp(end_timestamp) time_delta = end_timestamp - to_pd_timestamp(start_timestamp) one_day_trading_seconds = one_day_trading_minutes * 60 if level == IntervalLevel.LEVEL_1DAY: return time_delta.days + 1 if level == IntervalLevel.LEVEL_1WEEK: return int(math.ceil(time_delta.days / 7)) + 1 if level == IntervalLevel.LEVEL_1MON: return int(math.ceil(time_delta.days / 30)) + 1 if time_delta.days > 0: seconds = (time_delta.days + 1) * one_day_trading_seconds return int(math.ceil(seconds / level.to_second())) + 1 else: seconds = time_delta.total_seconds() return min( int(math.ceil(seconds / level.to_second())) + 1, one_day_trading_seconds / level.to_second() + 1)
def gen_quote_domain(providers: List[str], entity_type: str, levels): tables = [] for level in levels: level = IntervalLevel(level) cap_entity_type = entity_type.capitalize() cap_level = level.value.capitalize() if level != IntervalLevel.LEVEL_TICK: kdata_common = f'{cap_entity_type}KdataCommon' else: kdata_common = f'{cap_entity_type}TickCommon' class_name = f'{cap_entity_type}{cap_level}Kdata' table_name = f'{entity_type}_{level.value}_kdata' tables.append(table_name) domain_template = f'''# -*- coding: utf-8 -*- # this file is generated by gen_quote_domain function, dont't change it from sqlalchemy.ext.declarative import declarative_base from zvdata.contract import register_schema from zvt.domain.quotes import {kdata_common} KdataBase = declarative_base() class {class_name}(KdataBase, {kdata_common}): __tablename__ = '{table_name}' register_schema(providers={providers}, db_name='{table_name}', schema_base=KdataBase) __all__ = ['{class_name}'] ''' # generate the domain with open(os.path.join(entity_type, f'{table_name}.py'), 'w') as outfile: outfile.write(domain_template) # generate the package imports = [ f'from zvt.domain.quotes.{entity_type}.{table} import *' for table in tables ] imports_str = '\n'.join(imports) package_template = '''# -*- coding: utf-8 -*- # this file is generated by gen_quote_domain function, dont't change it ''' + imports_str with open(os.path.join(entity_type, '__init__.py'), 'w') as outfile: outfile.write(package_template)
keep_all_timestamp, fill_method, effective_number, transformer, acc, persist_factor, dry_run) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--level', help='trading level', default='1d', choices=[item.value for item in IntervalLevel]) parser.add_argument('--start', help='start code', default='000001') parser.add_argument('--end', help='end code', default='000005') args = parser.parse_args() level = IntervalLevel(args.level) start = args.start end = args.end entities = get_entities(provider='eastmoney', entity_type='stock', columns=[Stock.entity_id, Stock.code], filters=[Stock.code >= start, Stock.code < end]) codes = entities.index.to_list() factor = ZenFactor(codes=codes, start_timestamp='2005-01-01', end_timestamp=now_pd_timestamp(), level=level)
class FixedCycleDataRecorder(TimeSeriesDataRecorder): def __init__(self, entity_type='stock', exchanges=['sh', 'sz'], entity_ids=None, codes=None, batch_size=10, force_update=False, sleeping_time=10, default_size=2000, one_shot=False, fix_duplicate_way='add', start_timestamp=None, end_timestamp=None, contain_unfinished_data=False, level=IntervalLevel.LEVEL_1DAY, kdata_use_begin_time=False, close_hour=0, close_minute=0, one_day_trading_minutes=24 * 60) -> None: super().__init__(entity_type, exchanges, entity_ids, codes, batch_size, force_update, sleeping_time, default_size, one_shot, fix_duplicate_way, start_timestamp, end_timestamp) self.level = IntervalLevel(level) # FIXME:should remove unfinished data when recording,always set it to False now self.contain_unfinished_data = contain_unfinished_data self.kdata_use_begin_time = kdata_use_begin_time self.close_hour = close_hour self.close_minute = close_minute self.one_day_trading_minutes = one_day_trading_minutes def get_latest_saved_record(self, entity): order = eval('self.data_schema.{}.desc()'.format( self.get_evaluated_time_field())) return get_data(entity_id=entity.id, provider=self.provider, data_schema=self.data_schema, order=order, limit=1, return_type='domain', session=self.session, level=self.level.value) def evaluate_start_end_size_timestamps(self, entity): # get latest record latest_saved_record = self.get_latest_saved_record(entity=entity) if latest_saved_record: latest_timestamp = latest_saved_record[0].timestamp else: latest_timestamp = entity.timestamp if not latest_timestamp: return latest_timestamp, None, self.default_size, None current_time = pd.Timestamp.now() time_delta = current_time - latest_timestamp if self.level == IntervalLevel.LEVEL_1DAY: if is_same_date(current_time, latest_timestamp): return latest_timestamp, None, 0, None return latest_timestamp, None, time_delta.days + 1, None # to today,check closing time # 0,0 means never stop,e.g,coin if (self.close_hour != 0 and self.close_minute != 0) and time_delta.days == 0: if latest_timestamp.hour == self.close_hour and latest_timestamp.minute == self.close_minute: return latest_timestamp, None, 0, None if self.kdata_use_begin_time: touching_timestamp = latest_timestamp + pd.Timedelta( seconds=self.level.to_second()) else: touching_timestamp = latest_timestamp waiting_seconds, size = self.level.count_from_timestamp( touching_timestamp, one_day_trading_minutes=self.one_day_trading_minutes) if not self.one_shot and waiting_seconds and (waiting_seconds > 30): t = waiting_seconds / 2 self.logger.info( 'level:{},recorded_time:{},touching_timestamp:{},current_time:{},next_ok_time:{},just sleep:{} seconds' .format( self.level.value, latest_timestamp, touching_timestamp, current_time, touching_timestamp + pd.Timedelta(seconds=self.level.to_second()), t)) time.sleep(t) return latest_timestamp, None, size, None def persist(self, entity, domain_list): if domain_list: if domain_list[0].timestamp >= domain_list[-1].timestamp: first_timestamp = domain_list[-1].timestamp last_timestamp = domain_list[0].timestamp else: first_timestamp = domain_list[0].timestamp last_timestamp = domain_list[-1].timestamp self.logger.info( "persist {} for entity_id:{},time interval:[{},{}]".format( self.data_schema, entity.id, first_timestamp, last_timestamp)) current_timestamp = now_pd_timestamp() saving_datas = domain_list # FIXME:remove this logic # FIXME:should remove unfinished data when recording,always set it to False now if is_same_date(current_timestamp, last_timestamp) and self.contain_unfinished_data: if current_timestamp.hour >= self.close_hour and current_timestamp.minute >= self.close_minute + 2: # after the closing time of the day,we think the last data is finished saving_datas = domain_list else: # ignore unfinished kdata saving_datas = domain_list[:-1] self.logger.info( "ignore kdata for entity_id:{},level:{},timestamp:{},current_timestamp" .format(entity.id, self.level, last_timestamp, current_timestamp)) self.session.add_all(saving_datas) self.session.commit()
def __init__(self, data_schema: object, entity_ids: List[str] = None, entity_type: str = 'stock', exchanges: List[str] = ['sh', 'sz'], codes: List[str] = None, the_timestamp: Union[str, pd.Timestamp] = None, start_timestamp: Union[str, pd.Timestamp] = '2018-01-01', end_timestamp: Union[str, pd.Timestamp] = '2019-06-23', columns: List = None, filters: List = None, order: object = None, limit: int = None, provider: str = 'eastmoney', level: IntervalLevel = IntervalLevel.LEVEL_1DAY, category_field: str = 'entity_id', time_field: str = 'timestamp', trip_timestamp: bool = True, auto_load: bool = True) -> None: self.data_schema = data_schema self.the_timestamp = the_timestamp if the_timestamp: self.start_timestamp = the_timestamp self.end_timestamp = the_timestamp else: self.start_timestamp = start_timestamp self.end_timestamp = end_timestamp self.start_timestamp = to_pd_timestamp(self.start_timestamp) self.end_timestamp = to_pd_timestamp(self.end_timestamp) self.entity_type = entity_type self.exchanges = exchanges if codes: if type(codes) == str: codes = codes.replace(' ', '') if codes.startswith('[') and codes.endswith(']'): codes = json.loads(codes) else: codes = codes.split(',') self.codes = codes self.entity_ids = entity_ids self.provider = provider self.filters = filters self.order = order self.limit = limit if level: self.level = IntervalLevel(level) else: self.level = level self.category_field = category_field self.time_field = time_field self.trip_timestamp = trip_timestamp self.auto_load = auto_load self.category_column = eval('self.data_schema.{}'.format(self.category_field)) self.columns = columns # we store the data in a multiple index(category_column,timestamp) Dataframe if self.columns: # support str if type(columns[0]) == str: self.columns = [] for col in columns: self.columns.append(eval('data_schema.{}'.format(col))) time_col = eval('self.data_schema.{}'.format(self.time_field)) # always add category_column and time_field for normalizing self.columns = list(set(self.columns) | {self.category_column, time_col}) self.data_listeners: List[DataListener] = [] self.data_df: pd.DataFrame = None self.normal_data: NormalData = None if self.auto_load: self.load_data()
def __init__(self, entity_ids: List[str] = None, exchanges: List[str] = ['sh', 'sz'], codes: List[str] = None, start_timestamp: Union[str, pd.Timestamp] = None, end_timestamp: Union[str, pd.Timestamp] = None, provider: str = None, level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY, trader_name: str = None, real_time: bool = False, kdata_use_begin_time: bool = False, draw_result: bool = True) -> None: assert self.entity_schema is not None if trader_name: self.trader_name = trader_name else: self.trader_name = type(self).__name__.lower() self.trading_signal_listeners = [] self.selectors: List[TargetSelector] = [] self.entity_ids = entity_ids self.exchanges = exchanges self.codes = codes self.provider = provider # make sure the min level selector correspond to the provider and level self.level = IntervalLevel(level) self.real_time = real_time if start_timestamp and end_timestamp: self.start_timestamp = to_pd_timestamp(start_timestamp) self.end_timestamp = to_pd_timestamp(end_timestamp) else: assert False if real_time: logger.info( 'real_time mode, end_timestamp should be future,you could set it big enough for running forever' ) assert self.end_timestamp >= now_pd_timestamp() self.kdata_use_begin_time = kdata_use_begin_time self.draw_result = draw_result self.account_service = SimAccountService( trader_name=self.trader_name, timestamp=self.start_timestamp, provider=self.provider, level=self.level) self.add_trading_signal_listener(self.account_service) self.init_selectors(entity_ids=entity_ids, entity_schema=self.entity_schema, exchanges=self.exchanges, codes=self.codes, start_timestamp=self.start_timestamp, end_timestamp=self.end_timestamp) self.selectors_comparator = self.init_selectors_comparator() self.trading_level_asc = list( set([IntervalLevel(selector.level) for selector in self.selectors])) self.trading_level_asc.sort() self.trading_level_desc = list(self.trading_level_asc) self.trading_level_desc.reverse() self.targets_slot: TargetsSlot = TargetsSlot() self.session = get_db_session('zvt', 'business') trader = get_trader(session=self.session, trader_name=self.trader_name, return_type='domain', limit=1) if trader: self.logger.warning( "trader:{} has run before,old result would be deleted".format( self.trader_name)) self.session.query(business.Trader).filter( business.Trader.trader_name == self.trader_name).delete() self.session.commit() self.on_start()
class Trader(object): logger = logging.getLogger(__name__) entity_schema: EntityMixin = None def __init__(self, entity_ids: List[str] = None, exchanges: List[str] = ['sh', 'sz'], codes: List[str] = None, start_timestamp: Union[str, pd.Timestamp] = None, end_timestamp: Union[str, pd.Timestamp] = None, provider: str = None, level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY, trader_name: str = None, real_time: bool = False, kdata_use_begin_time: bool = False, draw_result: bool = True) -> None: assert self.entity_schema is not None if trader_name: self.trader_name = trader_name else: self.trader_name = type(self).__name__.lower() self.trading_signal_listeners = [] self.selectors: List[TargetSelector] = [] self.entity_ids = entity_ids self.exchanges = exchanges self.codes = codes self.provider = provider # make sure the min level selector correspond to the provider and level self.level = IntervalLevel(level) self.real_time = real_time if start_timestamp and end_timestamp: self.start_timestamp = to_pd_timestamp(start_timestamp) self.end_timestamp = to_pd_timestamp(end_timestamp) else: assert False if real_time: logger.info( 'real_time mode, end_timestamp should be future,you could set it big enough for running forever' ) assert self.end_timestamp >= now_pd_timestamp() self.kdata_use_begin_time = kdata_use_begin_time self.draw_result = draw_result self.account_service = SimAccountService( trader_name=self.trader_name, timestamp=self.start_timestamp, provider=self.provider, level=self.level) self.add_trading_signal_listener(self.account_service) self.init_selectors(entity_ids=entity_ids, entity_schema=self.entity_schema, exchanges=self.exchanges, codes=self.codes, start_timestamp=self.start_timestamp, end_timestamp=self.end_timestamp) self.selectors_comparator = self.init_selectors_comparator() self.trading_level_asc = list( set([IntervalLevel(selector.level) for selector in self.selectors])) self.trading_level_asc.sort() self.trading_level_desc = list(self.trading_level_asc) self.trading_level_desc.reverse() self.targets_slot: TargetsSlot = TargetsSlot() self.session = get_db_session('zvt', 'business') trader = get_trader(session=self.session, trader_name=self.trader_name, return_type='domain', limit=1) if trader: self.logger.warning( "trader:{} has run before,old result would be deleted".format( self.trader_name)) self.session.query(business.Trader).filter( business.Trader.trader_name == self.trader_name).delete() self.session.commit() self.on_start() def on_start(self): if not self.selectors: raise Exception( 'please setup self.selectors in init_selectors at first') # run all the selectors for selector in self.selectors: # run for the history data at first selector.run() if self.entity_ids: entity_ids = json.dumps(self.entity_ids) else: entity_ids = None if self.exchanges: exchanges = json.dumps(self.exchanges) else: exchanges = None if self.codes: codes = json.dumps(self.codes) else: codes = None trader_domain = business.Trader( id=self.trader_name, timestamp=self.start_timestamp, trader_name=self.trader_name, entity_type=entity_ids, exchanges=exchanges, codes=codes, start_timestamp=self.start_timestamp, end_timestamp=self.end_timestamp, provider=self.provider, level=self.level.value, real_time=self.real_time, kdata_use_begin_time=self.kdata_use_begin_time) self.session.add(trader_domain) self.session.commit() def init_selectors(self, entity_ids, entity_schema, exchanges, codes, start_timestamp, end_timestamp): """ implement this to init selectors """ raise NotImplementedError def init_selectors_comparator(self): """ overwrite this to set selectors_comparator """ return LimitSelectorsComparator(self.selectors) def add_trading_signal_listener(self, listener): if listener not in self.trading_signal_listeners: self.trading_signal_listeners.append(listener) def remove_trading_signal_listener(self, listener): if listener in self.trading_signal_listeners: self.trading_signal_listeners.remove(listener) def handle_targets_slot(self, timestamp): """ this function would be called in every cycle, you could overwrite it for your custom algorithm to select the targets of different levels the default implementation is selecting the targets in all levels :param timestamp: :type timestamp: """ long_selected = None short_selected = None for level in self.trading_level_desc: targets = self.targets_slot.get_targets(level=level) if targets: long_targets = set(targets[0]) short_targets = set(targets[1]) if not long_selected: long_selected = long_targets else: long_selected = long_selected & long_targets if not short_selected: short_selected = short_targets else: short_selected = short_selected & short_targets self.logger.debug('timestamp:{},long_selected:{}'.format( timestamp, long_selected)) self.logger.debug('timestamp:{},short_selected:{}'.format( timestamp, short_selected)) self.send_trading_signals(timestamp=timestamp, long_selected=long_selected, short_selected=short_selected) def send_trading_signals(self, timestamp, long_selected, short_selected): # current position account = self.account_service.latest_account current_holdings = [ position['entity_id'] for position in account['positions'] if position['available_long'] > 0 ] if long_selected: # just long the security not in the positions longed = long_selected - set(current_holdings) if longed: position_pct = 1.0 / len(longed) order_money = account['cash'] * position_pct for entity_id in longed: trading_signal = TradingSignal( entity_id=entity_id, the_timestamp=timestamp, trading_signal_type=TradingSignalType. trading_signal_open_long, trading_level=self.level, order_money=order_money) for listener in self.trading_signal_listeners: listener.on_trading_signal(trading_signal) # just short the security in current_holdings and short_selected if short_selected: shorted = set(current_holdings) & short_selected for entity_id in shorted: trading_signal = TradingSignal( entity_id=entity_id, the_timestamp=timestamp, trading_signal_type=TradingSignalType. trading_signal_close_long, position_pct=1.0, trading_level=self.level) for listener in self.trading_signal_listeners: listener.on_trading_signal(trading_signal) def on_finish(self): # show the result if self.draw_result: import plotly.io as pio pio.renderers.default = "browser" reader = AccountReader(trader_names=[self.trader_name]) drawer = Drawer(main_data=NormalData(reader.data_df.copy()[ ['trader_name', 'timestamp', 'all_value']], category_field='trader_name')) drawer.draw_line() def run(self): # iterate timestamp of the min level,e.g,9:30,9:35,9.40...for 5min level # timestamp represents the timestamp in kdata for timestamp in iterate_timestamps( entity_type=self.entity_schema, exchange=self.exchanges[0], start_timestamp=self.start_timestamp, end_timestamp=self.end_timestamp, level=self.level): if not is_trading_date(entity_type=self.entity_schema, exchange=self.exchanges[0], timestamp=timestamp): continue if self.real_time: # all selector move on to handle the coming data if self.kdata_use_begin_time: real_end_timestamp = timestamp + pd.Timedelta( seconds=self.level.to_second()) else: real_end_timestamp = timestamp waiting_seconds, _ = self.level.count_from_timestamp( real_end_timestamp, one_day_trading_minutes=get_one_day_trading_minutes( entity_type=self.entity_schema)) # meaning the future kdata not ready yet,we could move on to check if waiting_seconds and (waiting_seconds > 0): # iterate the selector from min to max which in finished timestamp kdata for level in self.trading_level_asc: if (is_in_finished_timestamps( entity_type=self.entity_schema, exchange=self.exchanges[0], timestamp=timestamp, level=level)): for selector in self.selectors: if selector.level == level: selector.move_on(timestamp, self.kdata_use_begin_time, timeout=waiting_seconds + 20) # on_trading_open to setup the account if self.level == IntervalLevel.LEVEL_1DAY or ( self.level != IntervalLevel.LEVEL_1DAY and is_open_time(entity_type=self.entity_schema, exchange=self.exchanges[0], timestamp=timestamp)): self.account_service.on_trading_open(timestamp) # the time always move on by min level step and we could check all level targets in the slot self.handle_targets_slot(timestamp=timestamp) for level in self.trading_level_asc: # in every cycle, all level selector do its job in its time if (is_in_finished_timestamps(entity_type=self.entity_schema, exchange=self.exchanges[0], timestamp=timestamp, level=level)): long_targets, short_targets = self.selectors_comparator.make_decision( timestamp=timestamp, trading_level=level) self.targets_slot.input_targets(level, long_targets, short_targets) # on_trading_close to calculate date account if self.level == IntervalLevel.LEVEL_1DAY or ( self.level != IntervalLevel.LEVEL_1DAY and is_close_time(entity_type=self.entity_schema, exchange=self.exchanges[0], timestamp=timestamp)): self.account_service.on_trading_close(timestamp) self.on_finish()
def __init__(self, data_schema: Mixin, entity_provider: str = None, entity_ids: List[str] = None, entity_type: str = 'stock', exchanges: List[str] = ['sh', 'sz'], codes: List[str] = None, the_timestamp: Union[str, pd.Timestamp] = None, start_timestamp: Union[str, pd.Timestamp] = None, end_timestamp: Union[str, pd.Timestamp] = now_pd_timestamp(), columns: List = None, filters: List = None, order: object = None, limit: int = None, provider: str = None, level: IntervalLevel = IntervalLevel.LEVEL_1DAY, category_field: str = 'entity_id', time_field: str = 'timestamp', computing_window: int = None) -> None: self.logger = logging.getLogger(self.__class__.__name__) self.data_schema = data_schema self.the_timestamp = the_timestamp if the_timestamp: self.start_timestamp = the_timestamp self.end_timestamp = the_timestamp else: self.start_timestamp = start_timestamp self.end_timestamp = end_timestamp self.start_timestamp = to_pd_timestamp(self.start_timestamp) self.end_timestamp = to_pd_timestamp(self.end_timestamp) self.entity_type = entity_type self.entity_provider = entity_provider self.exchanges = exchanges if codes: if type(codes) == str: codes = codes.replace(' ', '') if codes.startswith('[') and codes.endswith(']'): codes = json.loads(codes) else: codes = codes.split(',') self.codes = codes self.entity_ids = entity_ids # 转换成标准entity_id if not self.entity_ids and self.entity_provider: self.entity_ids = get_entity_ids(provider=self.entity_provider, entity_type=self.entity_type, exchanges=self.exchanges, codes=self.codes) self.provider = provider self.filters = filters self.order = order self.limit = limit if level: self.level = IntervalLevel(level) else: self.level = level self.category_field = category_field self.time_field = time_field self.computing_window = computing_window self.category_col = eval('self.data_schema.{}'.format( self.category_field)) self.time_col = eval('self.data_schema.{}'.format(self.time_field)) self.columns = columns # we store the data in a multiple index(category_column,timestamp) Dataframe if self.columns: # support str if type(columns[0]) == str: self.columns = [] for col in columns: self.columns.append(eval('data_schema.{}'.format(col))) # always add category_column and time_field for normalizing self.columns = list( set(self.columns) | {self.category_col, self.time_col}) self.data_listeners: List[DataListener] = [] self.data_df: pd.DataFrame = None self.load_data()
def next_timestamp(current_timestamp: pd.Timestamp, level: IntervalLevel) -> pd.Timestamp: current_timestamp = to_pd_timestamp(current_timestamp) return current_timestamp + pd.Timedelta(seconds=level.to_second())
def is_finished_kdata_timestamp(timestamp, level: IntervalLevel): timestamp = to_pd_timestamp(timestamp) if level.floor_timestamp(timestamp) == timestamp: return True return False