parser.add_argument("--debug", action='store_true', help="Set to debug mode. Example --debug'") args = parser.parse_args() print("input args: {}".format(args)) if args.debug is not None: logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) logger.debug('Debug mode on') universe = [] with open(args.universe_file, 'r') as f: reader = csv.reader(f, delimiter=',') universe = list(reader) if len(universe) != 1: raise Exception( "universe file should contain 1 line, current results: {}".format( universe)) universe = universe[0] print("universe: {}".format(universe)) gqdata = GQData() df = gqdata.get_data(universe, args.freq, args.startdate, end_date=args.enddate, datasource=args.source, data_type=DATATYPE_TICKER)
class GQBacktest(object): def __init__( self, algo: GQAlgo, symbols, start_datetime, end_datetime, data_freq=FREQ_DAY, initial_cash=10000, ): self.algo = algo self.symbols = symbols self.start_datetime = start_datetime self.end_datetime = end_datetime self.data_freq = data_freq self.datasource = algo.datasource self.initial_cash = initial_cash self.data = GQData() freq_map = { FREQ_DAY: Frequency.DAY, FREQ_MINUTE: Frequency.MINUTE, } self.datafeed = csvfeed.GenericBarFeed(freq_map[data_freq]) self._load_datafeed() self.my_strategy = MyStrategy(self.datafeed, self.symbols, self.algo, self.initial_cash) self.returnsAnalyzer = returns.Returns() self.my_strategy.attachAnalyzer(self.returnsAnalyzer) def _load_datafeed(self): # loading all data self.data.get_data(symbols=self.symbols, freq=self.data_freq, start_date=self.start_datetime, end_date=self.end_datetime, datasource=self.datasource, use_cache=True, fill_nan_method="ffill") data_files = {} for symbol in self.symbols: data_key = self.data.get_data_key( symbol=symbol, freq=self.data_freq, start_date=self.start_datetime, end_date=self.end_datetime, ) file_path = self.data.get_data_file_path(data_key) data_files[symbol] = file_path logger.debug("add csv file {} into data feed".format(file_path)) self.datafeed.addBarsFromCSV(symbol, file_path, skipMalformedBars=True) def run(self, plot=True): if plot: backtest_plt = plotter.StrategyPlotter(self.my_strategy) backtest_plt.getOrCreateSubplot("returns").addDataSeries( "Simple returns", self.returnsAnalyzer.getReturns()) tmp = self.returnsAnalyzer.getReturns() print(tmp) self.my_strategy.run() logger.info("Final portfolio value: $%.2f" % self.my_strategy.getBroker().getEquity()) if plot: if len(self.algo.metrics) > 0: for k in self.algo.metrics: ds, fig_idx = self.algo.metrics.get(k) plt.figure(fig_idx) ds.plot(legend=True) backtest_plt.plot()
class GQAlgo(object): def __init__(self, trading_platform, datasource): self.trading_platform = trading_platform self.datasource = datasource self.account = get_account_class(trading_platform) self.data = GQData() self.t = None # current time in UTC self.backtest_strategy = None self.metrics = {} self.init() def init(self): pass def run(self) -> [list]: raise NotImplementedError def get_trading_platform(self): return self.trading_platform def get_time(self): return self.t def get_cash(self): """ get current cash in USD :return: """ return self.account.get_cash() def get_positions(self): """ get current positions :return: dict symbol->position """ return self.account.get_positions() def algo_get_data(self, symbols, interval_timedelta, freq, fill_nan_method=None, remove_nan_rows=True): """ get data until now (time t, get from get_time()) :param symbols: list list of symbols :param interval_timedelta: deltatime used to calculate start time :param freq: string day, minute data level :param fill_nan_method: string fill nan method, default not fill, see more parameters here: https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.fillna.html :param remove_nan_rows: bool remove nan rows after fill nan :return: """ end_datetime = datetime.now(timezone.utc) if self.trading_platform == TRADING_BACKTEST: end_datetime = self.get_time() if end_datetime is None: raise ValueError("please run prerun() function first") start_datetime = end_datetime - interval_timedelta data = self.data.get_data(symbols=symbols, freq=freq, start_date=start_datetime, end_date=end_datetime, datasource=self.datasource, dict_output=True, fill_nan_method=fill_nan_method, remove_nan_rows=remove_nan_rows ) return data def init_backtest(self, strategy: strategy.BacktestingStrategy): self.backtest_strategy = strategy self.account.set_backtest_strategy(strategy) def prerun(self, t, verbose=True): if verbose: msg = "=============\nAlgorithm Time: {}\nCash: {}\nPositions: {}\n".format( t, self.get_cash(), self.get_positions() ) logger.info(msg) self.t = t def record_metric(self, key, value, figure_group=1): if figure_group < 0: raise ValueError("figure_grouup only can be positive, get figure_group {}".format(figure_group)) cur_data_serise, _ = self.metrics.get(key, (pd.Series([], name=key), figure_group)) cur_data_serise = cur_data_serise.append(pd.Series([value], index=[self.t], name=key)) self.metrics[key] = (cur_data_serise, figure_group)