def _initialize(self, strategy=None, start=None, end=None, timezone=None, commission=None): self.start = start self.end = end self.commission = commission self.timezone=timezone self.portfolio = Portfolio(strategy, start, end, commission, self.frequency, self.refresh_rate, self.trigger_time) self.context = self.portfolio.context # 最小精度 self.context.min_order = {} for univ in self.universe: exchange = univ.split('.')[1] self.context.min_order[exchange] = {} # print(self.context.min_order) for univ in self.universe: exchange = univ.split('.')[1] symbol = univ.split('.')[0].lower() min_order_qty, min_order_amount = get_min_order(exchange, symbol) self.context.min_order[exchange][symbol] = {'min_order_qty':min_order_qty, 'min_order_amount':min_order_amount} # 记录universe所有的asset self.universe_assets = list() for univ in self.universe: base_asset = univ.split('.')[0].split('/')[0].lower() quote_asset = univ.split('.')[0].split('/')[1].lower() if base_asset not in self.universe_assets: self.universe_assets.append(base_asset) if quote_asset not in self.universe_assets: self.universe_assets.append(quote_asset) self.all_assets = list(set(AccountManager().asset_varieties + self.universe_assets)) # print(self.all_assets) # # 初始每个账户添加除初始化之外的universe的asset # for name,account in self.portfolio.accounts.items(): # for asset in universe_assets: # if asset not in account.current_position.keys(): # account.current_position[asset] = SpotPosition('spot', asset, 0, 0, 0, 0) self.total_dates = get_total_dates(self.frequency, 1, self.trigger_time, self.start, self.end) self.strategy_dates = get_total_dates(self.frequency, self.refresh_rate, self.trigger_time, self.start, self.end) if self.frequency in ['d','1d','day','1day']: self.day_date=get_day_dates(self.start, self.end,self.refresh_rate) else: self.day_date=get_day_dates(self.start, self.end) # 先输出report的dates信息,画图用 output({"display_type": "strategy", "dates":self.day_date}) try: self.initialize(self.context) except Exception: output({"display_type": "error", "error_msg": traceback.format_exc()}) return 1
class Strategy(object): def __new__(cls, *args, **kwargs): if not hasattr(cls, '_instance'): _instance = super(Strategy, cls).__new__(cls) cls._instance = _instance return cls._instance def __init__(self, initialize=None, handle_data=None, universe=None, benchmark=None, freq=None, refresh_rate=None): """ :param initialize: initialize :param handle_data: handle_data :param universe: ('BTC/USDT.okex', 'ETH/BTC.okex', 'XBT/USD.bitmex') :param benchmark: 'csi5' # 策略参考标准 :param freq: 'd', # 'd'日线回测,'m'15分钟线回测 :param refresh_rate: 1 or (1,['08:00:00']) or (1,['08:00:00','18:00:00']), # 调仓时间间隔,若freq = 'd'的单位为交易日,若freq = 'm'时间间隔为分钟 :param trigger_time: '08:00' # 日级别调仓时间,时区为UTC """ self.initialize = initialize self.handle_data = handle_data self.universe = universe self.benchmark = benchmark self.frequency = freq # self.refresh_rate = refresh_rate self.refresh_rate = refresh_rate[0] if isinstance(refresh_rate, (tuple)) else refresh_rate self.trigger_time = refresh_rate[1] if isinstance(refresh_rate, (tuple)) else None self.start = None self.end = None self.commission = None # self.asset_varieties = None self.portfolio = None self.context = None self.cache_data = None self.universe_assets = None self.all_assets = None self.total_dates = None self.strategy_dates = None def _initialize(self, strategy=None, start=None, end=None, timezone=None, commission=None): self.start = start self.end = end self.commission = commission self.timezone=timezone self.portfolio = Portfolio(strategy, start, end, commission, self.frequency, self.refresh_rate, self.trigger_time) self.context = self.portfolio.context # 最小精度 self.context.min_order = {} for univ in self.universe: exchange = univ.split('.')[1] self.context.min_order[exchange] = {} # print(self.context.min_order) for univ in self.universe: exchange = univ.split('.')[1] symbol = univ.split('.')[0].lower() min_order_qty, min_order_amount = get_min_order(exchange, symbol) self.context.min_order[exchange][symbol] = {'min_order_qty':min_order_qty, 'min_order_amount':min_order_amount} # 记录universe所有的asset self.universe_assets = list() for univ in self.universe: base_asset = univ.split('.')[0].split('/')[0].lower() quote_asset = univ.split('.')[0].split('/')[1].lower() if base_asset not in self.universe_assets: self.universe_assets.append(base_asset) if quote_asset not in self.universe_assets: self.universe_assets.append(quote_asset) self.all_assets = list(set(AccountManager().asset_varieties + self.universe_assets)) # print(self.all_assets) # # 初始每个账户添加除初始化之外的universe的asset # for name,account in self.portfolio.accounts.items(): # for asset in universe_assets: # if asset not in account.current_position.keys(): # account.current_position[asset] = SpotPosition('spot', asset, 0, 0, 0, 0) self.total_dates = get_total_dates(self.frequency, 1, self.trigger_time, self.start, self.end) self.strategy_dates = get_total_dates(self.frequency, self.refresh_rate, self.trigger_time, self.start, self.end) if self.frequency in ['d','1d','day','1day']: self.day_date=get_day_dates(self.start, self.end,self.refresh_rate) else: self.day_date=get_day_dates(self.start, self.end) # 先输出report的dates信息,画图用 output({"display_type": "strategy", "dates":self.day_date}) try: self.initialize(self.context) except Exception: output({"display_type": "error", "error_msg": traceback.format_exc()}) return 1 async def _handle_data(self): # print("******************************************************************") clock = self.context.clock current_date = clock.current_date logger.current_date = current_date current_timestamp = clock.current_timestamp # print(current_date) prebars=clock.pre_bar pre_start = str(str2datetime(self.start) - datetime.timedelta(days=prebars)) try: if not self.cache_data: self.cache_data = self.context.prepare_data(self.universe, self.all_assets, self.benchmark, self.frequency, pre_start, self.end, self.timezone) except Exception: output({"display_type": "error", "error_msg": traceback.format_exc()}) return 1 # 记录初始持仓 start = self.start + ' ' + self.trigger_time[0] if self.trigger_time else self.start # 只支持一个trigger_time if current_timestamp == str2timestamp(start): for name, account in self.portfolio.accounts.items(): for asset, position in account.current_position.items(): position.context = self.context init_portfolio_position = dict() for asset in self.context.asset_varieties: init_pp_detail = PortfolioPosition(self.context, asset, self.portfolio.accounts).detail() init_portfolio_position[asset] = init_pp_detail self.context.init_portfolio_position = init_portfolio_position self.context.init_position_total = sum([asset['total_amount'] for asset in init_portfolio_position.values()]) self.context.init_total_account_position = {} for name in self.context.accounts_name: name_position = "{}_position".format(name) total_amount = 0 for asset,value in self.context.init_portfolio_position.items(): if name in value['consist_of'].keys(): total_amount += value['consist_of'][name]['amount'] self.context.init_total_account_position[name_position] = total_amount # logger.info('【记录初始持仓时间】:{} s'.format(str(time.time() - start2)[:5])) try: if str(current_date) in self.strategy_dates: self.handle_data(self.context) except Exception: output({"display_type": "error", "error_msg": traceback.format_exc()}) return 1 self.portfolio.record_history() # logger.info('【_handle_data耗时】:{} s'.format(str(time.time() - start2)[:5])) # print("******************************************************************", end='\n\n') def simple_report(self): report = SimpleReport(self.portfolio) simple_report = report.run() output(simple_report) def complete_report(self): report = CompleteReport(self.portfolio) complete_report = report.run() output(complete_report) def simplereport_excel(self): report = SimpleReport(self.portfolio) simple_report = report.run() output(simple_report) # 回测报告以excel表格形式输出 workbook=xlwt.Workbook(encoding='utf-8') sheet1=workbook.add_sheet('sheet1',cell_overwrite_ok=True) c=0 for key,value in simple_report.items(): sheet1.write(0,c,key) if isinstance(value,list): s1=len(value) for s2 in range(0,s1): temp=value[s2] sheet1.write(s2+1,c,temp) else: sheet1.write(1,c,value) c+=1 root1=os.getcwd() #获得当前目录路径,用来存放回测报告 adress1=root1+'/'+'simple_report.xls' #保存报告的地址 print(adress1) workbook.save(adress1) #净值图形输出 plt.plot(simple_report["cumulative_returns"]) plt.title("Cumulative_returns",fontsize='large') plt.show()