def set_env(self, environment): self._env = environment config = environment.config #取出config配置详情,如果启用某模块,则将模块名和配置放置到属性_mod_list中 for mod_name in config.mod.__dict__: mod_config = getattr(config.mod, mod_name) if not mod_config.enabled: continue self._mod_list.append((mod_name, mod_config))#每个模块,0:mod_name,1:mod_config # 读取每个mod的模块位置,并导入 for idx, (mod_name, user_mod_config) in enumerate(self._mod_list): if hasattr(user_mod_config, 'lib'): lib_name = user_mod_config.lib elif mod_name in SYSTEM_MOD_LIST: lib_name = "rqalpha.mod.rqalpha_mod_" + mod_name else: lib_name = "rqalpha_mod_" + mod_name system_log.debug(_(u"loading mod {}").format(lib_name)) mod_module = import_mod(lib_name) if mod_module is None: del self._mod_list[idx] return mod = mod_module.load_mod() #处理mod配置信息,使用用户配置覆盖默认配置 mod_config = RqAttrDict(copy.deepcopy(getattr(mod_module, "__config__", {}))) mod_config.update(user_mod_config) setattr(config.mod, mod_name, mod_config) self._mod_list[idx] = (mod_name, mod_config)#更新 list 容器 self._mod_dict[mod_name] = mod#更新 dict 容器 self._mod_list.sort(key=lambda item: getattr(item[1], "priority", 100)) environment.mod_dict = self._mod_dict#把mod绑定到环境变量
def convert2date(date: str or datetime.date or datetime.datetime or numpy.str_): """ 转换为 datetime.date 格式 :param date: :return: """ if type(date) == numpy.str_: date = str(date) if isinstance(date, str): if re.match(date_format[0], date): return datetime.date(int(date[:4]), int(date[5:7]), int(date[8:10])) elif re.match(date_format[1], date): return datetime.date(int(date[:4]), int(date[5:7]), int(date[8:10])) elif re.match(date_format[2], date): return datetime.date(int(date[:4]), int(date[4:6]), int(date[6:8])) else: system_log.debug( '3-->日期格式不对 支持格式"YYYYMMDD"或"YYYY-MM-DD"或"YYYY:MM:DD"或datetime.date' ) return None elif isinstance(date, datetime.datetime): return date.date() elif isinstance(date, datetime.date): return date else: system_log.debug( '4-->日期格式不对 支持格式"YYYYMMDD"或"YYYY-MM-DD"或"YYYY:MM:DD"或datetime.date' ) return None
def restore(self): for key, obj in six.iteritems(self._objects): state = self._persist_provider.load(key) system_log.debug('restore {} with state = {}', key, state) if not state: continue obj.set_state(state)
def _on_settlement(self, event): self._static_total_value = self.total_value for position in list(self._positions.values()): order_book_id = position.order_book_id if position.is_de_listed( ) and position.buy_quantity + position.sell_quantity != 0: user_system_log.warn( _(u"{order_book_id} is expired, close all positions by system" ).format(order_book_id=order_book_id)) del self._positions[order_book_id] elif position.buy_quantity == 0 and position.sell_quantity == 0: del self._positions[order_book_id] else: position.apply_settlement() # 如果 total_value <= 0 则认为已爆仓,清空仓位,资金归0 if self._static_total_value <= 0 and self.forced_liquidation: if self._positions: user_system_log.warn( _("Trigger Forced Liquidation, current total_value is 0")) self._positions.clear() self._static_total_value = 0 self._backward_trade_set.clear() system_log.debug( "future account applied settlement, current state: {}".format( self.get_state()))
def set_env(self, environment): self._env = environment config = environment.config for mod_name in config.mod.__dict__: mod_config = getattr(config.mod, mod_name) if not mod_config.enabled: continue self._mod_list.append((mod_name, mod_config)) for idx, (mod_name, user_mod_config) in enumerate(self._mod_list): if hasattr(user_mod_config, 'lib'): lib_name = user_mod_config.lib elif mod_name in SYSTEM_MOD_LIST: lib_name = "rqalpha.mod.rqalpha_mod_" + mod_name else: lib_name = "rqalpha_mod_" + mod_name system_log.debug(_(u"loading mod {}").format(lib_name)) mod_module = import_mod(lib_name) if mod_module is None: del self._mod_list[idx] return mod = mod_module.load_mod() mod_config = RqAttrDict(copy.deepcopy(getattr(mod_module, "__config__", {}))) mod_config.update(user_mod_config) setattr(config.mod, mod_name, mod_config) self._mod_list[idx] = (mod_name, mod_config) self._mod_dict[mod_name] = mod self._mod_list.sort(key=lambda item: getattr(item[1], "priority", 100)) environment.mod_dict = self._mod_dict
def _on_trade(self, event): if self != event.account: return self._apply_trade(event.trade, event.order) system_log.debug( "future account applied trade, current state: {}".format( self.get_state()))
def get_bar_day(self, instrument, dt): if dt is None: dt = datetime.now() current_time = time.strftime("%Y%m%d", time.localtime()) dt_time = dt.strftime("%Y%m%d") system_log.debug("order_book_id6 " + str(self._dt)) if dt_time == current_time: # 判断时间是否是当天,每天都是要清空缓存,所以要先获取历史 if self._cache[ 'history_kline'] is None or instrument.order_book_id not in self._cache[ 'history_kline'].keys(): ret_code, bar_data = self._get_cur_cache(instrument) else: ret_code, bar_data = 0, self._cache['history_kline'][ instrument.order_book_id] elif dt_time != current_time: if self._cache[ 'history_kline'] is None or instrument.order_book_id not in self._cache[ 'history_kline'].keys(): ret_code, bar_data = self._get_history_cache(instrument) else: ret_code, bar_data = 0, self._cache['history_kline'][ instrument.order_book_id] if ret_code == RET_ERROR or bar_data is None: raise Exception("can't get bar data") system_log.debug("order_book_id5 , dt=" + str(self._dt)) ret_dict = bar_data[ bar_data.datetime <= int(dt_time + "000000")].iloc[0].to_dict() return ret_dict
def _ensure_before_trading(self, event): # return True if before_trading won't run this time if self._last_before_trading == event.trading_dt.date( ) or self._env.config.extra.is_hold: return True if self._last_before_trading: # don't publish settlement on first day previous_trading_date = self._env.data_proxy.get_previous_trading_date( event.trading_dt).date() if self._env.trading_dt.date() != previous_trading_date: self._env.update_time( datetime.combine(previous_trading_date, self._env.calendar_dt.time()), datetime.combine(previous_trading_date, self._env.trading_dt.time())) system_log.debug( "publish settlement events with calendar_dt={}, trading_dt={}". format(self._env.calendar_dt, self._env.trading_dt)) self._split_and_publish(Event(EVENT.SETTLEMENT)) self._last_before_trading = event.trading_dt.date() self._split_and_publish( Event(EVENT.BEFORE_TRADING, calendar_dt=event.calendar_dt, trading_dt=event.trading_dt)) return False
def get_tick(self): while True: try: return self._tick_que.get(block=True, timeout=1) except Empty: system_log.debug("get tick timeout") continue
def format_date(date: str or datetime.date or datetime.datetime): """ 转换为 2018-10-10 的格式 :param self: :param date: :return: """ if isinstance(date, str): if re.match(date_format[0], date): pass elif re.match(date_format[1], date): date = date.replace(':', '-') elif re.match(date_format[2], date): date = date[:4] + '-' + date[4:6] + '-' + date[6:8] else: system_log.debug( "1-->类型日期格式错误 支持格式'YYYYMMDD'或'YYYY-MM-DD'或'YYYY:MM:DD'或datetime.date" ) return None elif isinstance(date, datetime.datetime): date = date.date().strftime("%Y-%m-%d") elif isinstance(date, datetime.date): date = date.strftime("%Y-%m-%d") else: system_log.debug( '2-->类型日期格式错误,支持格式"YYYYMMDD"或"YYYY-MM-DD"或"YYYY:MM:DD"或datetime.date' ) return None return date
def wrapper(*args, **kwargs): if not Environment.get_instance().config.extra.is_hold: return func(*args, **kwargs) else: system_log.debug( _(u"not run {}({}, {}) because strategy is hold").format( func, args, kwargs))
def dis_collection2factor_map(factors: list or f, factor2collection_map): collection2factor_map = dict() if type(factors) == f: factors = [factors] elif type(factors) != list: system_log.debug("错误的factors输入!") return None for factor in factors: if factor.name in factor2collection_map: collection = 1 # 如果没有转译,因子名直接映射str,就是 collection if type(factor2collection_map[ factor.name]) == str: # 目前实现的都是未经过转译的 collection = factor2collection_map[factor.name] # 如果有转译,因子名需要转译为对应的数据库字段名 # elif type(factor2collection[factor.name]) == dict: # collection = factor2collection[factor.name]['collection'] # factor = f(factor2collection[factor.name]['field'], factor.params) if collection in collection2factor_map: # 集合名已经存在于 table 中 collection2factor_map[ collection] = collection2factor_map[collection] + [factor] else: collection2factor_map[collection] = [factor] else: system_log.debug('未收录因子' + factor.name + ',或您没有获取该因子权限,请联系管理员。') return collection2factor_map
def set_env(self, environment): self._env = environment config = environment.config for mod_name in config.mod.__dict__: mod_config = getattr(config.mod, mod_name) if not mod_config.enabled: continue self._mod_list.append((mod_name, mod_config)) for idx, (mod_name, user_mod_config) in enumerate(self._mod_list): if hasattr(user_mod_config, 'lib'): lib_name = user_mod_config.lib elif mod_name in SYSTEM_MOD_LIST: lib_name = "rqalpha.mod.rqalpha_mod_" + mod_name else: lib_name = "rqalpha_mod_" + mod_name system_log.debug(_(u"loading mod {}").format(lib_name)) mod_module = import_mod(lib_name) if mod_module is None: del self._mod_list[idx] return mod = mod_module.load_mod() mod_config = RqAttrDict( copy.deepcopy(getattr(mod_module, "__config__", {}))) mod_config.update(user_mod_config) setattr(config.mod, mod_name, mod_config) self._mod_list[idx] = (mod_name, mod_config) self._mod_dict[mod_name] = mod self._mod_list.sort(key=lambda item: getattr(item[1], "priority", 100)) environment.mod_dict = self._mod_dict
def on_tick(self, event): vnpy_tick = event.dict_['data'] tick = make_tick(vnpy_tick) if tick['order_book_id'] in self.strategy_subscribed: system_log.debug("on_tick {}", vnpy_tick.__dict__) self._tick_que.put(tick) self._data_factory.put_tick_snapshot(tick)
def set_state(self, state): dict_data = pickle.loads(state) for key, value in six.iteritems(dict_data): try: self.__dict__[key] = pickle.loads(value) system_log.debug("restore context.{} {}", key, type(self.__dict__[key])) except Exception as e: user_system_log.warn('context.{} can not restore', key)
def set_state(self, state): dict_data = jsonpickle.decode(state) for key, value in six.iteritems(dict_data): try: self.__dict__[key] = value system_log.debug("restore context.{} {}", key, type(self.__dict__[key])) except Exception as e: user_system_log.warn('context.{} can not restore', key)
def set_state(self, state): dict_data = pickle.loads(state) for key, value in dict_data.items(): try: self.__dict__[key] = pickle.loads(value) system_log.debug("restore context.{} {}", key, type(self.__dict__[key])) except Exception as e: user_system_log.warn('context.{} can not restore', key)
def set_state(self, state): dict_data = pickle.loads(state) for key, value in six.iteritems(dict_data): try: self.__dict__[key] = pickle.loads(value) system_log.debug("restore g.{} {}", key, type(self.__dict__[key])) except Exception: user_system_log.warn("g.{} restore failed", key)
def _restore_obj(self, key, obj): state = self._persist_provider.load(key) system_log.debug('restore {} with state = {}', key, state) if not state: return False try: obj.set_state(state) except Exception: system_log.exception('restore failed: key={} state={}'.format(key, state)) return True
def wrapped(instrument, frequency, start_dt=None, end_dt=None, length=None): key = (instrument.order_book_id, frequency) if key not in self._caches: self._caches[key] = Cache(self, self.CACHE_LENGTH, instrument, frequency) data = self._caches[key].raw_history_bars(start_dt, end_dt, length) if data is not None: return data else: system_log.debug("缓存未命中: 品种[{}]频率[{}] from {} to {}, length {}".format( instrument.order_book_id, frequency, start_dt, end_dt, length )) return func(instrument, frequency, start_dt=start_dt, end_dt=end_dt, length=length)
def events(self, start_date, end_date, frequency): running = True self.clock_engine_thread.start() self.quotation_engine_thread.start() while running: real_dt = datetime.datetime.now() dt, event = self.event_queue.get() system_log.debug("real_dt {}, dt {}, event {}", real_dt, dt, event) yield dt, dt, event
def on_trade(self, event): vnpy_trade = event.dict_['data'] system_log.debug("on_trade {}", vnpy_trade.__dict__) if not self._account_inited: self._data_factory.cache_vnpy_trade_before_init(vnpy_trade) else: order = self._data_factory.get_order(vnpy_trade) trade = self._data_factory.make_trade(vnpy_trade, order.order_id) account = Environment.get_instance().get_account( order.order_book_id) self._env.event_bus.publish_event( Event(EVENT.TRADE, account=account, trade=trade))
def _get_history_cache(self, instrument): end_date = date.today().replace(month=12, day=31) last_year = timedelta(days=365) bar_data = pd.DataFrame() if self._cache['history_kline'] is None: self._cache['history_kline'] = {} self._cache['history_kline'][instrument.order_book_id] = pd.DataFrame() while bar_data is not None: begin_date = end_date - last_year for i in range(3): system_log.debug("order_book_id3=" + order_book_id + ", dt=" + str(self._dt)) ret_code, bar_data = self._quote_context.get_history_kline( instrument.order_book_id, start=begin_date.strftime('%Y-%m-%d'), end=end_date.strftime('%Y-%m-%d'), ktype='K_DAY') system_log.debug("order_book_id4=" + order_book_id + ", dt=" + str(self._dt)) if ret_code != RET_ERROR: break else: time.sleep(0.1) if ret_code == RET_ERROR or isinstance(bar_data, str): print("get history kline error") if bar_data.empty: return ret_code, self._cache['history_kline'][ instrument.order_book_id] end_date = begin_date # 对数据做处理先做处理再存 del bar_data['code'] # 去掉code for i in range(len(bar_data)): # 时间转换 bar_data.loc[i, 'time_key'] = int(bar_data['time_key'][i].replace( '-', '').replace(' ', '').replace(':', '')) bar_data['volume'] = bar_data['volume'].astype( 'float64') # 把成交量的数据类型转为float bar_data.rename(columns={ 'time_key': 'datetime', 'turnover': 'total_turnover' }, inplace=True) # 将字段名称改为一致的 bar_data = bar_data[::-1] self._cache['history_kline'][ instrument.order_book_id] = self._cache['history_kline'][ instrument.order_book_id].append(bar_data) return ret_code, self._cache['history_kline'][instrument.order_book_id]
def tear_down(self, *args): result = {} for mod_name, __ in reversed(self._mod_list): try: system_log.debug( _(u"mod tear_down [START] {}").format(mod_name)) ret = self._mod_dict[mod_name].tear_down(*args) system_log.debug( _(u"mod tear_down [END] {}").format(mod_name)) except Exception as e: system_log.exception("tear down fail for {}", mod_name) continue if ret is not None: result[mod_name] = ret return result
def publish_settlement(e=None): if e: previous_trading_date = self._env.data_proxy.get_previous_trading_date( e.trading_dt).date() if self._env.trading_dt.date() != previous_trading_date: self._env.trading_dt = datetime.combine( previous_trading_date, self._env.trading_dt.time()) self._env.calendar_dt = datetime.combine( previous_trading_date, self._env.calendar_dt.time()) system_log.debug( "publish settlement events with calendar_dt={}, trading_dt={}". format(self._env.calendar_dt, self._env.trading_dt)) event_bus.publish_event(PRE_SETTLEMENT) event_bus.publish_event(Event(EVENT.SETTLEMENT)) event_bus.publish_event(POST_SETTLEMENT)
def update_bars(self, bars, count): system_log.debug("缓存更新,品种:[{}],时间:[{}, {}]".format(self.instrument.order_book_id, bars["datetime"][0], bars["datetime"][-1])) old = self._data if old is not None and bars is not None: self._data = np.concatenate((self._data, bars), axis=0) else: if old is not None: self._data = old elif bars is not None: self._data = bars # self._data should never be None if self._data is not None and len(self._data) > self._chunk * 2: # 保留两倍缓存长度的空间到内存 left = len(self._data) - self._chunk * 2 self._data = self._data[left:] self._finished = bars is None or len(bars) < count
def events(self, start_date, end_date, frequency): running = True self.clock_engine_thread.start() self.quotation_engine_thread.start() while running: real_dt = datetime.datetime.now() while True: try: dt, event_type = self.event_queue.get(timeout=1) break except Empty: continue system_log.debug("real_dt {}, dt {}, event {}", real_dt, dt, event_type) yield Event(event_type, calendar_dt=real_dt, trading_dt=dt)
def events(self, start_date, end_date, frequency): if not self._mod_config.all_day: while datetime.now().date() < start_date - timedelta(days=1): continue mark_time_thread = Thread(target=self.mark_time_period, args=(start_date, date.fromtimestamp(2147483647))) mark_time_thread.setDaemon(True) mark_time_thread.start() while True: if self._time_period == TimePeriod.BEFORE_TRADING: if self._after_trading_processed: self._after_trading_processed = False if not self._before_trading_processed: system_log.debug("VNPYEventSource: before trading event") yield Event(EVENT.BEFORE_TRADING, calendar_dt=datetime.now(), trading_dt=datetime.now() + timedelta(days=1)) self._before_trading_processed = True continue else: continue elif self._time_period == TimePeriod.TRADING: if not self._before_trading_processed: system_log.debug("VNPYEventSource: before trading event") yield Event(EVENT.BEFORE_TRADING, calendar_dt=datetime.now(), trading_dt=datetime.now() + timedelta(days=1)) self._before_trading_processed = True continue else: tick = self._engine.get_tick() calendar_dt = tick['datetime'] if calendar_dt.hour > 20: trading_dt = calendar_dt + timedelta(days=1) else: trading_dt = calendar_dt system_log.debug("VNPYEventSource: tick {}", tick) yield Event(EVENT.TICK, calendar_dt=calendar_dt, trading_dt=trading_dt, tick=RqAttrDict(tick)) elif self._time_period == TimePeriod.AFTER_TRADING: if self._before_trading_processed: self._before_trading_processed = False if not self._after_trading_processed: system_log.debug("VNPYEventSource: after trading event") yield Event(EVENT.AFTER_TRADING, calendar_dt=datetime.now(), trading_dt=datetime.now()) self._after_trading_processed = True else: continue
def raw_history_bars(self, start_dt=None, end_dt=None, length=None, updated=False): bars = self._data if bars is not None: if end_dt: end_dti = np.uint64(convert_dt_to_int(end_dt)) end_pos = bars["datetime"].searchsorted(end_dti, side="right") if start_dt: start_dti = np.uint64(convert_dt_to_int(start_dt)) start_pos = bars["datetime"].searchsorted(start_dti, side="left") if start_dt and end_dt: if end_pos < len(bars) or bars[-1]["datetime"] == end_dti: if start_pos == 0 and bars[0][ "datetime"] != start_dti: # start datetime is early than cache return None else: return bars[start_pos:end_pos] # else update the cache elif length is not None: if end_dt: if end_pos < len(bars) or bars[-1]["datetime"] == end_dti: if end_pos - length < 0: return None else: return bars[end_pos - length:end_pos] # else update the cache elif start_dt: if start_pos == 0 and bars[0]["datetime"] != start_dti: return None if start_pos + length <= len(bars): return bars[start_pos:start_pos + length] # else update the cache # update the cache system_log.debug("缓存更新") if not self._finished and not updated: self._source.update_cache(self, end_dt or start_dt) return self.raw_history_bars(start_dt, end_dt, length, updated=True) return None
def get_trans(cls, lc: Optional[str], trans_dir=None): if lc is not None and "cn" in lc.lower(): locales = ["zh_Hans_CN"] try: if trans_dir is None: trans_dir = os.path.join( os.path.dirname(os.path.abspath(__file__)), "translations") return translation( domain="messages", localedir=trans_dir, languages=locales, ) except Exception as e: system_log.debug(e) return NullTranslations() else: return NullTranslations()
def on_order(self, event): vnpy_order = event.dict_['data'] system_log.debug("on_order {}", vnpy_order.__dict__) # FIXME 发现订单会重复返回,此操作是否会导致订单丢失有待验证 if vnpy_order.status == STATUS_UNKNOWN: return vnpy_order_id = vnpy_order.vtOrderID order = self._data_factory.get_order(vnpy_order) if not self._account_inited: self._data_factory.cache_vnpy_order_before_init(vnpy_order) else: account = Environment.get_instance().get_account( order.order_book_id) order.active() self._env.event_bus.publish_event( Event(EVENT.ORDER_CREATION_PASS, account=account, order=order)) self._data_factory.cache_vnpy_order(order.order_id, vnpy_order) if vnpy_order.status == STATUS_NOTTRADED or vnpy_order.status == STATUS_PARTTRADED: self._data_factory.cache_open_order(vnpy_order_id, order) elif vnpy_order.status == STATUS_ALLTRADED: self._data_factory.del_open_order(vnpy_order_id) elif vnpy_order.status == STATUS_CANCELLED: self._data_factory.del_open_order(vnpy_order_id) if order.status == ORDER_STATUS.PENDING_CANCEL: order.mark_cancelled( "%d order has been cancelled by user." % order.order_id) self._env.event_bus.publish_event( Event(EVENT.ORDER_CANCELLATION_PASS, account=account, order=order)) else: order.mark_rejected( 'Order was rejected or cancelled by vnpy.') self._env.event_bus.publish_event( Event(EVENT.ORDER_UNSOLICITED_UPDATE, account=account, order=order))
def events(self, start_date, end_date, frequency): running = True self.clock_engine_thread.start() if not self.mod_config.redis_uri: self.quotation_engine_thread.start() while running: real_dt = datetime.datetime.now() while True: try: dt, event_type = self.event_queue.get(timeout=1) break except Empty: continue system_log.debug("real_dt {}, dt {}, event {}", real_dt, dt, event_type) yield Event(event_type, calendar_dt=real_dt, trading_dt=dt)
def set_env(self, environment): self._env = environment config = environment.config for mod_name in config.mod.__dict__: mod_config = getattr(config.mod, mod_name) if not mod_config.enabled: continue self._mod_list.append((mod_name, mod_config)) self._mod_list.sort(key=lambda item: item[1].priority) for mod_name, mod_config in self._mod_list: system_log.debug('loading mod {}', mod_name) mod_module = import_module(mod_config.lib) mod = mod_module.load_mod() self._mod_dict[mod_name] = mod environment.mod_dict = self._mod_dict
def set_locale(self, locales, trans_dir=None): if locales[0] is None or "en" in locales[0].lower(): self.trans = NullTranslations() return if "cn" in locales[0].lower(): locales = ["zh_Hans_CN"] try: if trans_dir is None: trans_dir = os.path.join( os.path.dirname( os.path.abspath( __file__, ), ), "translations" ) self.trans = translation( domain="messages", localedir=trans_dir, languages=locales, ) except Exception as e: system_log.debug(e) self.trans = NullTranslations()
def wrapper(*args, **kwargs): if not Environment.get_instance().config.extra.is_hold: return func(*args, **kwargs) else: system_log.debug(_(u"not run {}({}, {}) because strategy is hold").format(func, args, kwargs))
def run(config, source_code=None, user_funcs=None): env = Environment(config) persist_helper = None init_succeed = False mod_handler = ModHandler() try: # avoid register handlers everytime # when running in ipython set_loggers(config) basic_system_log.debug("\n" + pformat(config.convert_to_dict())) if source_code is not None: env.set_strategy_loader(SourceCodeStrategyLoader(source_code)) elif user_funcs is not None: env.set_strategy_loader(UserFuncStrategyLoader(user_funcs)) else: env.set_strategy_loader(FileStrategyLoader(config.base.strategy_file)) env.set_global_vars(GlobalVars()) mod_handler.set_env(env) mod_handler.start_up() if not env.data_source: env.set_data_source(BaseDataSource(config.base.data_bundle_path)) env.set_data_proxy(DataProxy(env.data_source)) Scheduler.set_trading_dates_(env.data_source.get_trading_calendar()) scheduler = Scheduler(config.base.frequency) mod_scheduler._scheduler = scheduler env._universe = StrategyUniverse() _adjust_start_date(env.config, env.data_proxy) _validate_benchmark(env.config, env.data_proxy) # FIXME start_dt = datetime.datetime.combine(config.base.start_date, datetime.datetime.min.time()) env.calendar_dt = start_dt env.trading_dt = start_dt broker = env.broker assert broker is not None env.portfolio = broker.get_portfolio() env.benchmark_portfolio = create_benchmark_portfolio(env) event_source = env.event_source assert event_source is not None bar_dict = BarMap(env.data_proxy, config.base.frequency) env.set_bar_dict(bar_dict) if env.price_board is None: from .core.bar_dict_price_board import BarDictPriceBoard env.price_board = BarDictPriceBoard() ctx = ExecutionContext(const.EXECUTION_PHASE.GLOBAL) ctx._push() env.event_bus.publish_event(Event(EVENT.POST_SYSTEM_INIT)) scope = create_base_scope() scope.update({ "g": env.global_vars }) apis = api_helper.get_apis() scope.update(apis) scope = env.strategy_loader.load(scope) if env.config.extra.enable_profiler: enable_profiler(env, scope) ucontext = StrategyContext() user_strategy = Strategy(env.event_bus, scope, ucontext) scheduler.set_user_context(ucontext) if not config.extra.force_run_init_when_pt_resume: with run_with_user_log_disabled(disabled=config.base.resume_mode): user_strategy.init() if config.extra.context_vars: for k, v in six.iteritems(config.extra.context_vars): setattr(ucontext, k, v) if config.base.persist: persist_provider = env.persist_provider persist_helper = PersistHelper(persist_provider, env.event_bus, config.base.persist_mode) persist_helper.register('core', CoreObjectsPersistProxy(scheduler)) persist_helper.register('user_context', ucontext) persist_helper.register('global_vars', env.global_vars) persist_helper.register('universe', env._universe) if isinstance(event_source, Persistable): persist_helper.register('event_source', event_source) persist_helper.register('portfolio', env.portfolio) if env.benchmark_portfolio: persist_helper.register('benchmark_portfolio', env.benchmark_portfolio) for name, module in six.iteritems(env.mod_dict): if isinstance(module, Persistable): persist_helper.register('mod_{}'.format(name), module) # broker will restore open orders from account if isinstance(broker, Persistable): persist_helper.register('broker', broker) persist_helper.restore() env.event_bus.publish_event(Event(EVENT.POST_SYSTEM_RESTORED)) init_succeed = True # When force_run_init_when_pt_resume is active, # we should run `init` after restore persist data if config.extra.force_run_init_when_pt_resume: assert config.base.resume_mode == True with run_with_user_log_disabled(disabled=False): user_strategy.init() from .core.executor import Executor Executor(env).run(bar_dict) if env.profile_deco: output_profile_result(env) except CustomException as e: if init_succeed and env.config.base.persist and persist_helper: persist_helper.persist() code = _exception_handler(e) mod_handler.tear_down(code, e) except Exception as e: if init_succeed and env.config.base.persist and persist_helper: persist_helper.persist() exc_type, exc_val, exc_tb = sys.exc_info() user_exc = create_custom_exception(exc_type, exc_val, exc_tb, config.base.strategy_file) code = _exception_handler(user_exc) mod_handler.tear_down(code, user_exc) else: result = mod_handler.tear_down(const.EXIT_CODE.EXIT_SUCCESS) system_log.debug(_(u"strategy run successfully, normal exit")) return result
def parse_config(config_args, config_path=None, click_type=False, source_code=None, user_funcs=None): conf = default_config() deep_update(user_config(), conf) deep_update(project_config(), conf) if 'base__strategy_file' in config_args and config_args['base__strategy_file']: # FIXME: ugly, we need this to get code conf['base']['strategy_file'] = config_args['base__strategy_file'] elif ('base' in config_args and 'strategy_file' in config_args['base'] and config_args['base']['strategy_file']): conf['base']['strategy_file'] = config_args['base']['strategy_file'] if user_funcs is None: for k, v in six.iteritems(code_config(conf, source_code)): if k in conf['whitelist']: deep_update(v, conf[k]) mod_configs = config_args.pop('mod_configs', []) for k, v in mod_configs: key = 'mod__{}'.format(k.replace('.', '__')) config_args[key] = mod_config_value_parse(v) if click_type: for k, v in six.iteritems(config_args): if v is None: continue if k == 'base__accounts' and not v: continue key_path = k.split('__') sub_dict = conf for p in key_path[:-1]: if p not in sub_dict: sub_dict[p] = {} sub_dict = sub_dict[p] sub_dict[key_path[-1]] = v else: deep_update(config_args, conf) config = RqAttrDict(conf) set_locale(config.extra.locale) def _to_date(v): return pd.Timestamp(v).date() config.base.start_date = _to_date(config.base.start_date) config.base.end_date = _to_date(config.base.end_date) if config.base.data_bundle_path is None: config.base.data_bundle_path = os.path.join(os.path.expanduser(rqalpha_path), "bundle") config.base.run_type = parse_run_type(config.base.run_type) config.base.accounts = parse_accounts(config.base.accounts) config.base.persist_mode = parse_persist_mode(config.base.persist_mode) if config.extra.context_vars: if isinstance(config.extra.context_vars, six.string_types): config.extra.context_vars = json.loads(to_utf8(config.extra.context_vars)) system_log.level = getattr(logbook, config.extra.log_level.upper(), logbook.NOTSET) std_log.level = getattr(logbook, config.extra.log_level.upper(), logbook.NOTSET) user_log.level = getattr(logbook, config.extra.log_level.upper(), logbook.NOTSET) user_system_log.level = getattr(logbook, config.extra.log_level.upper(), logbook.NOTSET) if config.base.frequency == "1d": logger.DATETIME_FORMAT = "%Y-%m-%d" system_log.debug("\n" + pformat(config.convert_to_dict())) return config