def _create_generator(self, sim_params, source_filter=None): """ Create a basic generator setup using the sources and transforms attached to this algorithm. ::source_filter:: is a method that receives events in date sorted order, and returns True for those events that should be processed by the zipline, and False for those that should be skipped. """ sim_params.data_frequency = self.data_frequency # perf_tracker will be instantiated in __init__ if a sim_params # is passed to the constructor. If not, we instantiate here. if self.perf_tracker is None: self.perf_tracker = PerformanceTracker(sim_params) self.data_gen = self._create_data_generator(source_filter, sim_params) self.trading_client = AlgorithmSimulator(self, sim_params) transact_method = transact_partial(self.slippage, self.commission) self.set_transact(transact_method) return self.trading_client.transform(self.data_gen)
def _create_generator(self, sim_params, source_filter=None): """ Create a basic generator setup using the sources to this algorithm. ::source_filter:: is a method that receives events in date sorted order, and returns True for those events that should be processed by the zipline, and False for those that should be skipped. """ if not self.initialized: self.initialize(*self.initialize_args, **self.initialize_kwargs) self.initialized = True if self.perf_tracker is None: # HACK: When running with the `run` method, we set perf_tracker to # None so that it will be overwritten here. self.perf_tracker = PerformanceTracker(sim_params) self.portfolio_needs_update = True self.account_needs_update = True self.performance_needs_update = True self.data_gen = self._create_data_generator(source_filter, sim_params) self.trading_client = AlgorithmSimulator(self, sim_params) transact_method = transact_partial(self.slippage, self.commission) self.set_transact(transact_method) return self.trading_client.transform(self.data_gen)
def __init__(self, algo, environment): self.algo = algo self.environment = environment self.ordering_client = TransactionSimulator() self.perf_tracker = PerformanceTracker(self.environment) self.algo_start = self.environment.first_open self.algo_sim = AlgorithmSimulator(self.ordering_client, self.algo, self.algo_start)
def __init__(self, algo, sim_params): self.algo = algo self.sim_params = sim_params self.ordering_client = TransactionSimulator() self.perf_tracker = PerformanceTracker(self.sim_params) self.algo_start = self.sim_params.first_open self.algo_sim = AlgorithmSimulator(self.ordering_client, self.perf_tracker, self.algo, self.algo_start)
def __init__(self, algo, sim_params, blotter=None): self.algo = algo self.sim_params = sim_params if not blotter: self.blotter = Blotter() self.perf_tracker = PerformanceTracker(self.sim_params) self.algo_start = self.sim_params.first_open self.algo_sim = AlgorithmSimulator(self.blotter, self.perf_tracker, self.algo, self.algo_start)
def _create_generator(self, sim_params, source_filter=None): """ Create a basic generator setup using the sources to this algorithm. ::source_filter:: is a method that receives events in date sorted order, and returns True for those events that should be processed by the zipline, and False for those that should be skipped. """ if not self.initialized: self.initialize(*self.initialize_args, **self.initialize_kwargs) self.initialized = True if self.perf_tracker is None: # HACK: When running with the `run` method, we set perf_tracker to # None so that it will be overwritten here. self.perf_tracker = PerformanceTracker( sim_params=sim_params, env=self.trading_environment ) self.portfolio_needs_update = True self.account_needs_update = True self.performance_needs_update = True self.data_gen = self._create_data_generator(source_filter, sim_params) self.trading_client = AlgorithmSimulator(self, sim_params) transact_method = transact_partial(self.slippage, self.commission) self.set_transact(transact_method) return self.trading_client.transform(self.data_gen)
def test_bts_simulation_dt(self): code = """ def initialize(context): pass """ algo = TradingAlgorithm(script=code, sim_params=self.sim_params, env=self.env) algo.perf_tracker = PerformanceTracker( sim_params=self.sim_params, trading_calendar=self.trading_calendar, asset_finder=self.asset_finder, ) dt = pd.Timestamp("2016-08-04 9:13:14", tz='US/Eastern') algo_simulator = AlgorithmSimulator( algo, self.sim_params, self.data_portal, BeforeTradingStartsOnlyClock(dt), algo._create_benchmark_source(), NoRestrictions(), None ) # run through the algo's simulation list(algo_simulator.transform()) # since the clock only ever emitted a single before_trading_start # event, we can check that the simulation_dt was properly set self.assertEqual(dt, algo_simulator.simulation_dt)
def _create_generator(self, sim_params, source_filter=None): """ Create a basic generator setup using the sources and transforms attached to this algorithm. ::source_filter:: is a method that receives events in date sorted order, and returns True for those events that should be processed by the zipline, and False for those that should be skipped. """ # Instantiate perf_tracker self.perf_tracker = PerformanceTracker(sim_params) self.portfolio_needs_update = True self.data_gen = self._create_data_generator(source_filter, sim_params) self.trading_client = AlgorithmSimulator(self, sim_params) transact_method = transact_partial(self.slippage, self.commission) self.set_transact(transact_method) return self.trading_client.transform(self.data_gen)
def __init__(self, algo, sim_params): # ============== # Simulation # Param Setup # ============== self.sim_params = sim_params # ============== # Perf Tracker # Setup # ============== self.perf_tracker = PerformanceTracker(self.sim_params) self.perf_key = self.EMISSION_TO_PERF_KEY_MAP[ self.perf_tracker.emission_rate] # ============== # Algo Setup # ============== self.algo = algo self.algo_start = self.sim_params.first_open self.algo_start = self.algo_start.replace(hour=0, minute=0, second=0, microsecond=0) # ============== # Snapshot Setup # ============== # The algorithm's data as of our most recent event. # We want an object that will have empty objects as default # values on missing keys. self.current_data = BarData() # We don't have a datetime for the current snapshot until we # receive a message. self.simulation_dt = None self.snapshot_dt = None # ============= # Logging Setup # ============= # Processor function for injecting the algo_dt into # user prints/logs. def inject_algo_dt(record): if not 'algo_dt' in record.extra: record.extra['algo_dt'] = self.snapshot_dt self.processor = Processor(inject_algo_dt)
def __init__(self, algo, sim_params): self.algo = algo self.sim_params = sim_params self.ordering_client = TransactionSimulator() self.perf_tracker = PerformanceTracker(self.sim_params) self.algo_start = self.sim_params.first_open self.algo_sim = AlgorithmSimulator( self.ordering_client, self.perf_tracker, self.algo, self.algo_start )
def __init__(self, algo, sim_params, blotter=None): self.algo = algo self.sim_params = sim_params if not blotter: self.blotter = Blotter() self.perf_tracker = PerformanceTracker(self.sim_params) self.algo_start = self.sim_params.first_open self.algo_sim = AlgorithmSimulator( self.blotter, self.perf_tracker, self.algo, self.algo_start )
class TradingAlgorithm(object): """ Base class for trading algorithms. Inherit and overload initialize() and handle_data(data). A new algorithm could look like this: ``` class MyAlgo(TradingAlgorithm): def initialize(self, sids, amount): self.sids = sids self.amount = amount def handle_data(self, data): sid = self.sids[0] amount = self.amount self.order(sid, amount) ``` To then to run this algorithm: my_algo = MyAlgo([0], 100) # first argument has to be list of sids stats = my_algo.run(data) """ def __init__(self, *args, **kwargs): """Initialize sids and other state variables. :Arguments: data_frequency : str (daily, hourly or minutely) The duration of the bars. annualizer : int <optional> Which constant to use for annualizing risk metrics. If not provided, will extract from data_frequency. capital_base : float <default: 1.0e5> How much capital to start with. """ self._portfolio = None self.datetime = None self.registered_transforms = {} self.transforms = [] self.sources = [] self._recorded_vars = {} self.logger = None self.benchmark_return_source = None # default components for transact self.slippage = VolumeShareSlippage() self.commission = PerShare() self.leverage = NullLeverage() if 'data_frequency' in kwargs: self.set_data_frequency(kwargs.pop('data_frequency')) else: self.data_frequency = None # Override annualizer if set if 'annualizer' in kwargs: self.annualizer = kwargs['annualizer'] # set the capital base self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE) self.sim_params = kwargs.pop('sim_params', None) if self.sim_params: self.sim_params.data_frequency = self.data_frequency self.blotter = kwargs.pop('blotter', None) if not self.blotter: self.blotter = Blotter() # an algorithm subclass needs to set initialized to True when # it is fully initialized. self.initialized = False # call to user-defined constructor method self.initialize(*args, **kwargs) def __repr__(self): """ N.B. this does not yet represent a string that can be used to instantiate an exact copy of an algorithm. However, it is getting close, and provides some value as something that can be inspected interactively. """ return """ {class_name}( capital_base={capital_base} sim_params={sim_params}, initialized={initialized}, slippage={slippage}, commission={commission}, blotter={blotter}, recorded_vars={recorded_vars}) """.strip().format(class_name=self.__class__.__name__, capital_base=self.capital_base, sim_params=repr(self.sim_params), initialized=self.initialized, slippage=repr(self.slippage), commission=repr(self.commission), blotter=repr(self.blotter), recorded_vars=repr(self.recorded_vars)) def _create_data_generator(self, source_filter, sim_params): """ Create a merged data generator using the sources and transforms attached to this algorithm. ::source_filter:: is a method that receives events in date sorted order, and returns True for those events that should be processed by the zipline, and False for those that should be skipped. """ if self.benchmark_return_source is None: benchmark_return_source = [ Event({'dt': ret.date, 'returns': ret.returns, 'type': zipline.protocol.DATASOURCE_TYPE.BENCHMARK, 'source_id': 'benchmarks'}) for ret in trading.environment.benchmark_returns if ret.date.date() >= sim_params.period_start.date() and ret.date.date() <= sim_params.period_end.date() ] else: benchmark_return_source = self.benchmark_return_source date_sorted = date_sorted_sources(*self.sources) if source_filter: date_sorted = ifilter(source_filter, date_sorted) with_tnfms = sequential_transforms(date_sorted, *self.transforms) with_alias_dt = alias_dt(with_tnfms) with_benchmarks = date_sorted_sources(benchmark_return_source, with_alias_dt) # Group together events with the same dt field. This depends on the # events already being sorted. return groupby(with_benchmarks, attrgetter('dt')) def _create_generator(self, sim_params, source_filter=None): """ Create a basic generator setup using the sources and transforms attached to this algorithm. ::source_filter:: is a method that receives events in date sorted order, and returns True for those events that should be processed by the zipline, and False for those that should be skipped. """ sim_params.data_frequency = self.data_frequency self.data_gen = self._create_data_generator(source_filter, sim_params) self.perf_tracker = PerformanceTracker(sim_params) self.trading_client = AlgorithmSimulator(self, sim_params) transact_method = transact_partial(self.slippage, self.commission) self.set_transact(transact_method) self.blotter.leverage = leverage_partial(self.leverage, self.perf_tracker.get_portfolio()) return self.trading_client.transform(self.data_gen) def get_generator(self): """ Override this method to add new logic to the construction of the generator. Overrides can use the _create_generator method to get a standard construction generator. """ return self._create_generator(self.sim_params) def initialize(self, *args, **kwargs): pass # TODO: make a new subclass, e.g. BatchAlgorithm, and move # the run method to the subclass, and refactor to put the # generator creation logic into get_generator. def run(self, source, sim_params=None, benchmark_return_source=None): """Run the algorithm. :Arguments: source : can be either: - pandas.DataFrame - zipline source - list of zipline sources If pandas.DataFrame is provided, it must have the following structure: * column names must consist of ints representing the different sids * index must be DatetimeIndex * array contents should be price info. :Returns: daily_stats : pandas.DataFrame Daily performance metrics such as returns, alpha etc. """ if isinstance(source, (list, tuple)): assert self.sim_params is not None or sim_params is not None, \ """When providing a list of sources, \ sim_params have to be specified as a parameter or in the constructor.""" elif isinstance(source, pd.DataFrame): # if DataFrame provided, wrap in DataFrameSource source = DataFrameSource(source) elif isinstance(source, pd.Panel): source = DataPanelSource(source) if not isinstance(source, (list, tuple)): self.sources = [source] else: self.sources = source # Check for override of sim_params. # If it isn't passed to this function, # use the default params set with the algorithm. # Else, we create simulation parameters using the start and end of the # source provided. if not sim_params: if not self.sim_params: start = source.start end = source.end sim_params = create_simulation_parameters( start=start, end=end, capital_base=self.capital_base ) else: sim_params = self.sim_params # Create transforms by wrapping them into StatefulTransforms self.transforms = [] for namestring, trans_descr in self.registered_transforms.iteritems(): sf = StatefulTransform( trans_descr['class'], *trans_descr['args'], **trans_descr['kwargs'] ) sf.namestring = namestring self.transforms.append(sf) # create transforms and zipline self.gen = self._create_generator(sim_params) # loop through simulated_trading, each iteration returns a # perf dictionary perfs = [] for perf in self.gen: perfs.append(perf) # convert perf dict to pandas dataframe daily_stats = self._create_daily_stats(perfs) return daily_stats def _create_daily_stats(self, perfs): # create daily and cumulative stats dataframe daily_perfs = [] # TODO: the loop here could overwrite expected properties # of daily_perf. Could potentially raise or log a # warning. for perf in perfs: if 'daily_perf' in perf: perf['daily_perf'].update( perf['daily_perf'].pop('recorded_vars') ) daily_perfs.append(perf['daily_perf']) else: self.risk_report = perf daily_dts = [np.datetime64(perf['period_close'], utc=True) for perf in daily_perfs] daily_stats = pd.DataFrame(daily_perfs, index=daily_dts) return daily_stats def add_transform(self, transform_class, tag, *args, **kwargs): """Add a single-sid, sequential transform to the model. :Arguments: transform_class : class Which transform to use. E.g. mavg. tag : str How to name the transform. Can later be access via: data[sid].tag() Extra args and kwargs will be forwarded to the transform instantiation. """ self.registered_transforms[tag] = {'class': transform_class, 'args': args, 'kwargs': kwargs} def record(self, **kwargs): """ Track and record local variable (i.e. attributes) each day. """ for name, value in kwargs.items(): self._recorded_vars[name] = value def order(self, sid, amount, limit_price=None, stop_price=None): return self.blotter.order(sid, amount, limit_price, stop_price) def order_value(self, sid, value, limit_price=None, stop_price=None): last_price = self.trading_client.current_data[sid].price return self.blotter.order_value(sid, value, last_price, limit_price=limit_price, stop_price=stop_price) @property def recorded_vars(self): return copy(self._recorded_vars) @property def portfolio(self): return self._portfolio def set_portfolio(self, portfolio): self._portfolio = portfolio def set_logger(self, logger): self.logger = logger def set_datetime(self, dt): assert isinstance(dt, datetime), \ "Attempt to set algorithm's current time with non-datetime" assert dt.tzinfo == pytz.utc, \ "Algorithm expects a utc datetime" self.datetime = dt def get_datetime(self): """ Returns a copy of the datetime. """ date_copy = copy(self.datetime) assert date_copy.tzinfo == pytz.utc, \ "Algorithm should have a utc datetime" return date_copy def set_transact(self, transact): """ Set the method that will be called to create a transaction from open orders and trade events. """ self.blotter.transact = transact def set_slippage(self, slippage): if not isinstance(slippage, SlippageModel): raise UnsupportedSlippageModel() if self.initialized: raise OverrideSlippagePostInit() self.slippage = slippage def set_commission(self, commission): if not isinstance(commission, (PerShare, PerTrade)): raise UnsupportedCommissionModel() if self.initialized: raise OverrideCommissionPostInit() self.commission = commission def set_leverage(self, leverage): if not isinstance(leverage, LeverageModel): raise UnsupportedLeverageModel() if self.initialized: raise OverrideLeveragePostInit() self.leverage = leverage def set_sources(self, sources): assert isinstance(sources, list) self.sources = sources def set_transforms(self, transforms): assert isinstance(transforms, list) self.transforms = transforms def set_data_frequency(self, data_frequency): assert data_frequency in ('daily', 'minute') self.data_frequency = data_frequency self.annualizer = ANNUALIZER[self.data_frequency] def order_percent(self, sid, percent, limit_price=None, stop_price=None): """ Place an order in the specified security corresponding to the given percent of the current portfolio value. Note that percent must expressed as a decimal (0.50 means 50\%). """ value = self.portfolio.portfolio_value * percent return self.order_value(sid, value, limit_price, stop_price) def target(self, sid, target, limit_price=None, stop_price=None): """ Place an order to adjust a position to a target number of shares. If the position doesn't already exist, this is equivalent to placing a new order. If the position does exist, this is equivalent to placing an order for the difference between the target number of shares and the current number of shares. """ if sid in self.portfolio.positions: current_position = self.portfolio.positions[sid].amount req_shares = target - current_position return self.order(sid, req_shares, limit_price, stop_price) else: return self.order(sid, target, limit_price, stop_price) def target_value(self, sid, target, limit_price=None, stop_price=None): """ Place an order to adjust a position to a target value. If the position doesn't already exist, this is equivalent to placing a new order. If the position does exist, this is equivalent to placing an order for the difference between the target value and the current value. """ if sid in self.portfolio.positions: current_position = self.portfolio.positions[sid].amount current_price = self.portfolio.positions[sid].last_sale_price current_value = current_position * current_price req_value = target - current_value return self.order_value(sid, req_value, limit_price, stop_price) else: return self.order_value(sid, target, limit_price, stop_price) def target_percent(self, sid, target, limit_price=None, stop_price=None): """ Place an order to adjust a position to a target percent of the current portfolio value. If the position doesn't already exist, this is equivalent to placing a new order. If the position does exist, this is equivalent to placing an order for the difference between the target percent and the current percent. Note that target must expressed as a decimal (0.50 means 50\%). """ if sid in self.portfolio.positions: current_position = self.portfolio.positions[sid].amount current_price = self.portfolio.positions[sid].last_sale_price current_value = current_position * current_price else: current_value = 0 target_value = self.portfolio.portfolio_value * target req_value = target_value - current_value return self.order_value(sid, req_value, limit_price, stop_price)
def transaction_sim(self, **params): """ This is a utility method that asserts expected results for conversion of orders to transactions given a trade history""" trade_count = params['trade_count'] trade_interval = params['trade_interval'] order_count = params['order_count'] order_amount = params['order_amount'] order_interval = params['order_interval'] expected_txn_count = params['expected_txn_count'] expected_txn_volume = params['expected_txn_volume'] # optional parameters # --------------------- # if present, alternate between long and short sales alternate = params.get('alternate') # if present, expect transaction amounts to match orders exactly. complete_fill = params.get('complete_fill') sid = 1 sim_params = factory.create_simulation_parameters() blotter = Blotter() price = [10.1] * trade_count volume = [100] * trade_count start_date = sim_params.first_open generated_trades = factory.create_trade_history( sid, price, volume, trade_interval, sim_params, env=self.env, ) if alternate: alternator = -1 else: alternator = 1 order_date = start_date for i in range(order_count): blotter.set_date(order_date) blotter.order(sid, order_amount * alternator ** i, MarketOrder()) order_date = order_date + order_interval # move after market orders to just after market next # market open. if order_date.hour >= 21: if order_date.minute >= 00: order_date = order_date + timedelta(days=1) order_date = order_date.replace(hour=14, minute=30) # there should now be one open order list stored under the sid oo = blotter.open_orders self.assertEqual(len(oo), 1) self.assertTrue(sid in oo) order_list = oo[sid][:] # make copy self.assertEqual(order_count, len(order_list)) for i in range(order_count): order = order_list[i] self.assertEqual(order.sid, sid) self.assertEqual(order.amount, order_amount * alternator ** i) tracker = PerformanceTracker(sim_params, env=self.env) benchmark_returns = [ Event({'dt': dt, 'returns': ret, 'type': zipline.protocol.DATASOURCE_TYPE.BENCHMARK, 'source_id': 'benchmarks'}) for dt, ret in self.env.benchmark_returns.iteritems() if dt.date() >= sim_params.period_start.date() and dt.date() <= sim_params.period_end.date() ] generated_events = date_sorted_sources(generated_trades, benchmark_returns) # this approximates the loop inside TradingSimulationClient transactions = [] for dt, events in itertools.groupby(generated_events, operator.attrgetter('dt')): for event in events: if event.type == DATASOURCE_TYPE.TRADE: for txn, order in blotter.process_trade(event): transactions.append(txn) tracker.process_transaction(txn) elif event.type == DATASOURCE_TYPE.BENCHMARK: tracker.process_benchmark(event) elif event.type == DATASOURCE_TYPE.TRADE: tracker.process_trade(event) if complete_fill: self.assertEqual(len(transactions), len(order_list)) total_volume = 0 for i in range(len(transactions)): txn = transactions[i] total_volume += txn.amount if complete_fill: order = order_list[i] self.assertEqual(order.amount, txn.amount) self.assertEqual(total_volume, expected_txn_volume) self.assertEqual(len(transactions), expected_txn_count) cumulative_pos = tracker.cumulative_performance.positions[sid] self.assertEqual(total_volume, cumulative_pos.amount) # the open orders should not contain sid. oo = blotter.open_orders self.assertNotIn(sid, oo, "Entry is removed when no open orders")
def __init__(self, *args, **kwargs): """Initialize sids and other state variables. :Arguments: :Optional: initialize : function Function that is called with a single argument at the begninning of the simulation. handle_data : function Function that is called with 2 arguments (context and data) on every bar. script : str Algoscript that contains initialize and handle_data function definition. data_frequency : {'daily', 'minute'} The duration of the bars. capital_base : float <default: 1.0e5> How much capital to start with. instant_fill : bool <default: False> Whether to fill orders immediately or on next bar. asset_finder : An AssetFinder object A new AssetFinder object to be used in this TradingEnvironment asset_metadata: can be either: - dict - pandas.DataFrame - object with 'read' property If dict is provided, it must have the following structure: * keys are the identifiers * values are dicts containing the metadata, with the metadata field name as the key If pandas.DataFrame is provided, it must have the following structure: * column names must be the metadata fields * index must be the different asset identifiers * array contents should be the metadata value If an object with a 'read' property is provided, 'read' must return rows containing at least one of 'sid' or 'symbol' along with the other metadata fields. identifiers : List Any asset identifiers that are not provided in the asset_metadata, but will be traded by this TradingAlgorithm """ self.sources = [] # List of trading controls to be used to validate orders. self.trading_controls = [] # List of account controls to be checked on each bar. self.account_controls = [] self._recorded_vars = {} self.namespace = kwargs.get('namespace', {}) self._platform = kwargs.pop('platform', 'zipline') self.logger = None self.benchmark_return_source = None # default components for transact self.slippage = VolumeShareSlippage() self.commission = PerShare() self.instant_fill = kwargs.pop('instant_fill', False) # set the capital base self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE) self.sim_params = kwargs.pop('sim_params', None) if self.sim_params is None: self.sim_params = create_simulation_parameters( capital_base=self.capital_base, start=kwargs.pop('start', None), end=kwargs.pop('end', None)) self.perf_tracker = PerformanceTracker(self.sim_params) # Update the TradingEnvironment with the provided asset metadata self.trading_environment = kwargs.pop('env', TradingEnvironment.instance()) self.trading_environment.update_asset_finder( asset_finder=kwargs.pop('asset_finder', None), asset_metadata=kwargs.pop('asset_metadata', None), identifiers=kwargs.pop('identifiers', None)) # Pull in the environment's new AssetFinder for quick reference self.asset_finder = self.trading_environment.asset_finder self.init_engine(kwargs.pop('ffc_loader', None)) # Maps from name to Term self._filters = {} self._factors = {} self._classifiers = {} self.blotter = kwargs.pop('blotter', None) if not self.blotter: self.blotter = Blotter() # Set the dt initally to the period start by forcing it to change self.on_dt_changed(self.sim_params.period_start) self.portfolio_needs_update = True self.account_needs_update = True self.performance_needs_update = True self._portfolio = None self._account = None self.history_container_class = kwargs.pop( 'history_container_class', HistoryContainer, ) self.history_container = None self.history_specs = {} # If string is passed in, execute and get reference to # functions. self.algoscript = kwargs.pop('script', None) self._initialize = None self._before_trading_start = None self._analyze = None self.event_manager = EventManager() if self.algoscript is not None: filename = kwargs.pop('algo_filename', None) if filename is None: filename = '<string>' code = compile(self.algoscript, filename, 'exec') exec_(code, self.namespace) self._initialize = self.namespace.get('initialize') if 'handle_data' not in self.namespace: raise ValueError('You must define a handle_data function.') else: self._handle_data = self.namespace['handle_data'] self._before_trading_start = \ self.namespace.get('before_trading_start') # Optional analyze function, gets called after run self._analyze = self.namespace.get('analyze') elif kwargs.get('initialize') and kwargs.get('handle_data'): if self.algoscript is not None: raise ValueError('You can not set script and \ initialize/handle_data.') self._initialize = kwargs.pop('initialize') self._handle_data = kwargs.pop('handle_data') self._before_trading_start = kwargs.pop('before_trading_start', None) self.event_manager.add_event( zipline.utils.events.Event( zipline.utils.events.Always(), # We pass handle_data.__func__ to get the unbound method. # We will explicitly pass the algorithm to bind it again. self.handle_data.__func__, ), prepend=True, ) # If method not defined, NOOP if self._initialize is None: self._initialize = lambda x: None # Alternative way of setting data_frequency for backwards # compatibility. if 'data_frequency' in kwargs: self.data_frequency = kwargs.pop('data_frequency') self._most_recent_data = None # Prepare the algo for initialization self.initialized = False self.initialize_args = args self.initialize_kwargs = kwargs
def __init__(self, *args, **kwargs): """Initialize sids and other state variables. :Arguments: :Optional: initialize : function Function that is called with a single argument at the begninning of the simulation. handle_data : function Function that is called with 2 arguments (context and data) on every bar. script : str Algoscript that contains initialize and handle_data function definition. data_frequency : str (daily, hourly or minutely) The duration of the bars. capital_base : float <default: 1.0e5> How much capital to start with. instant_fill : bool <default: False> Whether to fill orders immediately or on next bar. environment : str <default: 'zipline'> The environment that this algorithm is running in. """ self.datetime = None self.registered_transforms = {} self.transforms = [] self.sources = [] # List of trading controls to be used to validate orders. self.trading_controls = [] self._recorded_vars = {} self.namespace = kwargs.get('namespace', {}) self._environment = kwargs.pop('environment', 'zipline') self.logger = None self.benchmark_return_source = None # default components for transact self.slippage = VolumeShareSlippage() self.commission = PerShare() self.instant_fill = kwargs.pop('instant_fill', False) # set the capital base self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE) self.sim_params = kwargs.pop('sim_params', None) if self.sim_params is None: self.sim_params = create_simulation_parameters( capital_base=self.capital_base ) self.perf_tracker = PerformanceTracker(self.sim_params) self.blotter = kwargs.pop('blotter', None) if not self.blotter: self.blotter = Blotter() self.portfolio_needs_update = True self.account_needs_update = True self.performance_needs_update = True self._portfolio = None self._account = None self.history_container = None self.history_specs = {} # If string is passed in, execute and get reference to # functions. self.algoscript = kwargs.pop('script', None) self._initialize = None self._before_trading_start = None self._analyze = None self.event_manager = EventManager() if self.algoscript is not None: exec_(self.algoscript, self.namespace) self._initialize = self.namespace.get('initialize') if 'handle_data' not in self.namespace: raise ValueError('You must define a handle_data function.') else: self._handle_data = self.namespace['handle_data'] self._before_trading_start = \ self.namespace.get('before_trading_start') # Optional analyze function, gets called after run self._analyze = self.namespace.get('analyze') elif kwargs.get('initialize') and kwargs.get('handle_data'): if self.algoscript is not None: raise ValueError('You can not set script and \ initialize/handle_data.') self._initialize = kwargs.pop('initialize') self._handle_data = kwargs.pop('handle_data') self._before_trading_start = kwargs.pop('before_trading_start', None) self.event_manager.add_event( zipline.utils.events.Event( zipline.utils.events.Always(), # We pass handle_data.__func__ to get the unbound method. # We will explicitly pass the algorithm to bind it again. self.handle_data.__func__, ), prepend=True, ) # If method not defined, NOOP if self._initialize is None: self._initialize = lambda x: None # Alternative way of setting data_frequency for backwards # compatibility. if 'data_frequency' in kwargs: self.data_frequency = kwargs.pop('data_frequency') # Subclasses that override initialize should only worry about # setting self.initialized = True if AUTO_INITIALIZE is # is manually set to False. self.initialized = False self.initialize(*args, **kwargs) if self.AUTO_INITIALIZE: self.initialized = True
class TradingAlgorithm(object): """ Base class for trading algorithms. Inherit and overload initialize() and handle_data(data). A new algorithm could look like this: ``` from zipline.api import order def initialize(context): context.sid = 'AAPL' context.amount = 100 def handle_data(self, data): sid = context.sid amount = context.amount order(sid, amount) ``` To then to run this algorithm pass these functions to TradingAlgorithm: my_algo = TradingAlgorithm(initialize, handle_data) stats = my_algo.run(data) """ # If this is set to false then it is the responsibility # of the overriding subclass to set initialized = true AUTO_INITIALIZE = True def __init__(self, *args, **kwargs): """Initialize sids and other state variables. :Arguments: :Optional: initialize : function Function that is called with a single argument at the begninning of the simulation. handle_data : function Function that is called with 2 arguments (context and data) on every bar. script : str Algoscript that contains initialize and handle_data function definition. data_frequency : str (daily, hourly or minutely) The duration of the bars. capital_base : float <default: 1.0e5> How much capital to start with. instant_fill : bool <default: False> Whether to fill orders immediately or on next bar. environment : str <default: 'zipline'> The environment that this algorithm is running in. """ self.datetime = None self.registered_transforms = {} self.transforms = [] self.sources = [] # List of trading controls to be used to validate orders. self.trading_controls = [] self._recorded_vars = {} self.namespace = kwargs.get('namespace', {}) self._environment = kwargs.pop('environment', 'zipline') self.logger = None self.benchmark_return_source = None # default components for transact self.slippage = VolumeShareSlippage() self.commission = PerShare() self.instant_fill = kwargs.pop('instant_fill', False) # set the capital base self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE) self.sim_params = kwargs.pop('sim_params', None) if self.sim_params is None: self.sim_params = create_simulation_parameters( capital_base=self.capital_base ) self.perf_tracker = PerformanceTracker(self.sim_params) self.blotter = kwargs.pop('blotter', None) if not self.blotter: self.blotter = Blotter() self.portfolio_needs_update = True self.account_needs_update = True self.performance_needs_update = True self._portfolio = None self._account = None self.history_container = None self.history_specs = {} # If string is passed in, execute and get reference to # functions. self.algoscript = kwargs.pop('script', None) self._initialize = None self._before_trading_start = None self._analyze = None self.event_manager = EventManager() if self.algoscript is not None: exec_(self.algoscript, self.namespace) self._initialize = self.namespace.get('initialize') if 'handle_data' not in self.namespace: raise ValueError('You must define a handle_data function.') else: self._handle_data = self.namespace['handle_data'] self._before_trading_start = \ self.namespace.get('before_trading_start') # Optional analyze function, gets called after run self._analyze = self.namespace.get('analyze') elif kwargs.get('initialize') and kwargs.get('handle_data'): if self.algoscript is not None: raise ValueError('You can not set script and \ initialize/handle_data.') self._initialize = kwargs.pop('initialize') self._handle_data = kwargs.pop('handle_data') self._before_trading_start = kwargs.pop('before_trading_start', None) self.event_manager.add_event( zipline.utils.events.Event( zipline.utils.events.Always(), # We pass handle_data.__func__ to get the unbound method. # We will explicitly pass the algorithm to bind it again. self.handle_data.__func__, ), prepend=True, ) # If method not defined, NOOP if self._initialize is None: self._initialize = lambda x: None # Alternative way of setting data_frequency for backwards # compatibility. if 'data_frequency' in kwargs: self.data_frequency = kwargs.pop('data_frequency') # Subclasses that override initialize should only worry about # setting self.initialized = True if AUTO_INITIALIZE is # is manually set to False. self.initialized = False self.initialize(*args, **kwargs) if self.AUTO_INITIALIZE: self.initialized = True def initialize(self, *args, **kwargs): """ Call self._initialize with `self` made available to Zipline API functions. """ with ZiplineAPI(self): self._initialize(self) def before_trading_start(self): if self._before_trading_start is None: return self._before_trading_start(self) def handle_data(self, data): if self.history_container: self.history_container.update(data, self.datetime) self._handle_data(self, data) def analyze(self, perf): if self._analyze is None: return with ZiplineAPI(self): self._analyze(self, perf) def __repr__(self): """ N.B. this does not yet represent a string that can be used to instantiate an exact copy of an algorithm. However, it is getting close, and provides some value as something that can be inspected interactively. """ return """ {class_name}( capital_base={capital_base} sim_params={sim_params}, initialized={initialized}, slippage={slippage}, commission={commission}, blotter={blotter}, recorded_vars={recorded_vars}) """.strip().format(class_name=self.__class__.__name__, capital_base=self.capital_base, sim_params=repr(self.sim_params), initialized=self.initialized, slippage=repr(self.slippage), commission=repr(self.commission), blotter=repr(self.blotter), recorded_vars=repr(self.recorded_vars)) def _create_data_generator(self, source_filter, sim_params=None): """ Create a merged data generator using the sources and transforms attached to this algorithm. ::source_filter:: is a method that receives events in date sorted order, and returns True for those events that should be processed by the zipline, and False for those that should be skipped. """ if sim_params is None: sim_params = self.sim_params if self.benchmark_return_source is None: env = trading.environment if (sim_params.data_frequency == 'minute' or sim_params.emission_rate == 'minute'): update_time = lambda date: env.get_open_and_close(date)[1] else: update_time = lambda date: date benchmark_return_source = [ Event({'dt': update_time(dt), 'returns': ret, 'type': zipline.protocol.DATASOURCE_TYPE.BENCHMARK, 'source_id': 'benchmarks'}) for dt, ret in trading.environment.benchmark_returns.iterkv() if dt.date() >= sim_params.period_start.date() and dt.date() <= sim_params.period_end.date() ] else: benchmark_return_source = self.benchmark_return_source date_sorted = date_sorted_sources(*self.sources) if source_filter: date_sorted = filter(source_filter, date_sorted) with_tnfms = sequential_transforms(date_sorted, *self.transforms) with_benchmarks = date_sorted_sources(benchmark_return_source, with_tnfms) # Group together events with the same dt field. This depends on the # events already being sorted. return groupby(with_benchmarks, attrgetter('dt')) def _create_generator(self, sim_params, source_filter=None): """ Create a basic generator setup using the sources and transforms attached to this algorithm. ::source_filter:: is a method that receives events in date sorted order, and returns True for those events that should be processed by the zipline, and False for those that should be skipped. """ if self.perf_tracker is None: # HACK: When running with the `run` method, we set perf_tracker to # None so that it will be overwritten here. self.perf_tracker = PerformanceTracker(sim_params) self.portfolio_needs_update = True self.account_needs_update = True self.performance_needs_update = True self.data_gen = self._create_data_generator(source_filter, sim_params) self.trading_client = AlgorithmSimulator(self, sim_params) transact_method = transact_partial(self.slippage, self.commission) self.set_transact(transact_method) return self.trading_client.transform(self.data_gen) def get_generator(self): """ Override this method to add new logic to the construction of the generator. Overrides can use the _create_generator method to get a standard construction generator. """ return self._create_generator(self.sim_params) # TODO: make a new subclass, e.g. BatchAlgorithm, and move # the run method to the subclass, and refactor to put the # generator creation logic into get_generator. def run(self, source, overwrite_sim_params=True, benchmark_return_source=None): """Run the algorithm. :Arguments: source : can be either: - pandas.DataFrame - zipline source - list of sources If pandas.DataFrame is provided, it must have the following structure: * column names must consist of ints representing the different sids * index must be DatetimeIndex * array contents should be price info. :Returns: daily_stats : pandas.DataFrame Daily performance metrics such as returns, alpha etc. """ if isinstance(source, list): if overwrite_sim_params: warnings.warn("""List of sources passed, will not attempt to extract sids, and start and end dates. Make sure to set the correct fields in sim_params passed to __init__().""", UserWarning) overwrite_sim_params = False elif isinstance(source, pd.DataFrame): # if DataFrame provided, wrap in DataFrameSource source = DataFrameSource(source) elif isinstance(source, pd.Panel): source = DataPanelSource(source) if isinstance(source, list): self.set_sources(source) else: self.set_sources([source]) # Override sim_params if params are provided by the source. if overwrite_sim_params: if hasattr(source, 'start'): self.sim_params.period_start = source.start if hasattr(source, 'end'): self.sim_params.period_end = source.end all_sids = [sid for s in self.sources for sid in s.sids] self.sim_params.sids = set(all_sids) # Changing period_start and period_close might require updating # of first_open and last_close. self.sim_params._update_internal() # Create history containers if len(self.history_specs) != 0: self.history_container = HistoryContainer( self.history_specs, self.sim_params.sids, self.sim_params.first_open) # Create transforms by wrapping them into StatefulTransforms self.transforms = [] for namestring, trans_descr in iteritems(self.registered_transforms): sf = StatefulTransform( trans_descr['class'], *trans_descr['args'], **trans_descr['kwargs'] ) sf.namestring = namestring self.transforms.append(sf) # force a reset of the performance tracker, in case # this is a repeat run of the algorithm. self.perf_tracker = None # create transforms and zipline self.gen = self._create_generator(self.sim_params) with ZiplineAPI(self): # loop through simulated_trading, each iteration returns a # perf dictionary perfs = [] for perf in self.gen: perfs.append(perf) # convert perf dict to pandas dataframe daily_stats = self._create_daily_stats(perfs) self.analyze(daily_stats) return daily_stats def _create_daily_stats(self, perfs): # create daily and cumulative stats dataframe daily_perfs = [] # TODO: the loop here could overwrite expected properties # of daily_perf. Could potentially raise or log a # warning. for perf in perfs: if 'daily_perf' in perf: perf['daily_perf'].update( perf['daily_perf'].pop('recorded_vars') ) daily_perfs.append(perf['daily_perf']) else: self.risk_report = perf daily_dts = [np.datetime64(perf['period_close'], utc=True) for perf in daily_perfs] daily_stats = pd.DataFrame(daily_perfs, index=daily_dts) return daily_stats def add_transform(self, transform_class, tag, *args, **kwargs): """Add a single-sid, sequential transform to the model. :Arguments: transform_class : class Which transform to use. E.g. mavg. tag : str How to name the transform. Can later be access via: data[sid].tag() Extra args and kwargs will be forwarded to the transform instantiation. """ self.registered_transforms[tag] = {'class': transform_class, 'args': args, 'kwargs': kwargs} @api_method def get_environment(self): return self._environment def add_event(self, rule=None, callback=None): """ Adds an event to the algorithm's EventManager. """ self.event_manager.add_event( zipline.utils.events.Event(rule, callback), ) @api_method def schedule_function(self, func, date_rule=None, time_rule=None, half_days=True): """ Schedules a function to be called with some timed rules. """ if self.sim_params.data_frequency != 'minute': raise IncompatibleScheduleFunctionDataFrequency() date_rule = date_rule or DateRuleFactory.every_day() time_rule = time_rule or TimeRuleFactory.market_open() self.add_event( make_eventrule(date_rule, time_rule, half_days), func, ) @api_method def record(self, *args, **kwargs): """ Track and record local variable (i.e. attributes) each day. """ # Make 2 objects both referencing the same iterator args = [iter(args)] * 2 # Zip generates list entries by calling `next` on each iterator it # receives. In this case the two iterators are the same object, so the # call to next on args[0] will also advance args[1], resulting in zip # returning (a,b) (c,d) (e,f) rather than (a,a) (b,b) (c,c) etc. positionals = zip(*args) for name, value in chain(positionals, iteritems(kwargs)): self._recorded_vars[name] = value @api_method def symbol(self, symbol_str, as_of_date=None): """ Default symbol lookup for any source that directly maps the symbol to the identifier (e.g. yahoo finance). Keyword argument as_of_date is ignored. """ return symbol_str @api_method def order(self, sid, amount, limit_price=None, stop_price=None, style=None): """ Place an order using the specified parameters. """ def round_if_near_integer(a, epsilon=1e-4): """ Round a to the nearest integer if that integer is within an epsilon of a. """ if abs(a - round(a)) <= epsilon: return round(a) else: return a # Truncate to the integer share count that's either within .0001 of # amount or closer to zero. # E.g. 3.9999 -> 4.0; 5.5 -> 5.0; -5.5 -> -5.0 amount = int(round_if_near_integer(amount)) # Raises a ZiplineError if invalid parameters are detected. self.validate_order_params(sid, amount, limit_price, stop_price, style) # Convert deprecated limit_price and stop_price parameters to use # ExecutionStyle objects. style = self.__convert_order_params_for_blotter(limit_price, stop_price, style) return self.blotter.order(sid, amount, style) def validate_order_params(self, sid, amount, limit_price, stop_price, style): """ Helper method for validating parameters to the order API function. Raises an UnsupportedOrderParameters if invalid arguments are found. """ if not self.initialized: raise OrderDuringInitialize( msg="order() can only be called from within handle_data()" ) if style: if limit_price: raise UnsupportedOrderParameters( msg="Passing both limit_price and style is not supported." ) if stop_price: raise UnsupportedOrderParameters( msg="Passing both stop_price and style is not supported." ) for control in self.trading_controls: control.validate(sid, amount, self.updated_portfolio(), self.get_datetime(), self.trading_client.current_data) @staticmethod def __convert_order_params_for_blotter(limit_price, stop_price, style): """ Helper method for converting deprecated limit_price and stop_price arguments into ExecutionStyle instances. This function assumes that either style == None or (limit_price, stop_price) == (None, None). """ # TODO_SS: DeprecationWarning for usage of limit_price and stop_price. if style: assert (limit_price, stop_price) == (None, None) return style if limit_price and stop_price: return StopLimitOrder(limit_price, stop_price) if limit_price: return LimitOrder(limit_price) if stop_price: return StopOrder(stop_price) else: return MarketOrder() @api_method def order_value(self, sid, value, limit_price=None, stop_price=None, style=None): """ Place an order by desired value rather than desired number of shares. If the requested sid is found in the universe, the requested value is divided by its price to imply the number of shares to transact. value > 0 :: Buy/Cover value < 0 :: Sell/Short Market order: order(sid, value) Limit order: order(sid, value, limit_price) Stop order: order(sid, value, None, stop_price) StopLimit order: order(sid, value, limit_price, stop_price) """ last_price = self.trading_client.current_data[sid].price if np.allclose(last_price, 0): zero_message = "Price of 0 for {psid}; can't infer value".format( psid=sid ) if self.logger: self.logger.debug(zero_message) # Don't place any order return else: amount = value / last_price return self.order(sid, amount, limit_price=limit_price, stop_price=stop_price, style=style) @property def recorded_vars(self): return copy(self._recorded_vars) @property def portfolio(self): return self.updated_portfolio() def updated_portfolio(self): if self.portfolio_needs_update: self._portfolio = \ self.perf_tracker.get_portfolio(self.performance_needs_update) self.portfolio_needs_update = False self.performance_needs_update = False return self._portfolio @property def account(self): return self.updated_account() def updated_account(self): if self.account_needs_update: self._account = \ self.perf_tracker.get_account(self.performance_needs_update) self.account_needs_update = False self.performance_needs_update = False return self._account def set_logger(self, logger): self.logger = logger def on_dt_changed(self, dt): """ Callback triggered by the simulation loop whenever the current dt changes. Any logic that should happen exactly once at the start of each datetime group should happen here. """ assert isinstance(dt, datetime), \ "Attempt to set algorithm's current time with non-datetime" assert dt.tzinfo == pytz.utc, \ "Algorithm expects a utc datetime" self.datetime = dt self.perf_tracker.set_date(dt) self.blotter.set_date(dt) @api_method def get_datetime(self): """ Returns a copy of the datetime. """ date_copy = copy(self.datetime) assert date_copy.tzinfo == pytz.utc, \ "Algorithm should have a utc datetime" return date_copy def set_transact(self, transact): """ Set the method that will be called to create a transaction from open orders and trade events. """ self.blotter.transact = transact def update_dividends(self, dividend_frame): """ Set DataFrame used to process dividends. DataFrame columns should contain at least the entries in zp.DIVIDEND_FIELDS. """ self.perf_tracker.update_dividends(dividend_frame) @api_method def set_slippage(self, slippage): if not isinstance(slippage, SlippageModel): raise UnsupportedSlippageModel() if self.initialized: raise OverrideSlippagePostInit() self.slippage = slippage @api_method def set_commission(self, commission): if not isinstance(commission, (PerShare, PerTrade, PerDollar)): raise UnsupportedCommissionModel() if self.initialized: raise OverrideCommissionPostInit() self.commission = commission def set_sources(self, sources): assert isinstance(sources, list) self.sources = sources def set_transforms(self, transforms): assert isinstance(transforms, list) self.transforms = transforms # Remain backwards compatibility @property def data_frequency(self): return self.sim_params.data_frequency @data_frequency.setter def data_frequency(self, value): assert value in ('daily', 'minute') self.sim_params.data_frequency = value @api_method def order_percent(self, sid, percent, limit_price=None, stop_price=None, style=None): """ Place an order in the specified security corresponding to the given percent of the current portfolio value. Note that percent must expressed as a decimal (0.50 means 50\%). """ value = self.portfolio.portfolio_value * percent return self.order_value(sid, value, limit_price=limit_price, stop_price=stop_price, style=style) @api_method def order_target(self, sid, target, limit_price=None, stop_price=None, style=None): """ Place an order to adjust a position to a target number of shares. If the position doesn't already exist, this is equivalent to placing a new order. If the position does exist, this is equivalent to placing an order for the difference between the target number of shares and the current number of shares. """ if sid in self.portfolio.positions: current_position = self.portfolio.positions[sid].amount req_shares = target - current_position return self.order(sid, req_shares, limit_price=limit_price, stop_price=stop_price, style=style) else: return self.order(sid, target, limit_price=limit_price, stop_price=stop_price, style=style) @api_method def order_target_value(self, sid, target, limit_price=None, stop_price=None, style=None): """ Place an order to adjust a position to a target value. If the position doesn't already exist, this is equivalent to placing a new order. If the position does exist, this is equivalent to placing an order for the difference between the target value and the current value. """ last_price = self.trading_client.current_data[sid].price if np.allclose(last_price, 0): # Don't place an order if self.logger: zero_message = "Price of 0 for {psid}; can't infer value" self.logger.debug(zero_message.format(psid=sid)) return target_amount = target / last_price return self.order_target(sid, target_amount, limit_price=limit_price, stop_price=stop_price, style=style) @api_method def order_target_percent(self, sid, target, limit_price=None, stop_price=None, style=None): """ Place an order to adjust a position to a target percent of the current portfolio value. If the position doesn't already exist, this is equivalent to placing a new order. If the position does exist, this is equivalent to placing an order for the difference between the target percent and the current percent. Note that target must expressed as a decimal (0.50 means 50\%). """ target_value = self.portfolio.portfolio_value * target return self.order_target_value(sid, target_value, limit_price=limit_price, stop_price=stop_price, style=style) @api_method def get_open_orders(self, sid=None): if sid is None: return { key: [order.to_api_obj() for order in orders] for key, orders in iteritems(self.blotter.open_orders) if orders } if sid in self.blotter.open_orders: orders = self.blotter.open_orders[sid] return [order.to_api_obj() for order in orders] return [] @api_method def get_order(self, order_id): if order_id in self.blotter.orders: return self.blotter.orders[order_id].to_api_obj() @api_method def cancel_order(self, order_param): order_id = order_param if isinstance(order_param, zipline.protocol.Order): order_id = order_param.id self.blotter.cancel(order_id) @api_method def add_history(self, bar_count, frequency, field, ffill=True): data_frequency = self.sim_params.data_frequency daily_at_midnight = (data_frequency == 'daily') history_spec = HistorySpec(bar_count, frequency, field, ffill, daily_at_midnight=daily_at_midnight, data_frequency=data_frequency) self.history_specs[history_spec.key_str] = history_spec @api_method def history(self, bar_count, frequency, field, ffill=True): spec_key_str = HistorySpec.spec_key( bar_count, frequency, field, ffill) history_spec = self.history_specs[spec_key_str] return self.history_container.get_history(history_spec, self.datetime) #################### # Trading Controls # #################### def register_trading_control(self, control): """ Register a new TradingControl to be checked prior to order calls. """ if self.initialized: raise RegisterTradingControlPostInit() self.trading_controls.append(control) @api_method def set_max_position_size(self, sid=None, max_shares=None, max_notional=None): """ Set a limit on the number of shares and/or dollar value held for the given sid. Limits are treated as absolute values and are enforced at the time that the algo attempts to place an order for sid. This means that it's possible to end up with more than the max number of shares due to splits/dividends, and more than the max notional due to price improvement. If an algorithm attempts to place an order that would result in increasing the absolute value of shares/dollar value exceeding one of these limits, raise a TradingControlException. """ control = MaxPositionSize(sid=sid, max_shares=max_shares, max_notional=max_notional) self.register_trading_control(control) @api_method def set_max_order_size(self, sid=None, max_shares=None, max_notional=None): """ Set a limit on the number of shares and/or dollar value of any single order placed for sid. Limits are treated as absolute values and are enforced at the time that the algo attempts to place an order for sid. If an algorithm attempts to place an order that would result in exceeding one of these limits, raise a TradingControlException. """ control = MaxOrderSize(sid=sid, max_shares=max_shares, max_notional=max_notional) self.register_trading_control(control) @api_method def set_max_order_count(self, max_count): """ Set a limit on the number of orders that can be placed within the given time interval. """ control = MaxOrderCount(max_count) self.register_trading_control(control) @api_method def set_long_only(self): """ Set a rule specifying that this algorithm cannot take short positions. """ self.register_trading_control(LongOnly()) @classmethod def all_api_methods(cls): """ Return a list of all the TradingAlgorithm API methods. """ return [fn for fn in cls.__dict__.itervalues() if getattr(fn, 'is_api_method', False)]
def __init__(self, *args, **kwargs): """Initialize sids and other state variables. :Arguments: :Optional: initialize : function Function that is called with a single argument at the begninning of the simulation. handle_data : function Function that is called with 2 arguments (context and data) on every bar. script : str Algoscript that contains initialize and handle_data function definition. data_frequency : {'daily', 'minute'} The duration of the bars. capital_base : float <default: 1.0e5> How much capital to start with. instant_fill : bool <default: False> Whether to fill orders immediately or on next bar. asset_finder : An AssetFinder object A new AssetFinder object to be used in this TradingEnvironment equities_metadata : can be either: - dict - pandas.DataFrame - object with 'read' property If dict is provided, it must have the following structure: * keys are the identifiers * values are dicts containing the metadata, with the metadata field name as the key If pandas.DataFrame is provided, it must have the following structure: * column names must be the metadata fields * index must be the different asset identifiers * array contents should be the metadata value If an object with a 'read' property is provided, 'read' must return rows containing at least one of 'sid' or 'symbol' along with the other metadata fields. identifiers : List Any asset identifiers that are not provided in the equities_metadata, but will be traded by this TradingAlgorithm """ self.sources = [] # List of trading controls to be used to validate orders. self.trading_controls = [] # List of account controls to be checked on each bar. self.account_controls = [] self._recorded_vars = {} self.namespace = kwargs.pop('namespace', {}) self._platform = kwargs.pop('platform', 'zipline') self.logger = None self.benchmark_return_source = None # default components for transact self.slippage = VolumeShareSlippage() self.commission = PerShare() self.instant_fill = kwargs.pop('instant_fill', False) # If an env has been provided, pop it self.trading_environment = kwargs.pop('env', None) if self.trading_environment is None: self.trading_environment = TradingEnvironment() # Update the TradingEnvironment with the provided asset metadata self.trading_environment.write_data( equities_data=kwargs.pop('equities_metadata', {}), equities_identifiers=kwargs.pop('identifiers', []), futures_data=kwargs.pop('futures_metadata', {}), ) # set the capital base self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE) self.sim_params = kwargs.pop('sim_params', None) if self.sim_params is None: self.sim_params = create_simulation_parameters( capital_base=self.capital_base, start=kwargs.pop('start', None), end=kwargs.pop('end', None), env=self.trading_environment, ) else: self.sim_params.update_internal_from_env(self.trading_environment) # Build a perf_tracker self.perf_tracker = PerformanceTracker(sim_params=self.sim_params, env=self.trading_environment) # Pull in the environment's new AssetFinder for quick reference self.asset_finder = self.trading_environment.asset_finder self.init_engine(kwargs.pop('ffc_loader', None)) # Maps from name to Term self._filters = {} self._factors = {} self._classifiers = {} self.blotter = kwargs.pop('blotter', None) if not self.blotter: self.blotter = Blotter() # Set the dt initally to the period start by forcing it to change self.on_dt_changed(self.sim_params.period_start) # The symbol lookup date specifies the date to use when resolving # symbols to sids, and can be set using set_symbol_lookup_date() self._symbol_lookup_date = None self.portfolio_needs_update = True self.account_needs_update = True self.performance_needs_update = True self._portfolio = None self._account = None self.history_container_class = kwargs.pop( 'history_container_class', HistoryContainer, ) self.history_container = None self.history_specs = {} # If string is passed in, execute and get reference to # functions. self.algoscript = kwargs.pop('script', None) self._initialize = None self._before_trading_start = None self._analyze = None self.event_manager = EventManager() if self.algoscript is not None: filename = kwargs.pop('algo_filename', None) if filename is None: filename = '<string>' code = compile(self.algoscript, filename, 'exec') exec_(code, self.namespace) self._initialize = self.namespace.get('initialize') if 'handle_data' not in self.namespace: raise ValueError('You must define a handle_data function.') else: self._handle_data = self.namespace['handle_data'] self._before_trading_start = \ self.namespace.get('before_trading_start') # Optional analyze function, gets called after run self._analyze = self.namespace.get('analyze') elif kwargs.get('initialize') and kwargs.get('handle_data'): if self.algoscript is not None: raise ValueError('You can not set script and \ initialize/handle_data.') self._initialize = kwargs.pop('initialize') self._handle_data = kwargs.pop('handle_data') self._before_trading_start = kwargs.pop('before_trading_start', None) self.event_manager.add_event( zipline.utils.events.Event( zipline.utils.events.Always(), # We pass handle_data.__func__ to get the unbound method. # We will explicitly pass the algorithm to bind it again. self.handle_data.__func__, ), prepend=True, ) # If method not defined, NOOP if self._initialize is None: self._initialize = lambda x: None # Alternative way of setting data_frequency for backwards # compatibility. if 'data_frequency' in kwargs: self.data_frequency = kwargs.pop('data_frequency') self._most_recent_data = None # Prepare the algo for initialization self.initialized = False self.initialize_args = args self.initialize_kwargs = kwargs
def __init__(self, *args, **kwargs): """Initialize sids and other state variables. :Arguments: data_frequency : str (daily, hourly or minutely) The duration of the bars. annualizer : int <optional> Which constant to use for annualizing risk metrics. If not provided, will extract from data_frequency. capital_base : float <default: 1.0e5> How much capital to start with. instant_fill : bool <default: False> Whether to fill orders immediately or on next bar. """ self.datetime = None self.registered_transforms = {} self.transforms = [] self.sources = [] self._recorded_vars = {} self.logger = None self.benchmark_return_source = None self.perf_tracker = None # default components for transact self.slippage = VolumeShareSlippage() self.commission = PerShare() if 'data_frequency' in kwargs: self.set_data_frequency(kwargs.pop('data_frequency')) else: self.data_frequency = None self.instant_fill = kwargs.pop('instant_fill', False) # Override annualizer if set if 'annualizer' in kwargs: self.annualizer = kwargs['annualizer'] # set the capital base self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE) self.sim_params = kwargs.pop('sim_params', None) if self.sim_params: self.sim_params.data_frequency = self.data_frequency self.perf_tracker = PerformanceTracker(self.sim_params) self.blotter = kwargs.pop('blotter', None) if not self.blotter: self.blotter = Blotter() self.portfolio_needs_update = True self._portfolio = None # an algorithm subclass needs to set initialized to True when # it is fully initialized. self.initialized = False # call to user-defined constructor method self.initialize(*args, **kwargs)
def __init__(self, *args, **kwargs): """Initialize sids and other state variables. :Arguments: :Optional: initialize : function Function that is called with a single argument at the begninning of the simulation. handle_data : function Function that is called with 2 arguments (context and data) on every bar. script : str Algoscript that contains initialize and handle_data function definition. data_frequency : str (daily, hourly or minutely) The duration of the bars. capital_base : float <default: 1.0e5> How much capital to start with. instant_fill : bool <default: False> Whether to fill orders immediately or on next bar. """ self.datetime = None self.registered_transforms = {} self.transforms = [] self.sources = [] # List of trading controls to be used to validate orders. self.trading_controls = [] self._recorded_vars = {} self.namespace = kwargs.get('namespace', {}) self.logger = None self.benchmark_return_source = None # default components for transact self.slippage = VolumeShareSlippage() self.commission = PerShare() self.instant_fill = kwargs.pop('instant_fill', False) # set the capital base self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE) self.sim_params = kwargs.pop('sim_params', None) if self.sim_params is None: self.sim_params = create_simulation_parameters( capital_base=self.capital_base ) self.perf_tracker = PerformanceTracker(self.sim_params) self.blotter = kwargs.pop('blotter', None) if not self.blotter: self.blotter = Blotter() self.portfolio_needs_update = True self._portfolio = None self.history_container = None self.history_specs = {} # If string is passed in, execute and get reference to # functions. self.algoscript = kwargs.pop('script', None) self._initialize = None self._analyze = None if self.algoscript is not None: exec_(self.algoscript, self.namespace) self._initialize = self.namespace.get('initialize', None) if 'handle_data' not in self.namespace: raise ValueError('You must define a handle_data function.') else: self._handle_data = self.namespace['handle_data'] # Optional analyze function, gets called after run self._analyze = self.namespace.get('analyze', None) elif kwargs.get('initialize', False) and kwargs.get('handle_data'): if self.algoscript is not None: raise ValueError('You can not set script and \ initialize/handle_data.') self._initialize = kwargs.pop('initialize') self._handle_data = kwargs.pop('handle_data') # If method not defined, NOOP if self._initialize is None: self._initialize = lambda x: None # Alternative way of setting data_frequency for backwards # compatibility. if 'data_frequency' in kwargs: self.data_frequency = kwargs.pop('data_frequency') # Subclasses that override initialize should only worry about # setting self.initialized = True if AUTO_INITIALIZE is # is manually set to False. self.initialized = False self.initialize(*args, **kwargs) if self.AUTO_INITIALIZE: self.initialized = True
class AlgorithmSimulator(object): EMISSION_TO_PERF_KEY_MAP = { 'minute': 'intraday_perf', 'daily': 'daily_perf' } def get_hash(self): """ There should only ever be one TSC in the system, so we don't bother passing args into the hash. """ return self.__class__.__name__ + hash_args() def __init__(self, algo, sim_params): # ============== # Simulation # Param Setup # ============== self.sim_params = sim_params # ============== # Perf Tracker # Setup # ============== self.perf_tracker = PerformanceTracker(self.sim_params) self.perf_key = self.EMISSION_TO_PERF_KEY_MAP[ self.perf_tracker.emission_rate] # ============== # Algo Setup # ============== self.algo = algo self.algo_start = self.sim_params.first_open self.algo_start = self.algo_start.replace(hour=0, minute=0, second=0, microsecond=0) # ============== # Snapshot Setup # ============== # The algorithm's data as of our most recent event. # We want an object that will have empty objects as default # values on missing keys. self.current_data = BarData() # We don't have a datetime for the current snapshot until we # receive a message. self.simulation_dt = None self.snapshot_dt = None # ============= # Logging Setup # ============= # Processor function for injecting the algo_dt into # user prints/logs. def inject_algo_dt(record): if not 'algo_dt' in record.extra: record.extra['algo_dt'] = self.snapshot_dt self.processor = Processor(inject_algo_dt) def transform(self, stream_in): """ Main generator work loop. """ # Set the simulation date to be the first event we see. peek_date, peek_snapshot = next(stream_in) self.simulation_dt = peek_date # Stitch back together the generator by placing the peeked # event back in front stream = itertools.chain([(peek_date, peek_snapshot)], stream_in) # inject the current algo # snapshot time to any log record generated. with self.processor.threadbound(): updated = False bm_updated = False for date, snapshot in stream: self.perf_tracker.set_date(date) self.algo.blotter.set_date(date) # If we're still in the warmup period. Use the event to # update our universe, but don't yield any perf messages, # and don't send a snapshot to handle_data. if date < self.algo_start: for event in snapshot: if event.type in (DATASOURCE_TYPE.TRADE, DATASOURCE_TYPE.CUSTOM): self.update_universe(event) self.perf_tracker.process_event(event) else: for event in snapshot: if event.type in (DATASOURCE_TYPE.TRADE, DATASOURCE_TYPE.CUSTOM): self.update_universe(event) updated = True if event.type == DATASOURCE_TYPE.BENCHMARK: bm_updated = True txns, orders = self.algo.blotter.process_trade(event) for data in chain([event], txns, orders): self.perf_tracker.process_event(data) # Update our portfolio. self.algo.set_portfolio(self.perf_tracker.get_portfolio()) # Send the current state of the universe # to the user's algo. if updated: self.simulate_snapshot(date) updated = False # run orders placed in the algorithm call # above through perf tracker before emitting # the perf packet, so that the perf includes # placed orders for order in self.algo.blotter.new_orders: self.perf_tracker.process_event(order) self.algo.blotter.new_orders = [] # The benchmark is our internal clock. When it # updates, we need to emit a performance message. if bm_updated: bm_updated = False yield self.get_message(date) risk_message = self.perf_tracker.handle_simulation_end() # When emitting minutely, it is still useful to have a final # packet with the entire days performance rolled up. if self.perf_tracker.emission_rate == 'minute': daily_rollup = self.perf_tracker.to_dict( emission_type='daily' ) daily_rollup['daily_perf']['recorded_vars'] = \ self.algo.recorded_vars yield daily_rollup yield risk_message def get_message(self, date): rvars = self.algo.recorded_vars if self.perf_tracker.emission_rate == 'daily': perf_message = \ self.perf_tracker.handle_market_close() perf_message['daily_perf']['recorded_vars'] = rvars return perf_message elif self.perf_tracker.emission_rate == 'minute': self.perf_tracker.handle_minute_close(date) perf_message = self.perf_tracker.to_dict() perf_message['intraday_perf']['recorded_vars'] = rvars return perf_message def update_universe(self, event): """ Update the universe with new event information. """ # Update our knowledge of this event's sid sid_data = self.current_data[event.sid] sid_data.__dict__.update(event.__dict__) def simulate_snapshot(self, date): """ Run the user's algo against our current snapshot and update the algo's simulated time. """ # Needs to be set so that we inject the proper date into algo # log/print lines. self.snapshot_dt = date self.algo.set_datetime(self.snapshot_dt) # Update the simulation time. self.simulation_dt = date self.algo.handle_data(self.current_data)
def transaction_sim(self, **params): """ This is a utility method that asserts expected results for conversion of orders to transactions given a trade history""" trade_count = params['trade_count'] trade_interval = params['trade_interval'] order_count = params['order_count'] order_amount = params['order_amount'] order_interval = params['order_interval'] expected_txn_count = params['expected_txn_count'] expected_txn_volume = params['expected_txn_volume'] # optional parameters # --------------------- # if present, alternate between long and short sales alternate = params.get('alternate') # if present, expect transaction amounts to match orders exactly. complete_fill = params.get('complete_fill') sid = 1 sim_params = factory.create_simulation_parameters() blotter = Blotter() price = [10.1] * trade_count volume = [100] * trade_count start_date = sim_params.first_open generated_trades = factory.create_trade_history( sid, price, volume, trade_interval, sim_params) if alternate: alternator = -1 else: alternator = 1 order_date = start_date for i in range(order_count): blotter.set_date(order_date) blotter.order(sid, order_amount * alternator**i, MarketOrder()) order_date = order_date + order_interval # move after market orders to just after market next # market open. if order_date.hour >= 21: if order_date.minute >= 00: order_date = order_date + timedelta(days=1) order_date = order_date.replace(hour=14, minute=30) # there should now be one open order list stored under the sid oo = blotter.open_orders self.assertEqual(len(oo), 1) self.assertTrue(sid in oo) order_list = oo[sid][:] # make copy self.assertEqual(order_count, len(order_list)) for i in range(order_count): order = order_list[i] self.assertEqual(order.sid, sid) self.assertEqual(order.amount, order_amount * alternator**i) tracker = PerformanceTracker(sim_params) benchmark_returns = [ Event({ 'dt': dt, 'returns': ret, 'type': zipline.protocol.DATASOURCE_TYPE.BENCHMARK, 'source_id': 'benchmarks' }) for dt, ret in trading.environment.benchmark_returns.iteritems() if dt.date() >= sim_params.period_start.date() and dt.date() <= sim_params.period_end.date() ] generated_events = date_sorted_sources(generated_trades, benchmark_returns) # this approximates the loop inside TradingSimulationClient transactions = [] for dt, events in itertools.groupby(generated_events, operator.attrgetter('dt')): for event in events: if event.type == DATASOURCE_TYPE.TRADE: for txn, order in blotter.process_trade(event): transactions.append(txn) tracker.process_transaction(txn) elif event.type == DATASOURCE_TYPE.BENCHMARK: tracker.process_benchmark(event) elif event.type == DATASOURCE_TYPE.TRADE: tracker.process_trade(event) if complete_fill: self.assertEqual(len(transactions), len(order_list)) total_volume = 0 for i in range(len(transactions)): txn = transactions[i] total_volume += txn.amount if complete_fill: order = order_list[i] self.assertEqual(order.amount, txn.amount) self.assertEqual(total_volume, expected_txn_volume) self.assertEqual(len(transactions), expected_txn_count) cumulative_pos = tracker.cumulative_performance.positions[sid] self.assertEqual(total_volume, cumulative_pos.amount) # the open orders should not contain sid. oo = blotter.open_orders self.assertNotIn(sid, oo, "Entry is removed when no open orders")
class TradingAlgorithm(object): """ Base class for trading algorithms. Inherit and overload initialize() and handle_data(data). A new algorithm could look like this: ``` from zipline.api import order def initialize(context): context.sid = 'AAPL' context.amount = 100 def handle_data(self, data): sid = context.sid amount = context.amount order(sid, amount) ``` To then to run this algorithm pass these functions to TradingAlgorithm: my_algo = TradingAlgorithm(initialize, handle_data) stats = my_algo.run(data) """ def __init__(self, *args, **kwargs): """Initialize sids and other state variables. :Arguments: :Optional: initialize : function Function that is called with a single argument at the begninning of the simulation. handle_data : function Function that is called with 2 arguments (context and data) on every bar. script : str Algoscript that contains initialize and handle_data function definition. data_frequency : str (daily, hourly or minutely) The duration of the bars. annualizer : int <optional> Which constant to use for annualizing risk metrics. If not provided, will extract from data_frequency. capital_base : float <default: 1.0e5> How much capital to start with. instant_fill : bool <default: False> Whether to fill orders immediately or on next bar. """ self.datetime = None self.registered_transforms = {} self.transforms = [] self.sources = [] self._recorded_vars = {} self.logger = None self.benchmark_return_source = None self.perf_tracker = None # default components for transact self.slippage = VolumeShareSlippage() self.commission = PerShare() if 'data_frequency' in kwargs: self.set_data_frequency(kwargs.pop('data_frequency')) else: self.data_frequency = None self.instant_fill = kwargs.pop('instant_fill', False) # Override annualizer if set if 'annualizer' in kwargs: self.annualizer = kwargs['annualizer'] # set the capital base self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE) self.sim_params = kwargs.pop('sim_params', None) if self.sim_params: if self.data_frequency is None: self.data_frequency = self.sim_params.data_frequency else: self.sim_params.data_frequency = self.data_frequency self.perf_tracker = PerformanceTracker(self.sim_params) self.blotter = kwargs.pop('blotter', None) if not self.blotter: self.blotter = Blotter() self.portfolio_needs_update = True self._portfolio = None self.history_container = None self.history_specs = {} # If string is passed in, execute and get reference to # functions. self.algoscript = kwargs.pop('script', None) self._initialize = None if self.algoscript is not None: self.ns = {} exec_(self.algoscript, self.ns) if 'initialize' not in self.ns: raise ValueError('You must define an initialze function.') if 'handle_data' not in self.ns: raise ValueError('You must define a handle_data function.') self._initialize = self.ns['initialize'] self._handle_data = self.ns['handle_data'] # If two functions are passed in assume initialize and # handle_data are passed in. elif kwargs.get('initialize', False) and kwargs.get('handle_data'): if self.algoscript is not None: raise ValueError('You can not set script and \ initialize/handle_data.') self._initialize = kwargs.pop('initialize') self._handle_data = kwargs.pop('handle_data') if self._initialize is None: self._initialize = lambda x: None # an algorithm subclass needs to set initialized to True when # it is fully initialized. self.initialized = False self.initialize(*args, **kwargs) def initialize(self, *args, **kwargs): """ Call self._initialize with `self` made available to Zipline API functions. """ with ZiplineAPI(self): self._initialize(self) def handle_data(self, data): if self.history_container: self.history_container.update(data, self.datetime) self._handle_data(self, data) def __repr__(self): """ N.B. this does not yet represent a string that can be used to instantiate an exact copy of an algorithm. However, it is getting close, and provides some value as something that can be inspected interactively. """ return """ {class_name}( capital_base={capital_base} sim_params={sim_params}, initialized={initialized}, slippage={slippage}, commission={commission}, blotter={blotter}, recorded_vars={recorded_vars}) """.strip().format(class_name=self.__class__.__name__, capital_base=self.capital_base, sim_params=repr(self.sim_params), initialized=self.initialized, slippage=repr(self.slippage), commission=repr(self.commission), blotter=repr(self.blotter), recorded_vars=repr(self.recorded_vars)) def _create_data_generator(self, source_filter, sim_params): """ Create a merged data generator using the sources and transforms attached to this algorithm. ::source_filter:: is a method that receives events in date sorted order, and returns True for those events that should be processed by the zipline, and False for those that should be skipped. """ if self.benchmark_return_source is None: env = trading.environment if (self.data_frequency == 'minute' or sim_params.emission_rate == 'minute'): update_time = lambda date: env.get_open_and_close(date)[1] else: update_time = lambda date: date benchmark_return_source = [ Event({ 'dt': update_time(dt), 'returns': ret, 'type': zipline.protocol.DATASOURCE_TYPE.BENCHMARK, 'source_id': 'benchmarks' }) for dt, ret in trading.environment.benchmark_returns.iterkv() if dt.date() >= sim_params.period_start.date() and dt.date() <= sim_params.period_end.date() ] else: benchmark_return_source = self.benchmark_return_source date_sorted = date_sorted_sources(*self.sources) if source_filter: date_sorted = filter(source_filter, date_sorted) with_tnfms = sequential_transforms(date_sorted, *self.transforms) with_benchmarks = date_sorted_sources(benchmark_return_source, with_tnfms) # Group together events with the same dt field. This depends on the # events already being sorted. return groupby(with_benchmarks, attrgetter('dt')) def _create_generator(self, sim_params, source_filter=None): """ Create a basic generator setup using the sources and transforms attached to this algorithm. ::source_filter:: is a method that receives events in date sorted order, and returns True for those events that should be processed by the zipline, and False for those that should be skipped. """ sim_params.data_frequency = self.data_frequency # perf_tracker will be instantiated in __init__ if a sim_params # is passed to the constructor. If not, we instantiate here. if self.perf_tracker is None: self.perf_tracker = PerformanceTracker(sim_params) self.data_gen = self._create_data_generator(source_filter, sim_params) self.trading_client = AlgorithmSimulator(self, sim_params) transact_method = transact_partial(self.slippage, self.commission) self.set_transact(transact_method) return self.trading_client.transform(self.data_gen) def get_generator(self): """ Override this method to add new logic to the construction of the generator. Overrides can use the _create_generator method to get a standard construction generator. """ return self._create_generator(self.sim_params) # TODO: make a new subclass, e.g. BatchAlgorithm, and move # the run method to the subclass, and refactor to put the # generator creation logic into get_generator. def run(self, source, sim_params=None, benchmark_return_source=None): """Run the algorithm. :Arguments: source : can be either: - pandas.DataFrame - zipline source - list of zipline sources If pandas.DataFrame is provided, it must have the following structure: * column names must consist of ints representing the different sids * index must be DatetimeIndex * array contents should be price info. :Returns: daily_stats : pandas.DataFrame Daily performance metrics such as returns, alpha etc. """ if isinstance(source, (list, tuple)): assert self.sim_params is not None or sim_params is not None, \ """When providing a list of sources, \ sim_params have to be specified as a parameter or in the constructor.""" elif isinstance(source, pd.DataFrame): # if DataFrame provided, wrap in DataFrameSource source = DataFrameSource(source) elif isinstance(source, pd.Panel): source = DataPanelSource(source) if not isinstance(source, (list, tuple)): self.sources = [source] else: self.sources = source # Check for override of sim_params. # If it isn't passed to this function, # use the default params set with the algorithm. # Else, we create simulation parameters using the start and end of the # source provided. if sim_params is None: if self.sim_params is None: start = source.start end = source.end sim_params = create_simulation_parameters( start=start, end=end, capital_base=self.capital_base, ) else: sim_params = self.sim_params # update sim params to ensure it's set self.sim_params = sim_params if self.sim_params.sids is None: all_sids = [sid for s in self.sources for sid in s.sids] self.sim_params.sids = set(all_sids) # Create history containers if len(self.history_specs) != 0: self.history_container = HistoryContainer( self.history_specs, self.sim_params.sids, self.sim_params.first_open) # Create transforms by wrapping them into StatefulTransforms self.transforms = [] for namestring, trans_descr in iteritems(self.registered_transforms): sf = StatefulTransform(trans_descr['class'], *trans_descr['args'], **trans_descr['kwargs']) sf.namestring = namestring self.transforms.append(sf) # force a reset of the performance tracker, in case # this is a repeat run of the algorithm. self.perf_tracker = None # create transforms and zipline self.gen = self._create_generator(sim_params) with ZiplineAPI(self): # loop through simulated_trading, each iteration returns a # perf dictionary perfs = [] for perf in self.gen: perfs.append(perf) # convert perf dict to pandas dataframe daily_stats = self._create_daily_stats(perfs) return daily_stats def _create_daily_stats(self, perfs): # create daily and cumulative stats dataframe daily_perfs = [] # TODO: the loop here could overwrite expected properties # of daily_perf. Could potentially raise or log a # warning. for perf in perfs: if 'daily_perf' in perf: perf['daily_perf'].update( perf['daily_perf'].pop('recorded_vars')) daily_perfs.append(perf['daily_perf']) else: self.risk_report = perf daily_dts = [ np.datetime64(perf['period_close'], utc=True) for perf in daily_perfs ] daily_stats = pd.DataFrame(daily_perfs, index=daily_dts) return daily_stats def add_transform(self, transform_class, tag, *args, **kwargs): """Add a single-sid, sequential transform to the model. :Arguments: transform_class : class Which transform to use. E.g. mavg. tag : str How to name the transform. Can later be access via: data[sid].tag() Extra args and kwargs will be forwarded to the transform instantiation. """ self.registered_transforms[tag] = { 'class': transform_class, 'args': args, 'kwargs': kwargs } @api_method def record(self, **kwargs): """ Track and record local variable (i.e. attributes) each day. """ for name, value in kwargs.items(): self._recorded_vars[name] = value @api_method def order(self, sid, amount, limit_price=None, stop_price=None, style=None): """ Place an order using the specified parameters. """ # Raises a ZiplineError if invalid parameters are detected. self.validate_order_params(sid, amount, limit_price, stop_price, style) # Convert deprecated limit_price and stop_price parameters to use # ExecutionStyle objects. style = self.__convert_order_params_for_blotter( limit_price, stop_price, style) return self.blotter.order(sid, amount, style) def validate_order_params(self, sid, amount, limit_price, stop_price, style): """ Helper method for validating parameters to the order API function. Raises an UnsupportedOrderParameters if invalid arguments are found. """ if style: if limit_price: raise UnsupportedOrderParameters( msg="Passing both limit_price and style is not supported.") if stop_price: raise UnsupportedOrderParameters( msg="Passing both stop_price and style is not supported.") @staticmethod def __convert_order_params_for_blotter(limit_price, stop_price, style): """ Helper method for converting deprecated limit_price and stop_price arguments into ExecutionStyle instances. This function assumes that either style == None or (limit_price, stop_price) == (None, None). """ # TODO_SS: DeprecationWarning for usage of limit_price and stop_price. if style: assert (limit_price, stop_price) == (None, None) return style if limit_price and stop_price: return StopLimitOrder(limit_price, stop_price) if limit_price: return LimitOrder(limit_price) if stop_price: return StopOrder(stop_price) else: return MarketOrder() @api_method def order_value(self, sid, value, limit_price=None, stop_price=None, style=None): """ Place an order by desired value rather than desired number of shares. If the requested sid is found in the universe, the requested value is divided by its price to imply the number of shares to transact. value > 0 :: Buy/Cover value < 0 :: Sell/Short Market order: order(sid, value) Limit order: order(sid, value, limit_price) Stop order: order(sid, value, None, stop_price) StopLimit order: order(sid, value, limit_price, stop_price) """ last_price = self.trading_client.current_data[sid].price if np.allclose(last_price, 0): zero_message = "Price of 0 for {psid}; can't infer value".format( psid=sid) if self.logger: self.logger.debug(zero_message) # Don't place any order return else: amount = value / last_price return self.order(sid, amount, limit_price=limit_price, stop_price=stop_price, style=style) @property def recorded_vars(self): return copy(self._recorded_vars) @property def portfolio(self): # internally this will cause a refresh of the # period performance calculations. return self.perf_tracker.get_portfolio() def updated_portfolio(self): # internally this will cause a refresh of the # period performance calculations. if self.portfolio_needs_update: self._portfolio = self.perf_tracker.get_portfolio() self.portfolio_needs_update = False return self._portfolio def set_logger(self, logger): self.logger = logger def set_datetime(self, dt): assert isinstance(dt, datetime), \ "Attempt to set algorithm's current time with non-datetime" assert dt.tzinfo == pytz.utc, \ "Algorithm expects a utc datetime" self.datetime = dt @api_method def get_datetime(self): """ Returns a copy of the datetime. """ date_copy = copy(self.datetime) assert date_copy.tzinfo == pytz.utc, \ "Algorithm should have a utc datetime" return date_copy def set_transact(self, transact): """ Set the method that will be called to create a transaction from open orders and trade events. """ self.blotter.transact = transact @api_method def set_slippage(self, slippage): if not isinstance(slippage, SlippageModel): raise UnsupportedSlippageModel() if self.initialized: raise OverrideSlippagePostInit() self.slippage = slippage @api_method def set_commission(self, commission): if not isinstance(commission, (PerShare, PerTrade, PerDollar)): raise UnsupportedCommissionModel() if self.initialized: raise OverrideCommissionPostInit() self.commission = commission def set_sources(self, sources): assert isinstance(sources, list) self.sources = sources def set_transforms(self, transforms): assert isinstance(transforms, list) self.transforms = transforms def set_data_frequency(self, data_frequency): assert data_frequency in ('daily', 'minute') self.data_frequency = data_frequency self.annualizer = ANNUALIZER[self.data_frequency] @api_method def order_percent(self, sid, percent, limit_price=None, stop_price=None, style=None): """ Place an order in the specified security corresponding to the given percent of the current portfolio value. Note that percent must expressed as a decimal (0.50 means 50\%). """ value = self.portfolio.portfolio_value * percent return self.order_value(sid, value, limit_price=limit_price, stop_price=stop_price, style=style) @api_method def order_target(self, sid, target, limit_price=None, stop_price=None, style=None): """ Place an order to adjust a position to a target number of shares. If the position doesn't already exist, this is equivalent to placing a new order. If the position does exist, this is equivalent to placing an order for the difference between the target number of shares and the current number of shares. """ if sid in self.portfolio.positions: current_position = self.portfolio.positions[sid].amount req_shares = target - current_position return self.order(sid, req_shares, limit_price=limit_price, stop_price=stop_price, style=style) else: return self.order(sid, target, limit_price=limit_price, stop_price=stop_price, style=style) @api_method def order_target_value(self, sid, target, limit_price=None, stop_price=None, style=None): """ Place an order to adjust a position to a target value. If the position doesn't already exist, this is equivalent to placing a new order. If the position does exist, this is equivalent to placing an order for the difference between the target value and the current value. """ if sid in self.portfolio.positions: current_position = self.portfolio.positions[sid].amount current_price = self.trading_client.current_data[sid].price current_value = current_position * current_price req_value = target - current_value return self.order_value(sid, req_value, limit_price=limit_price, stop_price=stop_price, style=style) else: return self.order_value(sid, target, limit_price=limit_price, stop_price=stop_price, style=style) @api_method def order_target_percent(self, sid, target, limit_price=None, stop_price=None, style=None): """ Place an order to adjust a position to a target percent of the current portfolio value. If the position doesn't already exist, this is equivalent to placing a new order. If the position does exist, this is equivalent to placing an order for the difference between the target percent and the current percent. Note that target must expressed as a decimal (0.50 means 50\%). """ if sid in self.portfolio.positions: current_position = self.portfolio.positions[sid].amount current_price = self.trading_client.current_data[sid].price current_value = current_position * current_price else: current_value = 0 target_value = self.portfolio.portfolio_value * target req_value = target_value - current_value return self.order_value(sid, req_value, limit_price=limit_price, stop_price=stop_price, style=style) @api_method def get_open_orders(self, sid=None): if sid is None: return { key: [order.to_api_obj() for order in orders] for key, orders in self.blotter.open_orders.iteritems() } if sid in self.blotter.open_orders: orders = self.blotter.open_orders[sid] return [order.to_api_obj() for order in orders] return [] @api_method def get_order(self, order_id): if order_id in self.blotter.orders: return self.blotter.orders[order_id].to_api_obj() @api_method def cancel_order(self, order_param): order_id = order_param if isinstance(order_param, zipline.protocol.Order): order_id = order_param.id self.blotter.cancel(order_id) def raw_positions(self): """ Returns the current portfolio for the algorithm. N.B. this is not done as a property, so that the function can be passed and called from within a source. """ # Return the 'internal' positions object, as in the one that is # not passed to the algo, and thus should not have tainted keys. return self.perf_tracker.cumulative_performance.positions def raw_orders(self): """ Returns the current open orders from the blotter. N.B. this is not a property, so that the function can be passed and called back from within a source. """ return self.blotter.open_orders @api_method def add_history(self, bar_count, frequency, field, ffill=True): history_spec = HistorySpec(bar_count, frequency, field, ffill) self.history_specs[history_spec.key_str] = history_spec @api_method def history(self, bar_count, frequency, field, ffill=True): spec_key_str = HistorySpec.spec_key(bar_count, frequency, field, ffill) history_spec = self.history_specs[spec_key_str] return self.history_container.get_history(history_spec, self.datetime)
class TradingAlgorithm(object): """ Base class for trading algorithms. Inherit and overload initialize() and handle_data(data). A new algorithm could look like this: ``` from zipline.api import order def initialize(context): context.sid = 'AAPL' context.amount = 100 def handle_data(self, data): sid = context.sid amount = context.amount order(sid, amount) ``` To then to run this algorithm pass these functions to TradingAlgorithm: my_algo = TradingAlgorithm(initialize, handle_data) stats = my_algo.run(data) """ # If this is set to false then it is the responsibility # of the overriding subclass to set initialized = true AUTO_INITIALIZE = True def __init__(self, *args, **kwargs): """Initialize sids and other state variables. :Arguments: :Optional: initialize : function Function that is called with a single argument at the begninning of the simulation. handle_data : function Function that is called with 2 arguments (context and data) on every bar. script : str Algoscript that contains initialize and handle_data function definition. data_frequency : str (daily, hourly or minutely) The duration of the bars. capital_base : float <default: 1.0e5> How much capital to start with. instant_fill : bool <default: False> Whether to fill orders immediately or on next bar. """ self.datetime = None self.registered_transforms = {} self.transforms = [] self.sources = [] # List of trading controls to be used to validate orders. self.trading_controls = [] self._recorded_vars = {} self.namespace = kwargs.get('namespace', {}) self.logger = None self.benchmark_return_source = None # default components for transact self.slippage = VolumeShareSlippage() self.commission = PerShare() self.instant_fill = kwargs.pop('instant_fill', False) # set the capital base self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE) self.sim_params = kwargs.pop('sim_params', None) if self.sim_params is None: self.sim_params = create_simulation_parameters( capital_base=self.capital_base ) self.perf_tracker = PerformanceTracker(self.sim_params) self.blotter = kwargs.pop('blotter', None) if not self.blotter: self.blotter = Blotter() self.portfolio_needs_update = True self._portfolio = None self.history_container = None self.history_specs = {} # If string is passed in, execute and get reference to # functions. self.algoscript = kwargs.pop('script', None) self._initialize = None self._analyze = None if self.algoscript is not None: exec_(self.algoscript, self.namespace) self._initialize = self.namespace.get('initialize', None) if 'handle_data' not in self.namespace: raise ValueError('You must define a handle_data function.') else: self._handle_data = self.namespace['handle_data'] # Optional analyze function, gets called after run self._analyze = self.namespace.get('analyze', None) elif kwargs.get('initialize', False) and kwargs.get('handle_data'): if self.algoscript is not None: raise ValueError('You can not set script and \ initialize/handle_data.') self._initialize = kwargs.pop('initialize') self._handle_data = kwargs.pop('handle_data') # If method not defined, NOOP if self._initialize is None: self._initialize = lambda x: None # Alternative way of setting data_frequency for backwards # compatibility. if 'data_frequency' in kwargs: self.data_frequency = kwargs.pop('data_frequency') # Subclasses that override initialize should only worry about # setting self.initialized = True if AUTO_INITIALIZE is # is manually set to False. self.initialized = False self.initialize(*args, **kwargs) if self.AUTO_INITIALIZE: self.initialized = True def initialize(self, *args, **kwargs): """ Call self._initialize with `self` made available to Zipline API functions. """ with ZiplineAPI(self): self._initialize(self) def handle_data(self, data): if self.history_container: self.history_container.update(data, self.datetime) self._handle_data(self, data) def analyze(self, perf): if self._analyze is None: return with ZiplineAPI(self): self._analyze(self, perf) def __repr__(self): """ N.B. this does not yet represent a string that can be used to instantiate an exact copy of an algorithm. However, it is getting close, and provides some value as something that can be inspected interactively. """ return """ {class_name}( capital_base={capital_base} sim_params={sim_params}, initialized={initialized}, slippage={slippage}, commission={commission}, blotter={blotter}, recorded_vars={recorded_vars}) """.strip().format(class_name=self.__class__.__name__, capital_base=self.capital_base, sim_params=repr(self.sim_params), initialized=self.initialized, slippage=repr(self.slippage), commission=repr(self.commission), blotter=repr(self.blotter), recorded_vars=repr(self.recorded_vars)) def _create_data_generator(self, source_filter, sim_params=None): """ Create a merged data generator using the sources and transforms attached to this algorithm. ::source_filter:: is a method that receives events in date sorted order, and returns True for those events that should be processed by the zipline, and False for those that should be skipped. """ if sim_params is None: sim_params = self.sim_params if self.benchmark_return_source is None: env = trading.environment if (sim_params.data_frequency == 'minute' or sim_params.emission_rate == 'minute'): update_time = lambda date: env.get_open_and_close(date)[1] else: update_time = lambda date: date benchmark_return_source = [ Event({'dt': update_time(dt), 'returns': ret, 'type': zipline.protocol.DATASOURCE_TYPE.BENCHMARK, 'source_id': 'benchmarks'}) for dt, ret in trading.environment.benchmark_returns.iterkv() if dt.date() >= sim_params.period_start.date() and dt.date() <= sim_params.period_end.date() ] else: benchmark_return_source = self.benchmark_return_source date_sorted = date_sorted_sources(*self.sources) if source_filter: date_sorted = filter(source_filter, date_sorted) with_tnfms = sequential_transforms(date_sorted, *self.transforms) with_benchmarks = date_sorted_sources(benchmark_return_source, with_tnfms) # Group together events with the same dt field. This depends on the # events already being sorted. return groupby(with_benchmarks, attrgetter('dt')) def _create_generator(self, sim_params, source_filter=None): """ Create a basic generator setup using the sources and transforms attached to this algorithm. ::source_filter:: is a method that receives events in date sorted order, and returns True for those events that should be processed by the zipline, and False for those that should be skipped. """ # Instantiate perf_tracker self.perf_tracker = PerformanceTracker(sim_params) self.portfolio_needs_update = True self.data_gen = self._create_data_generator(source_filter, sim_params) self.trading_client = AlgorithmSimulator(self, sim_params) transact_method = transact_partial(self.slippage, self.commission) self.set_transact(transact_method) return self.trading_client.transform(self.data_gen) def get_generator(self): """ Override this method to add new logic to the construction of the generator. Overrides can use the _create_generator method to get a standard construction generator. """ return self._create_generator(self.sim_params) # TODO: make a new subclass, e.g. BatchAlgorithm, and move # the run method to the subclass, and refactor to put the # generator creation logic into get_generator. def run(self, source, overwrite_sim_params=True, benchmark_return_source=None): """Run the algorithm. :Arguments: source : can be either: - pandas.DataFrame - zipline source - list of sources If pandas.DataFrame is provided, it must have the following structure: * column names must consist of ints representing the different sids * index must be DatetimeIndex * array contents should be price info. :Returns: daily_stats : pandas.DataFrame Daily performance metrics such as returns, alpha etc. """ if isinstance(source, list): if overwrite_sim_params: warnings.warn("""List of sources passed, will not attempt to extract sids, and start and end dates. Make sure to set the correct fields in sim_params passed to __init__().""", UserWarning) overwrite_sim_params = False elif isinstance(source, pd.DataFrame): # if DataFrame provided, wrap in DataFrameSource source = DataFrameSource(source) elif isinstance(source, pd.Panel): source = DataPanelSource(source) if isinstance(source, list): self.set_sources(source) else: self.set_sources([source]) # Override sim_params if params are provided by the source. if overwrite_sim_params: if hasattr(source, 'start'): self.sim_params.period_start = source.start if hasattr(source, 'end'): self.sim_params.period_end = source.end all_sids = [sid for s in self.sources for sid in s.sids] self.sim_params.sids = set(all_sids) # Changing period_start and period_close might require updating # of first_open and last_close. self.sim_params._update_internal() # Create history containers if len(self.history_specs) != 0: self.history_container = HistoryContainer( self.history_specs, self.sim_params.sids, self.sim_params.first_open) # Create transforms by wrapping them into StatefulTransforms self.transforms = [] for namestring, trans_descr in iteritems(self.registered_transforms): sf = StatefulTransform( trans_descr['class'], *trans_descr['args'], **trans_descr['kwargs'] ) sf.namestring = namestring self.transforms.append(sf) # force a reset of the performance tracker, in case # this is a repeat run of the algorithm. self.perf_tracker = None # create transforms and zipline self.gen = self._create_generator(self.sim_params) with ZiplineAPI(self): # loop through simulated_trading, each iteration returns a # perf dictionary perfs = [] for perf in self.gen: perfs.append(perf) # convert perf dict to pandas dataframe daily_stats = self._create_daily_stats(perfs) self.analyze(daily_stats) return daily_stats def _create_daily_stats(self, perfs): # create daily and cumulative stats dataframe daily_perfs = [] # TODO: the loop here could overwrite expected properties # of daily_perf. Could potentially raise or log a # warning. for perf in perfs: if 'daily_perf' in perf: perf['daily_perf'].update( perf['daily_perf'].pop('recorded_vars') ) daily_perfs.append(perf['daily_perf']) else: self.risk_report = perf daily_dts = [np.datetime64(perf['period_close'], utc=True) for perf in daily_perfs] daily_stats = pd.DataFrame(daily_perfs, index=daily_dts) return daily_stats def add_transform(self, transform_class, tag, *args, **kwargs): """Add a single-sid, sequential transform to the model. :Arguments: transform_class : class Which transform to use. E.g. mavg. tag : str How to name the transform. Can later be access via: data[sid].tag() Extra args and kwargs will be forwarded to the transform instantiation. """ self.registered_transforms[tag] = {'class': transform_class, 'args': args, 'kwargs': kwargs} @api_method def record(self, *args, **kwargs): """ Track and record local variable (i.e. attributes) each day. """ # Make 2 objects both referencing the same iterator args = [iter(args)] * 2 # Zip generates list entries by calling `next` on each iterator it # receives. In this case the two iterators are the same object, so the # call to next on args[0] will also advance args[1], resulting in zip # returning (a,b) (c,d) (e,f) rather than (a,a) (b,b) (c,c) etc. positionals = zip(*args) for name, value in chain(positionals, iteritems(kwargs)): self._recorded_vars[name] = value @api_method def symbol(self, symbol_str, as_of_date=None): """ Default symbol lookup for any source that directly maps the symbol to the identifier (e.g. yahoo finance). Keyword argument as_of_date is ignored. """ return symbol_str @api_method def order(self, sid, amount, limit_price=None, stop_price=None, style=None): """ Place an order using the specified parameters. """ def round_if_near_integer(a, epsilon=1e-4): """ Round a to the nearest integer if that integer is within an epsilon of a. """ if abs(a - round(a)) <= epsilon: return round(a) else: return a # Truncate to the integer share count that's either within .0001 of # amount or closer to zero. # E.g. 3.9999 -> 4.0; 5.5 -> 5.0; -5.5 -> -5.0 amount = int(round_if_near_integer(amount)) # Raises a ZiplineError if invalid parameters are detected. self.validate_order_params(sid, amount, limit_price, stop_price, style) # Convert deprecated limit_price and stop_price parameters to use # ExecutionStyle objects. style = self.__convert_order_params_for_blotter(limit_price, stop_price, style) return self.blotter.order(sid, amount, style) def validate_order_params(self, sid, amount, limit_price, stop_price, style): """ Helper method for validating parameters to the order API function. Raises an UnsupportedOrderParameters if invalid arguments are found. """ if not self.initialized: raise OrderDuringInitialize( msg="order() can only be called from within handle_data()" ) if style: if limit_price: raise UnsupportedOrderParameters( msg="Passing both limit_price and style is not supported." ) if stop_price: raise UnsupportedOrderParameters( msg="Passing both stop_price and style is not supported." ) for control in self.trading_controls: control.validate(sid, amount, self.updated_portfolio(), self.get_datetime(), self.trading_client.current_data) @staticmethod def __convert_order_params_for_blotter(limit_price, stop_price, style): """ Helper method for converting deprecated limit_price and stop_price arguments into ExecutionStyle instances. This function assumes that either style == None or (limit_price, stop_price) == (None, None). """ # TODO_SS: DeprecationWarning for usage of limit_price and stop_price. if style: assert (limit_price, stop_price) == (None, None) return style if limit_price and stop_price: return StopLimitOrder(limit_price, stop_price) if limit_price: return LimitOrder(limit_price) if stop_price: return StopOrder(stop_price) else: return MarketOrder() @api_method def order_value(self, sid, value, limit_price=None, stop_price=None, style=None): """ Place an order by desired value rather than desired number of shares. If the requested sid is found in the universe, the requested value is divided by its price to imply the number of shares to transact. value > 0 :: Buy/Cover value < 0 :: Sell/Short Market order: order(sid, value) Limit order: order(sid, value, limit_price) Stop order: order(sid, value, None, stop_price) StopLimit order: order(sid, value, limit_price, stop_price) """ last_price = self.trading_client.current_data[sid].price if np.allclose(last_price, 0): zero_message = "Price of 0 for {psid}; can't infer value".format( psid=sid ) if self.logger: self.logger.debug(zero_message) # Don't place any order return else: amount = value / last_price return self.order(sid, amount, limit_price=limit_price, stop_price=stop_price, style=style) @property def recorded_vars(self): return copy(self._recorded_vars) @property def portfolio(self): return self.updated_portfolio() def updated_portfolio(self): if self.portfolio_needs_update: self._portfolio = self.perf_tracker.get_portfolio() self.portfolio_needs_update = False return self._portfolio def set_logger(self, logger): self.logger = logger def on_dt_changed(self, dt): """ Callback triggered by the simulation loop whenever the current dt changes. Any logic that should happen exactly once at the start of each datetime group should happen here. """ assert isinstance(dt, datetime), \ "Attempt to set algorithm's current time with non-datetime" assert dt.tzinfo == pytz.utc, \ "Algorithm expects a utc datetime" self.datetime = dt self.perf_tracker.set_date(dt) self.blotter.set_date(dt) @api_method def get_datetime(self): """ Returns a copy of the datetime. """ date_copy = copy(self.datetime) assert date_copy.tzinfo == pytz.utc, \ "Algorithm should have a utc datetime" return date_copy def set_transact(self, transact): """ Set the method that will be called to create a transaction from open orders and trade events. """ self.blotter.transact = transact def update_dividends(self, dividend_frame): """ Set DataFrame used to process dividends. DataFrame columns should contain at least the entries in zp.DIVIDEND_FIELDS. """ self.perf_tracker.update_dividends(dividend_frame) @api_method def set_slippage(self, slippage): if not isinstance(slippage, SlippageModel): raise UnsupportedSlippageModel() if self.initialized: raise OverrideSlippagePostInit() self.slippage = slippage @api_method def set_commission(self, commission): if not isinstance(commission, (PerShare, PerTrade, PerDollar)): raise UnsupportedCommissionModel() if self.initialized: raise OverrideCommissionPostInit() self.commission = commission def set_sources(self, sources): assert isinstance(sources, list) self.sources = sources def set_transforms(self, transforms): assert isinstance(transforms, list) self.transforms = transforms # Remain backwards compatibility @property def data_frequency(self): return self.sim_params.data_frequency @data_frequency.setter def data_frequency(self, value): assert value in ('daily', 'minute') self.sim_params.data_frequency = value @api_method def order_percent(self, sid, percent, limit_price=None, stop_price=None, style=None): """ Place an order in the specified security corresponding to the given percent of the current portfolio value. Note that percent must expressed as a decimal (0.50 means 50\%). """ value = self.portfolio.portfolio_value * percent return self.order_value(sid, value, limit_price=limit_price, stop_price=stop_price, style=style) @api_method def order_target(self, sid, target, limit_price=None, stop_price=None, style=None): """ Place an order to adjust a position to a target number of shares. If the position doesn't already exist, this is equivalent to placing a new order. If the position does exist, this is equivalent to placing an order for the difference between the target number of shares and the current number of shares. """ if sid in self.portfolio.positions: current_position = self.portfolio.positions[sid].amount req_shares = target - current_position return self.order(sid, req_shares, limit_price=limit_price, stop_price=stop_price, style=style) else: return self.order(sid, target, limit_price=limit_price, stop_price=stop_price, style=style) @api_method def order_target_value(self, sid, target, limit_price=None, stop_price=None, style=None): """ Place an order to adjust a position to a target value. If the position doesn't already exist, this is equivalent to placing a new order. If the position does exist, this is equivalent to placing an order for the difference between the target value and the current value. """ last_price = self.trading_client.current_data[sid].price if np.allclose(last_price, 0): # Don't place an order if self.logger: zero_message = "Price of 0 for {psid}; can't infer value" self.logger.debug(zero_message.format(psid=sid)) return target_amount = target / last_price return self.order_target(sid, target_amount, limit_price=limit_price, stop_price=stop_price, style=style) @api_method def order_target_percent(self, sid, target, limit_price=None, stop_price=None, style=None): """ Place an order to adjust a position to a target percent of the current portfolio value. If the position doesn't already exist, this is equivalent to placing a new order. If the position does exist, this is equivalent to placing an order for the difference between the target percent and the current percent. Note that target must expressed as a decimal (0.50 means 50\%). """ target_value = self.portfolio.portfolio_value * target return self.order_target_value(sid, target_value, limit_price=limit_price, stop_price=stop_price, style=style) @api_method def get_open_orders(self, sid=None): if sid is None: return { key: [order.to_api_obj() for order in orders] for key, orders in iteritems(self.blotter.open_orders) if orders } if sid in self.blotter.open_orders: orders = self.blotter.open_orders[sid] return [order.to_api_obj() for order in orders] return [] @api_method def get_order(self, order_id): if order_id in self.blotter.orders: return self.blotter.orders[order_id].to_api_obj() @api_method def cancel_order(self, order_param): order_id = order_param if isinstance(order_param, zipline.protocol.Order): order_id = order_param.id self.blotter.cancel(order_id) @api_method def add_history(self, bar_count, frequency, field, ffill=True): daily_at_midnight = (self.sim_params.data_frequency == 'daily') history_spec = HistorySpec(bar_count, frequency, field, ffill, daily_at_midnight=daily_at_midnight) self.history_specs[history_spec.key_str] = history_spec @api_method def history(self, bar_count, frequency, field, ffill=True): spec_key_str = HistorySpec.spec_key( bar_count, frequency, field, ffill) history_spec = self.history_specs[spec_key_str] return self.history_container.get_history(history_spec, self.datetime) #################### # Trading Controls # #################### def register_trading_control(self, control): """ Register a new TradingControl to be checked prior to order calls. """ if self.initialized: raise RegisterTradingControlPostInit() self.trading_controls.append(control) @api_method def set_max_position_size(self, sid=None, max_shares=None, max_notional=None): """ Set a limit on the number of shares and/or dollar value held for the given sid. Limits are treated as absolute values and are enforced at the time that the algo attempts to place an order for sid. This means that it's possible to end up with more than the max number of shares due to splits/dividends, and more than the max notional due to price improvement. If an algorithm attempts to place an order that would result in increasing the absolute value of shares/dollar value exceeding one of these limits, raise a TradingControlException. """ control = MaxPositionSize(sid=sid, max_shares=max_shares, max_notional=max_notional) self.register_trading_control(control) @api_method def set_max_order_size(self, sid=None, max_shares=None, max_notional=None): """ Set a limit on the number of shares and/or dollar value of any single order placed for sid. Limits are treated as absolute values and are enforced at the time that the algo attempts to place an order for sid. If an algorithm attempts to place an order that would result in exceeding one of these limits, raise a TradingControlException. """ control = MaxOrderSize(sid=sid, max_shares=max_shares, max_notional=max_notional) self.register_trading_control(control) @api_method def set_max_order_count(self, max_count): """ Set a limit on the number of orders that can be placed within the given time interval. """ control = MaxOrderCount(max_count) self.register_trading_control(control) @api_method def set_long_only(self): """ Set a rule specifying that this algorithm cannot take short positions. """ self.register_trading_control(LongOnly()) @classmethod def all_api_methods(cls): """ Return a list of all the TradingAlgorithm API methods. """ return [fn for fn in cls.__dict__.itervalues() if getattr(fn, 'is_api_method', False)]
def transaction_sim(self, **params): """ This is a utility method that asserts expected results for conversion of orders to transactions given a trade history""" trade_count = params['trade_count'] trade_interval = params['trade_interval'] trade_delay = params.get('trade_delay') order_count = params['order_count'] order_amount = params['order_amount'] order_interval = params['order_interval'] expected_txn_count = params['expected_txn_count'] expected_txn_volume = params['expected_txn_volume'] # optional parameters # --------------------- # if present, alternate between long and short sales alternate = params.get('alternate') # if present, expect transaction amounts to match orders exactly. complete_fill = params.get('complete_fill') sid = 1 sim_params = factory.create_simulation_parameters() trade_sim = TransactionSimulator() price = [10.1] * trade_count volume = [100] * trade_count start_date = sim_params.first_open generated_trades = factory.create_trade_history( sid, price, volume, trade_interval, sim_params ) if alternate: alternator = -1 else: alternator = 1 order_date = start_date for i in xrange(order_count): order = ndict({ 'sid': sid, 'amount': order_amount * alternator ** i, 'dt': order_date }) trade_sim.place_order(order) order_date = order_date + order_interval # move after market orders to just after market next # market open. if order_date.hour >= 21: if order_date.minute >= 00: order_date = order_date + timedelta(days=1) order_date = order_date.replace(hour=14, minute=30) # there should now be one open order list stored under the sid oo = trade_sim.open_orders self.assertEqual(len(oo), 1) self.assertTrue(sid in oo) order_list = oo[sid] self.assertEqual(order_count, len(order_list)) for i in xrange(order_count): order = order_list[i] self.assertEqual(order.sid, sid) self.assertEqual(order.amount, order_amount * alternator ** i) tracker = PerformanceTracker(sim_params) # this approximates the loop inside TradingSimulationClient transactions = [] for trade in generated_trades: if trade_delay: trade.dt = trade.dt + trade_delay trade_sim.update(trade) if trade.TRANSACTION: transactions.append(trade.TRANSACTION) tracker.process_event(trade) if complete_fill: self.assertEqual(len(transactions), len(order_list)) total_volume = 0 for i in xrange(len(transactions)): txn = transactions[i] total_volume += txn.amount if complete_fill: order = order_list[i] self.assertEqual(order.amount, txn.amount) self.assertEqual(total_volume, expected_txn_volume) self.assertEqual(len(transactions), expected_txn_count) cumulative_pos = tracker.cumulative_performance.positions[sid] self.assertEqual(total_volume, cumulative_pos.amount) # the open orders should now be empty oo = trade_sim.open_orders self.assertTrue(sid in oo) order_list = oo[sid] self.assertEqual(0, len(order_list))
class TradeSimulationClient(object): """ Generator-style class that takes the expected output of a merge, a user algorithm, a trading environment, and a simulator slippage as arguments. Pipes the merge stream through a TransactionSimulator and a PerformanceTracker, which keep track of the current state of our algorithm's simulated universe. Results are fed to the user's algorithm, which directly inserts transactions into the TransactionSimulator's order book. TransactionSimulator maintains a dictionary from sids to the as-yet unfilled orders placed by the user's algorithm. As trade events arrive, if the algorithm has open orders against the trade's sid, the simulator will fill orders up to 25% of market cap. Applied transactions are added to a txn field on the event and forwarded to PerformanceTracker. The txn field is set to None on non-trade events and events that do not match any open orders. PerformanceTracker receives the updated event messages from TransactionSimulator, maintaining a set of daily and cumulative performance metrics for the algorithm. The tracker removes the txn field from each event it receives, replacing it with a portfolio field to be fed into the user algo. At the end of each trading day, the PerformanceTracker also generates a daily performance report, which is appended to event's perf_report field. Fully processed events are fed to AlgorithmSimulator, which batches together events with the same dt field into a single snapshot to be fed to the algo. The portfolio object is repeatedly overwritten so that only the most recent snapshot of the universe is sent to the algo. """ def __init__(self, algo, environment): self.algo = algo self.environment = environment self.ordering_client = TransactionSimulator() self.perf_tracker = PerformanceTracker(self.environment) self.algo_start = self.environment.first_open self.algo_sim = AlgorithmSimulator(self.ordering_client, self.algo, self.algo_start) def get_hash(self): """ There should only ever be one TSC in the system, so we don't bother passing args into the hash. """ return self.__class__.__name__ + hash_args() def simulate(self, stream_in): """ Main generator work loop. """ # Simulate filling any open orders made by the previous run of # the user's algorithm. Fills the Transaction field on any # event that results in a filled order. with_filled_orders = self.ordering_client.transform(stream_in) # Pipe the events with transactions to perf. This will remove # the TRANSACTION field added by TransactionSimulator and replace it # with a portfolio field to be passed to the user's # algorithm. Also adds a perf_message field which is usually # none, but contains an update message once per day. with_portfolio = self.perf_tracker.transform(with_filled_orders) # Pass the messages from perf to the user's algorithm for simulation. # Events are batched by dt so that the algo handles all events for a # given timestamp at one one go. performance_messages = self.algo_sim.transform(with_portfolio) # The algorithm will yield a daily_results message (as # calculated by the performance tracker) at the end of each # day. It will also yield a risk report at the end of the # simulation. for message in performance_messages: yield message
class TradingAlgorithm(object): """ Base class for trading algorithms. Inherit and overload initialize() and handle_data(data). A new algorithm could look like this: ``` from zipline.api import order, symbol def initialize(context): context.sid = symbol('AAPL') context.amount = 100 def handle_data(context, data): sid = context.sid amount = context.amount order(sid, amount) ``` To then to run this algorithm pass these functions to TradingAlgorithm: my_algo = TradingAlgorithm(initialize, handle_data) stats = my_algo.run(data) """ def __init__(self, *args, **kwargs): """Initialize sids and other state variables. :Arguments: :Optional: initialize : function Function that is called with a single argument at the begninning of the simulation. handle_data : function Function that is called with 2 arguments (context and data) on every bar. script : str Algoscript that contains initialize and handle_data function definition. data_frequency : {'daily', 'minute'} The duration of the bars. capital_base : float <default: 1.0e5> How much capital to start with. instant_fill : bool <default: False> Whether to fill orders immediately or on next bar. asset_finder : An AssetFinder object A new AssetFinder object to be used in this TradingEnvironment asset_metadata: can be either: - dict - pandas.DataFrame - object with 'read' property If dict is provided, it must have the following structure: * keys are the identifiers * values are dicts containing the metadata, with the metadata field name as the key If pandas.DataFrame is provided, it must have the following structure: * column names must be the metadata fields * index must be the different asset identifiers * array contents should be the metadata value If an object with a 'read' property is provided, 'read' must return rows containing at least one of 'sid' or 'symbol' along with the other metadata fields. identifiers : List Any asset identifiers that are not provided in the asset_metadata, but will be traded by this TradingAlgorithm """ self.sources = [] # List of trading controls to be used to validate orders. self.trading_controls = [] # List of account controls to be checked on each bar. self.account_controls = [] self._recorded_vars = {} self.namespace = kwargs.get('namespace', {}) self._platform = kwargs.pop('platform', 'zipline') self.logger = None self.benchmark_return_source = None # default components for transact self.slippage = VolumeShareSlippage() self.commission = PerShare() self.instant_fill = kwargs.pop('instant_fill', False) # set the capital base self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE) self.sim_params = kwargs.pop('sim_params', None) if self.sim_params is None: self.sim_params = create_simulation_parameters( capital_base=self.capital_base, start=kwargs.pop('start', None), end=kwargs.pop('end', None)) self.perf_tracker = PerformanceTracker(self.sim_params) # Update the TradingEnvironment with the provided asset metadata self.trading_environment = kwargs.pop('env', TradingEnvironment.instance()) self.trading_environment.update_asset_finder( asset_finder=kwargs.pop('asset_finder', None), asset_metadata=kwargs.pop('asset_metadata', None), identifiers=kwargs.pop('identifiers', None)) # Pull in the environment's new AssetFinder for quick reference self.asset_finder = self.trading_environment.asset_finder self.init_engine(kwargs.pop('ffc_loader', None)) # Maps from name to Term self._filters = {} self._factors = {} self._classifiers = {} self.blotter = kwargs.pop('blotter', None) if not self.blotter: self.blotter = Blotter() # Set the dt initally to the period start by forcing it to change self.on_dt_changed(self.sim_params.period_start) self.portfolio_needs_update = True self.account_needs_update = True self.performance_needs_update = True self._portfolio = None self._account = None self.history_container_class = kwargs.pop( 'history_container_class', HistoryContainer, ) self.history_container = None self.history_specs = {} # If string is passed in, execute and get reference to # functions. self.algoscript = kwargs.pop('script', None) self._initialize = None self._before_trading_start = None self._analyze = None self.event_manager = EventManager() if self.algoscript is not None: filename = kwargs.pop('algo_filename', None) if filename is None: filename = '<string>' code = compile(self.algoscript, filename, 'exec') exec_(code, self.namespace) self._initialize = self.namespace.get('initialize') if 'handle_data' not in self.namespace: raise ValueError('You must define a handle_data function.') else: self._handle_data = self.namespace['handle_data'] self._before_trading_start = \ self.namespace.get('before_trading_start') # Optional analyze function, gets called after run self._analyze = self.namespace.get('analyze') elif kwargs.get('initialize') and kwargs.get('handle_data'): if self.algoscript is not None: raise ValueError('You can not set script and \ initialize/handle_data.') self._initialize = kwargs.pop('initialize') self._handle_data = kwargs.pop('handle_data') self._before_trading_start = kwargs.pop('before_trading_start', None) self.event_manager.add_event( zipline.utils.events.Event( zipline.utils.events.Always(), # We pass handle_data.__func__ to get the unbound method. # We will explicitly pass the algorithm to bind it again. self.handle_data.__func__, ), prepend=True, ) # If method not defined, NOOP if self._initialize is None: self._initialize = lambda x: None # Alternative way of setting data_frequency for backwards # compatibility. if 'data_frequency' in kwargs: self.data_frequency = kwargs.pop('data_frequency') self._most_recent_data = None # Prepare the algo for initialization self.initialized = False self.initialize_args = args self.initialize_kwargs = kwargs def init_engine(self, loader): """ Construct and save an FFCEngine from loader. If loader is None, constructs a NoOpFFCEngine. """ if loader is not None: self.engine = SimpleFFCEngine( loader, self.trading_environment.trading_days, self.asset_finder, ) else: self.engine = NoOpFFCEngine() def initialize(self, *args, **kwargs): """ Call self._initialize with `self` made available to Zipline API functions. """ with ZiplineAPI(self): self._initialize(self) def before_trading_start(self): if self._before_trading_start is None: return self._before_trading_start(self) def handle_data(self, data): self._most_recent_data = data if self.history_container: self.history_container.update(data, self.datetime) self._handle_data(self, data) # Unlike trading controls which remain constant unless placing an # order, account controls can change each bar. Thus, must check # every bar no matter if the algorithm places an order or not. self.validate_account_controls() def analyze(self, perf): if self._analyze is None: return with ZiplineAPI(self): self._analyze(self, perf) def __repr__(self): """ N.B. this does not yet represent a string that can be used to instantiate an exact copy of an algorithm. However, it is getting close, and provides some value as something that can be inspected interactively. """ return """ {class_name}( capital_base={capital_base} sim_params={sim_params}, initialized={initialized}, slippage={slippage}, commission={commission}, blotter={blotter}, recorded_vars={recorded_vars}) """.strip().format(class_name=self.__class__.__name__, capital_base=self.capital_base, sim_params=repr(self.sim_params), initialized=self.initialized, slippage=repr(self.slippage), commission=repr(self.commission), blotter=repr(self.blotter), recorded_vars=repr(self.recorded_vars)) def _create_data_generator(self, source_filter, sim_params=None): """ Create a merged data generator using the sources attached to this algorithm. ::source_filter:: is a method that receives events in date sorted order, and returns True for those events that should be processed by the zipline, and False for those that should be skipped. """ if sim_params is None: sim_params = self.sim_params if self.benchmark_return_source is None: if sim_params.data_frequency == 'minute' or \ sim_params.emission_rate == 'minute': def update_time(date): return self.trading_environment.get_open_and_close(date)[1] else: def update_time(date): return date benchmark_return_source = [ Event({ 'dt': update_time(dt), 'returns': ret, 'type': zipline.protocol.DATASOURCE_TYPE.BENCHMARK, 'source_id': 'benchmarks' }) for dt, ret in self.trading_environment.benchmark_returns.iteritems() if dt.date() >= sim_params.period_start.date() and dt.date() <= sim_params.period_end.date() ] else: benchmark_return_source = self.benchmark_return_source date_sorted = date_sorted_sources(*self.sources) if source_filter: date_sorted = filter(source_filter, date_sorted) with_benchmarks = date_sorted_sources(benchmark_return_source, date_sorted) # Group together events with the same dt field. This depends on the # events already being sorted. return groupby(with_benchmarks, attrgetter('dt')) def _create_generator(self, sim_params, source_filter=None): """ Create a basic generator setup using the sources to this algorithm. ::source_filter:: is a method that receives events in date sorted order, and returns True for those events that should be processed by the zipline, and False for those that should be skipped. """ if not self.initialized: self.initialize(*self.initialize_args, **self.initialize_kwargs) self.initialized = True if self.perf_tracker is None: # HACK: When running with the `run` method, we set perf_tracker to # None so that it will be overwritten here. self.perf_tracker = PerformanceTracker(sim_params) self.portfolio_needs_update = True self.account_needs_update = True self.performance_needs_update = True self.data_gen = self._create_data_generator(source_filter, sim_params) self.trading_client = AlgorithmSimulator(self, sim_params) transact_method = transact_partial(self.slippage, self.commission) self.set_transact(transact_method) return self.trading_client.transform(self.data_gen) def get_generator(self): """ Override this method to add new logic to the construction of the generator. Overrides can use the _create_generator method to get a standard construction generator. """ return self._create_generator(self.sim_params) # TODO: make a new subclass, e.g. BatchAlgorithm, and move # the run method to the subclass, and refactor to put the # generator creation logic into get_generator. def run(self, source, overwrite_sim_params=True, benchmark_return_source=None): """Run the algorithm. :Arguments: source : can be either: - pandas.DataFrame - zipline source - list of sources If pandas.DataFrame is provided, it must have the following structure: * column names must be the different asset identifiers * index must be DatetimeIndex * array contents should be price info. :Returns: daily_stats : pandas.DataFrame Daily performance metrics such as returns, alpha etc. """ # Ensure that source is a DataSource object if isinstance(source, list): if overwrite_sim_params: warnings.warn( """List of sources passed, will not attempt to extract start and end dates. Make sure to set the correct fields in sim_params passed to __init__().""", UserWarning) overwrite_sim_params = False elif isinstance(source, pd.DataFrame): # if DataFrame provided, map columns to sids and wrap # in DataFrameSource copy_frame = source.copy() copy_frame.columns = \ self.asset_finder.map_identifier_index_to_sids( source.columns, source.index[0] ) source = DataFrameSource(copy_frame) elif isinstance(source, pd.Panel): # If Panel provided, map items to sids and wrap # in DataPanelSource copy_panel = source.copy() copy_panel.items = self.asset_finder.map_identifier_index_to_sids( source.items, source.major_axis[0]) source = DataPanelSource(copy_panel) if isinstance(source, list): self.set_sources(source) else: self.set_sources([source]) # Override sim_params if params are provided by the source. if overwrite_sim_params: if hasattr(source, 'start'): self.sim_params.period_start = source.start if hasattr(source, 'end'): self.sim_params.period_end = source.end # Changing period_start and period_close might require updating # of first_open and last_close. self.sim_params._update_internal() # The sids field of the source is the reference for the universe at # the start of the run self._current_universe = set() for source in self.sources: for sid in source.sids: self._current_universe.add(sid) # Check that all sids from the source are accounted for in # the AssetFinder. This retrieve call will raise an exception if the # sid is not found. for sid in self._current_universe: self.asset_finder.retrieve_asset(sid) # force a reset of the performance tracker, in case # this is a repeat run of the algorithm. self.perf_tracker = None # create zipline self.gen = self._create_generator(self.sim_params) # Create history containers if self.history_specs: self.history_container = self.history_container_class( self.history_specs, self.current_universe(), self.sim_params.first_open, self.sim_params.data_frequency, ) # loop through simulated_trading, each iteration returns a # perf dictionary perfs = [] for perf in self.gen: perfs.append(perf) # convert perf dict to pandas dataframe daily_stats = self._create_daily_stats(perfs) self.analyze(daily_stats) return daily_stats def _create_daily_stats(self, perfs): # create daily and cumulative stats dataframe daily_perfs = [] # TODO: the loop here could overwrite expected properties # of daily_perf. Could potentially raise or log a # warning. for perf in perfs: if 'daily_perf' in perf: perf['daily_perf'].update( perf['daily_perf'].pop('recorded_vars')) perf['daily_perf'].update(perf['cumulative_risk_metrics']) daily_perfs.append(perf['daily_perf']) else: self.risk_report = perf daily_dts = [ np.datetime64(perf['period_close'], utc=True) for perf in daily_perfs ] daily_stats = pd.DataFrame(daily_perfs, index=daily_dts) return daily_stats @api_method def add_transform(self, transform, days=None): """ Ensures that the history container will have enough size to service a simple transform. :Arguments: transform : string The transform to add. must be an element of: {'mavg', 'stddev', 'vwap', 'returns'}. days : int <default=None> The maximum amount of days you will want for this transform. This is not needed for 'returns'. """ if transform not in {'mavg', 'stddev', 'vwap', 'returns'}: raise ValueError('Invalid transform') if transform == 'returns': if days is not None: raise ValueError('returns does use days') self.add_history(2, '1d', 'price') return elif days is None: raise ValueError('no number of days specified') if self.sim_params.data_frequency == 'daily': mult = 1 freq = '1d' else: mult = 390 freq = '1m' bars = mult * days self.add_history(bars, freq, 'price') if transform == 'vwap': self.add_history(bars, freq, 'volume') @api_method def get_environment(self, field='platform'): env = { 'arena': self.sim_params.arena, 'data_frequency': self.sim_params.data_frequency, 'start': self.sim_params.first_open, 'end': self.sim_params.last_close, 'capital_base': self.sim_params.capital_base, 'platform': self._platform } if field == '*': return env else: return env[field] def add_event(self, rule=None, callback=None): """ Adds an event to the algorithm's EventManager. """ self.event_manager.add_event( zipline.utils.events.Event(rule, callback), ) @api_method def schedule_function(self, func, date_rule=None, time_rule=None, half_days=True): """ Schedules a function to be called with some timed rules. """ date_rule = date_rule or DateRuleFactory.every_day() time_rule = ((time_rule or TimeRuleFactory.market_open()) if self.sim_params.data_frequency == 'minute' else # If we are in daily mode the time_rule is ignored. zipline.utils.events.Always()) self.add_event( make_eventrule(date_rule, time_rule, half_days), func, ) @api_method def record(self, *args, **kwargs): """ Track and record local variable (i.e. attributes) each day. """ # Make 2 objects both referencing the same iterator args = [iter(args)] * 2 # Zip generates list entries by calling `next` on each iterator it # receives. In this case the two iterators are the same object, so the # call to next on args[0] will also advance args[1], resulting in zip # returning (a,b) (c,d) (e,f) rather than (a,a) (b,b) (c,c) etc. positionals = zip(*args) for name, value in chain(positionals, iteritems(kwargs)): self._recorded_vars[name] = value @api_method def symbol(self, symbol_str): """ Default symbol lookup for any source that directly maps the symbol to the Asset (e.g. yahoo finance). """ return self.asset_finder.lookup_symbol_resolve_multiple( symbol_str, as_of_date=self.datetime) @api_method def symbols(self, *args): """ Default symbols lookup for any source that directly maps the symbol to the Asset (e.g. yahoo finance). """ return [self.symbol(identifier) for identifier in args] @api_method def sid(self, a_sid): """ Default sid lookup for any source that directly maps the integer sid to the Asset. """ return self.asset_finder.retrieve_asset(a_sid) @api_method def future_chain(self, root_symbol, as_of_date=None): """ Look up a future chain with the specified parameters. Parameters ---------- root_symbol : str The root symbol of a future chain. as_of_date : datetime.datetime or pandas.Timestamp or str, optional Date at which the chain determination is rooted. I.e. the existing contract whose notice date is first after this date is the primary contract, etc. Returns ------- FutureChain The future chain matching the specified parameters. Raises ------ RootSymbolNotFound If a future chain could not be found for the given root symbol. """ return FutureChain(asset_finder=self.asset_finder, get_datetime=self.get_datetime, root_symbol=root_symbol.upper(), as_of_date=as_of_date) def _calculate_order_value_amount(self, asset, value): """ Calculates how many shares/contracts to order based on the type of asset being ordered. """ last_price = self.trading_client.current_data[asset].price if tolerant_equals(last_price, 0): zero_message = "Price of 0 for {psid}; can't infer value".format( psid=asset) if self.logger: self.logger.debug(zero_message) # Don't place any order return 0 if isinstance(asset, Future): value_multiplier = asset.contract_multiplier else: value_multiplier = 1 return value / (last_price * value_multiplier) @api_method def order(self, sid, amount, limit_price=None, stop_price=None, style=None): """ Place an order using the specified parameters. """ def round_if_near_integer(a, epsilon=1e-4): """ Round a to the nearest integer if that integer is within an epsilon of a. """ if abs(a - round(a)) <= epsilon: return round(a) else: return a # Truncate to the integer share count that's either within .0001 of # amount or closer to zero. # E.g. 3.9999 -> 4.0; 5.5 -> 5.0; -5.5 -> -5.0 amount = int(round_if_near_integer(amount)) # Raises a ZiplineError if invalid parameters are detected. self.validate_order_params(sid, amount, limit_price, stop_price, style) # Convert deprecated limit_price and stop_price parameters to use # ExecutionStyle objects. style = self.__convert_order_params_for_blotter( limit_price, stop_price, style) return self.blotter.order(sid, amount, style) def validate_order_params(self, asset, amount, limit_price, stop_price, style): """ Helper method for validating parameters to the order API function. Raises an UnsupportedOrderParameters if invalid arguments are found. """ if not self.initialized: raise OrderDuringInitialize( msg="order() can only be called from within handle_data()") if style: if limit_price: raise UnsupportedOrderParameters( msg="Passing both limit_price and style is not supported.") if stop_price: raise UnsupportedOrderParameters( msg="Passing both stop_price and style is not supported.") if not isinstance(asset, Asset): raise UnsupportedOrderParameters( msg="Passing non-Asset argument to 'order()' is not supported." " Use 'sid()' or 'symbol()' methods to look up an Asset.") for control in self.trading_controls: control.validate(asset, amount, self.updated_portfolio(), self.get_datetime(), self.trading_client.current_data) @staticmethod def __convert_order_params_for_blotter(limit_price, stop_price, style): """ Helper method for converting deprecated limit_price and stop_price arguments into ExecutionStyle instances. This function assumes that either style == None or (limit_price, stop_price) == (None, None). """ # TODO_SS: DeprecationWarning for usage of limit_price and stop_price. if style: assert (limit_price, stop_price) == (None, None) return style if limit_price and stop_price: return StopLimitOrder(limit_price, stop_price) if limit_price: return LimitOrder(limit_price) if stop_price: return StopOrder(stop_price) else: return MarketOrder() @api_method def order_value(self, sid, value, limit_price=None, stop_price=None, style=None): """ Place an order by desired value rather than desired number of shares. If the requested sid is found in the universe, the requested value is divided by its price to imply the number of shares to transact. If the Asset being ordered is a Future, the 'value' calculated is actually the exposure, as Futures have no 'value'. value > 0 :: Buy/Cover value < 0 :: Sell/Short Market order: order(sid, value) Limit order: order(sid, value, limit_price) Stop order: order(sid, value, None, stop_price) StopLimit order: order(sid, value, limit_price, stop_price) """ amount = self._calculate_order_value_amount(sid, value) return self.order(sid, amount, limit_price=limit_price, stop_price=stop_price, style=style) @property def recorded_vars(self): return copy(self._recorded_vars) @property def portfolio(self): return self.updated_portfolio() def updated_portfolio(self): if self.portfolio_needs_update: self._portfolio = \ self.perf_tracker.get_portfolio(self.performance_needs_update) self.portfolio_needs_update = False self.performance_needs_update = False return self._portfolio @property def account(self): return self.updated_account() def updated_account(self): if self.account_needs_update: self._account = \ self.perf_tracker.get_account(self.performance_needs_update) self.account_needs_update = False self.performance_needs_update = False return self._account def set_logger(self, logger): self.logger = logger def on_dt_changed(self, dt): """ Callback triggered by the simulation loop whenever the current dt changes. Any logic that should happen exactly once at the start of each datetime group should happen here. """ assert isinstance(dt, datetime), \ "Attempt to set algorithm's current time with non-datetime" assert dt.tzinfo == pytz.utc, \ "Algorithm expects a utc datetime" self.datetime = dt self.perf_tracker.set_date(dt) self.blotter.set_date(dt) @api_method def get_datetime(self, tz=None): """ Returns the simulation datetime. """ dt = self.datetime assert dt.tzinfo == pytz.utc, "Algorithm should have a utc datetime" if tz is not None: # Convert to the given timezone passed as a string or tzinfo. if isinstance(tz, string_types): tz = pytz.timezone(tz) dt = dt.astimezone(tz) return dt # datetime.datetime objects are immutable. def set_transact(self, transact): """ Set the method that will be called to create a transaction from open orders and trade events. """ self.blotter.transact = transact def update_dividends(self, dividend_frame): """ Set DataFrame used to process dividends. DataFrame columns should contain at least the entries in zp.DIVIDEND_FIELDS. """ self.perf_tracker.update_dividends(dividend_frame) @api_method def set_slippage(self, slippage): if not isinstance(slippage, SlippageModel): raise UnsupportedSlippageModel() if self.initialized: raise OverrideSlippagePostInit() self.slippage = slippage @api_method def set_commission(self, commission): if not isinstance(commission, (PerShare, PerTrade, PerDollar)): raise UnsupportedCommissionModel() if self.initialized: raise OverrideCommissionPostInit() self.commission = commission def set_sources(self, sources): assert isinstance(sources, list) self.sources = sources # Remain backwards compatibility @property def data_frequency(self): return self.sim_params.data_frequency @data_frequency.setter def data_frequency(self, value): assert value in ('daily', 'minute') self.sim_params.data_frequency = value @api_method def order_percent(self, sid, percent, limit_price=None, stop_price=None, style=None): """ Place an order in the specified asset corresponding to the given percent of the current portfolio value. Note that percent must expressed as a decimal (0.50 means 50\%). """ value = self.portfolio.portfolio_value * percent return self.order_value(sid, value, limit_price=limit_price, stop_price=stop_price, style=style) @api_method def order_target(self, sid, target, limit_price=None, stop_price=None, style=None): """ Place an order to adjust a position to a target number of shares. If the position doesn't already exist, this is equivalent to placing a new order. If the position does exist, this is equivalent to placing an order for the difference between the target number of shares and the current number of shares. """ if sid in self.portfolio.positions: current_position = self.portfolio.positions[sid].amount req_shares = target - current_position return self.order(sid, req_shares, limit_price=limit_price, stop_price=stop_price, style=style) else: return self.order(sid, target, limit_price=limit_price, stop_price=stop_price, style=style) @api_method def order_target_value(self, sid, target, limit_price=None, stop_price=None, style=None): """ Place an order to adjust a position to a target value. If the position doesn't already exist, this is equivalent to placing a new order. If the position does exist, this is equivalent to placing an order for the difference between the target value and the current value. If the Asset being ordered is a Future, the 'target value' calculated is actually the target exposure, as Futures have no 'value'. """ target_amount = self._calculate_order_value_amount(sid, target) return self.order_target(sid, target_amount, limit_price=limit_price, stop_price=stop_price, style=style) @api_method def order_target_percent(self, sid, target, limit_price=None, stop_price=None, style=None): """ Place an order to adjust a position to a target percent of the current portfolio value. If the position doesn't already exist, this is equivalent to placing a new order. If the position does exist, this is equivalent to placing an order for the difference between the target percent and the current percent. Note that target must expressed as a decimal (0.50 means 50\%). """ target_value = self.portfolio.portfolio_value * target return self.order_target_value(sid, target_value, limit_price=limit_price, stop_price=stop_price, style=style) @api_method def get_open_orders(self, sid=None): if sid is None: return { key: [order.to_api_obj() for order in orders] for key, orders in iteritems(self.blotter.open_orders) if orders } if sid in self.blotter.open_orders: orders = self.blotter.open_orders[sid] return [order.to_api_obj() for order in orders] return [] @api_method def get_order(self, order_id): if order_id in self.blotter.orders: return self.blotter.orders[order_id].to_api_obj() @api_method def cancel_order(self, order_param): order_id = order_param if isinstance(order_param, zipline.protocol.Order): order_id = order_param.id self.blotter.cancel(order_id) @api_method def add_history(self, bar_count, frequency, field, ffill=True): data_frequency = self.sim_params.data_frequency history_spec = HistorySpec(bar_count, frequency, field, ffill, data_frequency=data_frequency) self.history_specs[history_spec.key_str] = history_spec if self.initialized: if self.history_container: self.history_container.ensure_spec( history_spec, self.datetime, self._most_recent_data, ) else: self.history_container = self.history_container_class( self.history_specs, self.current_universe(), self.sim_params.first_open, self.sim_params.data_frequency, ) def get_history_spec(self, bar_count, frequency, field, ffill): spec_key = HistorySpec.spec_key(bar_count, frequency, field, ffill) if spec_key not in self.history_specs: data_freq = self.sim_params.data_frequency spec = HistorySpec( bar_count, frequency, field, ffill, data_frequency=data_freq, ) self.history_specs[spec_key] = spec if not self.history_container: self.history_container = self.history_container_class( self.history_specs, self.current_universe(), self.datetime, self.sim_params.data_frequency, bar_data=self._most_recent_data, ) self.history_container.ensure_spec( spec, self.datetime, self._most_recent_data, ) return self.history_specs[spec_key] @api_method def history(self, bar_count, frequency, field, ffill=True): history_spec = self.get_history_spec( bar_count, frequency, field, ffill, ) return self.history_container.get_history(history_spec, self.datetime) #################### # Account Controls # #################### def register_account_control(self, control): """ Register a new AccountControl to be checked on each bar. """ if self.initialized: raise RegisterAccountControlPostInit() self.account_controls.append(control) def validate_account_controls(self): for control in self.account_controls: control.validate(self.updated_portfolio(), self.updated_account(), self.get_datetime(), self.trading_client.current_data) @api_method def set_max_leverage(self, max_leverage=None): """ Set a limit on the maximum leverage of the algorithm. """ control = MaxLeverage(max_leverage) self.register_account_control(control) #################### # Trading Controls # #################### def register_trading_control(self, control): """ Register a new TradingControl to be checked prior to order calls. """ if self.initialized: raise RegisterTradingControlPostInit() self.trading_controls.append(control) @api_method def set_max_position_size(self, sid=None, max_shares=None, max_notional=None): """ Set a limit on the number of shares and/or dollar value held for the given sid. Limits are treated as absolute values and are enforced at the time that the algo attempts to place an order for sid. This means that it's possible to end up with more than the max number of shares due to splits/dividends, and more than the max notional due to price improvement. If an algorithm attempts to place an order that would result in increasing the absolute value of shares/dollar value exceeding one of these limits, raise a TradingControlException. """ control = MaxPositionSize(asset=sid, max_shares=max_shares, max_notional=max_notional) self.register_trading_control(control) @api_method def set_max_order_size(self, sid=None, max_shares=None, max_notional=None): """ Set a limit on the number of shares and/or dollar value of any single order placed for sid. Limits are treated as absolute values and are enforced at the time that the algo attempts to place an order for sid. If an algorithm attempts to place an order that would result in exceeding one of these limits, raise a TradingControlException. """ control = MaxOrderSize(asset=sid, max_shares=max_shares, max_notional=max_notional) self.register_trading_control(control) @api_method def set_max_order_count(self, max_count): """ Set a limit on the number of orders that can be placed within the given time interval. """ control = MaxOrderCount(max_count) self.register_trading_control(control) @api_method def set_do_not_order_list(self, restricted_list): """ Set a restriction on which sids can be ordered. """ control = RestrictedListOrder(restricted_list) self.register_trading_control(control) @api_method def set_long_only(self): """ Set a rule specifying that this algorithm cannot take short positions. """ self.register_trading_control(LongOnly()) ########### # FFC API # ########### @api_method @require_not_initialized(AddTermPostInit()) def add_factor(self, factor, name): if name in self._factors: raise ValueError("Name %r is already a factor!" % name) self._factors[name] = factor @api_method @require_not_initialized(AddTermPostInit()) def add_filter(self, filter): name = "anon_filter_%d" % len(self._filters) self._filters[name] = filter # Note: add_classifier is not yet implemented since you can't do anything # useful with classifiers yet. def _all_terms(self): # Merge all three dicts. return dict( chain.from_iterable( iteritems(terms) for terms in (self._filters, self._factors, self._classifiers))) def compute_factor_matrix(self, start_date): """ Compute a factor matrix starting at start_date. """ days = self.trading_environment.trading_days start_date_loc = days.get_loc(start_date) sim_end = self.sim_params.last_close.normalize() end_loc = min(start_date_loc + 252, days.get_loc(sim_end)) end_date = days[end_loc] return self.engine.factor_matrix( self._all_terms(), start_date, end_date, ), end_date def current_universe(self): return self._current_universe @classmethod def all_api_methods(cls): """ Return a list of all the TradingAlgorithm API methods. """ return [ fn for fn in itervalues(vars(cls)) if getattr(fn, 'is_api_method', False) ]
def __init__(self, *args, **kwargs): """Initialize sids and other state variables. :Arguments: :Optional: initialize : function Function that is called with a single argument at the begninning of the simulation. handle_data : function Function that is called with 2 arguments (context and data) on every bar. script : str Algoscript that contains initialize and handle_data function definition. data_frequency : str (daily, hourly or minutely) The duration of the bars. annualizer : int <optional> Which constant to use for annualizing risk metrics. If not provided, will extract from data_frequency. capital_base : float <default: 1.0e5> How much capital to start with. instant_fill : bool <default: False> Whether to fill orders immediately or on next bar. """ self.datetime = None self.registered_transforms = {} self.transforms = [] self.sources = [] self._recorded_vars = {} self.logger = None self.benchmark_return_source = None self.perf_tracker = None # default components for transact self.slippage = VolumeShareSlippage() self.commission = PerShare() if 'data_frequency' in kwargs: self.set_data_frequency(kwargs.pop('data_frequency')) else: self.data_frequency = None self.instant_fill = kwargs.pop('instant_fill', False) # Override annualizer if set if 'annualizer' in kwargs: self.annualizer = kwargs['annualizer'] # set the capital base self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE) self.sim_params = kwargs.pop('sim_params', None) if self.sim_params: if self.data_frequency is None: self.data_frequency = self.sim_params.data_frequency else: self.sim_params.data_frequency = self.data_frequency self.perf_tracker = PerformanceTracker(self.sim_params) self.blotter = kwargs.pop('blotter', None) if not self.blotter: self.blotter = Blotter() self.portfolio_needs_update = True self._portfolio = None # If string is passed in, execute and get reference to # functions. self.algoscript = kwargs.pop('script', None) if self.algoscript is not None: self.ns = {} exec_(self.algoscript, self.ns) if 'initialize' not in self.ns: raise ValueError('You must define an initialze function.') if 'handle_data' not in self.ns: raise ValueError('You must define a handle_data function.') self._initialize = self.ns['initialize'] self._handle_data = self.ns['handle_data'] # If two functions are passed in assume initialize and # handle_data are passed in. elif kwargs.get('initialize', False) and kwargs.get('handle_data'): if self.algoscript is not None: raise ValueError('You can not set script and \ initialize/handle_data.') self._initialize = kwargs.pop('initialize') self._handle_data = kwargs.pop('handle_data') # an algorithm subclass needs to set initialized to True when # it is fully initialized. self.initialized = False self.initialize(*args, **kwargs)
class TradingAlgorithm(object): """ Base class for trading algorithms. Inherit and overload initialize() and handle_data(data). A new algorithm could look like this: ``` from zipline.api import order, symbol def initialize(context): context.sid = symbol('AAPL') context.amount = 100 def handle_data(context, data): sid = context.sid amount = context.amount order(sid, amount) ``` To then to run this algorithm pass these functions to TradingAlgorithm: my_algo = TradingAlgorithm(initialize, handle_data) stats = my_algo.run(data) """ def __init__(self, *args, **kwargs): """Initialize sids and other state variables. :Arguments: :Optional: initialize : function Function that is called with a single argument at the begninning of the simulation. handle_data : function Function that is called with 2 arguments (context and data) on every bar. script : str Algoscript that contains initialize and handle_data function definition. data_frequency : {'daily', 'minute'} The duration of the bars. capital_base : float <default: 1.0e5> How much capital to start with. instant_fill : bool <default: False> Whether to fill orders immediately or on next bar. asset_finder : An AssetFinder object A new AssetFinder object to be used in this TradingEnvironment equities_metadata : can be either: - dict - pandas.DataFrame - object with 'read' property If dict is provided, it must have the following structure: * keys are the identifiers * values are dicts containing the metadata, with the metadata field name as the key If pandas.DataFrame is provided, it must have the following structure: * column names must be the metadata fields * index must be the different asset identifiers * array contents should be the metadata value If an object with a 'read' property is provided, 'read' must return rows containing at least one of 'sid' or 'symbol' along with the other metadata fields. identifiers : List Any asset identifiers that are not provided in the equities_metadata, but will be traded by this TradingAlgorithm """ self.sources = [] # List of trading controls to be used to validate orders. self.trading_controls = [] # List of account controls to be checked on each bar. self.account_controls = [] self._recorded_vars = {} self.namespace = kwargs.pop('namespace', {}) self._platform = kwargs.pop('platform', 'zipline') self.logger = None self.benchmark_return_source = None # default components for transact self.slippage = VolumeShareSlippage() self.commission = PerShare() self.instant_fill = kwargs.pop('instant_fill', False) # If an env has been provided, pop it self.trading_environment = kwargs.pop('env', None) if self.trading_environment is None: self.trading_environment = TradingEnvironment() # Update the TradingEnvironment with the provided asset metadata self.trading_environment.write_data( equities_data=kwargs.pop('equities_metadata', {}), equities_identifiers=kwargs.pop('identifiers', []), futures_data=kwargs.pop('futures_metadata', {}), ) # set the capital base self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE) self.sim_params = kwargs.pop('sim_params', None) if self.sim_params is None: self.sim_params = create_simulation_parameters( capital_base=self.capital_base, start=kwargs.pop('start', None), end=kwargs.pop('end', None), env=self.trading_environment, ) else: self.sim_params.update_internal_from_env(self.trading_environment) # Build a perf_tracker self.perf_tracker = PerformanceTracker(sim_params=self.sim_params, env=self.trading_environment) # Pull in the environment's new AssetFinder for quick reference self.asset_finder = self.trading_environment.asset_finder self.init_engine(kwargs.pop('ffc_loader', None)) # Maps from name to Term self._filters = {} self._factors = {} self._classifiers = {} self.blotter = kwargs.pop('blotter', None) if not self.blotter: self.blotter = Blotter() # Set the dt initally to the period start by forcing it to change self.on_dt_changed(self.sim_params.period_start) # The symbol lookup date specifies the date to use when resolving # symbols to sids, and can be set using set_symbol_lookup_date() self._symbol_lookup_date = None self.portfolio_needs_update = True self.account_needs_update = True self.performance_needs_update = True self._portfolio = None self._account = None self.history_container_class = kwargs.pop( 'history_container_class', HistoryContainer, ) self.history_container = None self.history_specs = {} # If string is passed in, execute and get reference to # functions. self.algoscript = kwargs.pop('script', None) self._initialize = None self._before_trading_start = None self._analyze = None self.event_manager = EventManager() if self.algoscript is not None: filename = kwargs.pop('algo_filename', None) if filename is None: filename = '<string>' code = compile(self.algoscript, filename, 'exec') exec_(code, self.namespace) self._initialize = self.namespace.get('initialize') if 'handle_data' not in self.namespace: raise ValueError('You must define a handle_data function.') else: self._handle_data = self.namespace['handle_data'] self._before_trading_start = \ self.namespace.get('before_trading_start') # Optional analyze function, gets called after run self._analyze = self.namespace.get('analyze') elif kwargs.get('initialize') and kwargs.get('handle_data'): if self.algoscript is not None: raise ValueError('You can not set script and \ initialize/handle_data.') self._initialize = kwargs.pop('initialize') self._handle_data = kwargs.pop('handle_data') self._before_trading_start = kwargs.pop('before_trading_start', None) self.event_manager.add_event( zipline.utils.events.Event( zipline.utils.events.Always(), # We pass handle_data.__func__ to get the unbound method. # We will explicitly pass the algorithm to bind it again. self.handle_data.__func__, ), prepend=True, ) # If method not defined, NOOP if self._initialize is None: self._initialize = lambda x: None # Alternative way of setting data_frequency for backwards # compatibility. if 'data_frequency' in kwargs: self.data_frequency = kwargs.pop('data_frequency') self._most_recent_data = None # Prepare the algo for initialization self.initialized = False self.initialize_args = args self.initialize_kwargs = kwargs def init_engine(self, loader): """ Construct and save an FFCEngine from loader. If loader is None, constructs a NoOpFFCEngine. """ if loader is not None: self.engine = SimpleFFCEngine( loader, self.trading_environment.trading_days, self.asset_finder, ) else: self.engine = NoOpFFCEngine() def initialize(self, *args, **kwargs): """ Call self._initialize with `self` made available to Zipline API functions. """ with ZiplineAPI(self): self._initialize(self, *args, **kwargs) def before_trading_start(self, data): if self._before_trading_start is None: return self._before_trading_start(self, data) def handle_data(self, data): self._most_recent_data = data if self.history_container: self.history_container.update(data, self.datetime) self._handle_data(self, data) # Unlike trading controls which remain constant unless placing an # order, account controls can change each bar. Thus, must check # every bar no matter if the algorithm places an order or not. self.validate_account_controls() def analyze(self, perf): if self._analyze is None: return with ZiplineAPI(self): self._analyze(self, perf) def __repr__(self): """ N.B. this does not yet represent a string that can be used to instantiate an exact copy of an algorithm. However, it is getting close, and provides some value as something that can be inspected interactively. """ return """ {class_name}( capital_base={capital_base} sim_params={sim_params}, initialized={initialized}, slippage={slippage}, commission={commission}, blotter={blotter}, recorded_vars={recorded_vars}) """.strip().format(class_name=self.__class__.__name__, capital_base=self.capital_base, sim_params=repr(self.sim_params), initialized=self.initialized, slippage=repr(self.slippage), commission=repr(self.commission), blotter=repr(self.blotter), recorded_vars=repr(self.recorded_vars)) def _create_data_generator(self, source_filter, sim_params=None): """ Create a merged data generator using the sources attached to this algorithm. ::source_filter:: is a method that receives events in date sorted order, and returns True for those events that should be processed by the zipline, and False for those that should be skipped. """ if sim_params is None: sim_params = self.sim_params if self.benchmark_return_source is None: if sim_params.data_frequency == 'minute' or \ sim_params.emission_rate == 'minute': def update_time(date): return self.trading_environment.get_open_and_close(date)[1] else: def update_time(date): return date benchmark_return_source = [ Event({'dt': update_time(dt), 'returns': ret, 'type': zipline.protocol.DATASOURCE_TYPE.BENCHMARK, 'source_id': 'benchmarks'}) for dt, ret in self.trading_environment.benchmark_returns.iteritems() if dt.date() >= sim_params.period_start.date() and dt.date() <= sim_params.period_end.date() ] else: benchmark_return_source = self.benchmark_return_source date_sorted = date_sorted_sources(*self.sources) if source_filter: date_sorted = filter(source_filter, date_sorted) with_benchmarks = date_sorted_sources(benchmark_return_source, date_sorted) # Group together events with the same dt field. This depends on the # events already being sorted. return groupby(with_benchmarks, attrgetter('dt')) def _create_generator(self, sim_params, source_filter=None): """ Create a basic generator setup using the sources to this algorithm. ::source_filter:: is a method that receives events in date sorted order, and returns True for those events that should be processed by the zipline, and False for those that should be skipped. """ if not self.initialized: self.initialize(*self.initialize_args, **self.initialize_kwargs) self.initialized = True if self.perf_tracker is None: # HACK: When running with the `run` method, we set perf_tracker to # None so that it will be overwritten here. self.perf_tracker = PerformanceTracker( sim_params=sim_params, env=self.trading_environment ) self.portfolio_needs_update = True self.account_needs_update = True self.performance_needs_update = True self.data_gen = self._create_data_generator(source_filter, sim_params) self.trading_client = AlgorithmSimulator(self, sim_params) transact_method = transact_partial(self.slippage, self.commission) self.set_transact(transact_method) return self.trading_client.transform(self.data_gen) def get_generator(self): """ Override this method to add new logic to the construction of the generator. Overrides can use the _create_generator method to get a standard construction generator. """ return self._create_generator(self.sim_params) # TODO: make a new subclass, e.g. BatchAlgorithm, and move # the run method to the subclass, and refactor to put the # generator creation logic into get_generator. def run(self, source, overwrite_sim_params=True, benchmark_return_source=None): """Run the algorithm. :Arguments: source : can be either: - pandas.DataFrame - zipline source - list of sources If pandas.DataFrame is provided, it must have the following structure: * column names must be the different asset identifiers * index must be DatetimeIndex * array contents should be price info. :Returns: daily_stats : pandas.DataFrame Daily performance metrics such as returns, alpha etc. """ # Ensure that source is a DataSource object if isinstance(source, list): if overwrite_sim_params: warnings.warn("""List of sources passed, will not attempt to extract start and end dates. Make sure to set the correct fields in sim_params passed to __init__().""", UserWarning) overwrite_sim_params = False elif isinstance(source, pd.DataFrame): # if DataFrame provided, map columns to sids and wrap # in DataFrameSource copy_frame = source.copy() copy_frame.columns = self._write_and_map_id_index_to_sids( source.columns, source.index[0], ) source = DataFrameSource(copy_frame) elif isinstance(source, pd.Panel): # If Panel provided, map items to sids and wrap # in DataPanelSource copy_panel = source.copy() copy_panel.items = self._write_and_map_id_index_to_sids( source.items, source.major_axis[0], ) source = DataPanelSource(copy_panel) if isinstance(source, list): self.set_sources(source) else: self.set_sources([source]) # Override sim_params if params are provided by the source. if overwrite_sim_params: if hasattr(source, 'start'): self.sim_params.period_start = source.start if hasattr(source, 'end'): self.sim_params.period_end = source.end # Changing period_start and period_close might require updating # of first_open and last_close. self.sim_params.update_internal_from_env( env=self.trading_environment ) # The sids field of the source is the reference for the universe at # the start of the run self._current_universe = set() for source in self.sources: for sid in source.sids: self._current_universe.add(sid) # Check that all sids from the source are accounted for in # the AssetFinder. This retrieve call will raise an exception if the # sid is not found. for sid in self._current_universe: self.asset_finder.retrieve_asset(sid) # force a reset of the performance tracker, in case # this is a repeat run of the algorithm. self.perf_tracker = None # create zipline self.gen = self._create_generator(self.sim_params) # Create history containers if self.history_specs: self.history_container = self.history_container_class( self.history_specs, self.current_universe(), self.sim_params.first_open, self.sim_params.data_frequency, self.trading_environment, ) # loop through simulated_trading, each iteration returns a # perf dictionary perfs = [] for perf in self.gen: perfs.append(perf) # convert perf dict to pandas dataframe daily_stats = self._create_daily_stats(perfs) self.analyze(daily_stats) return daily_stats def _write_and_map_id_index_to_sids(self, identifiers, as_of_date): # Build new Assets for identifiers that can't be resolved as # sids/Assets identifiers_to_build = [] for identifier in identifiers: asset = None if isinstance(identifier, Asset): asset = self.asset_finder.retrieve_asset(sid=identifier.sid, default_none=True) elif hasattr(identifier, '__int__'): asset = self.asset_finder.retrieve_asset(sid=identifier, default_none=True) if asset is None: identifiers_to_build.append(identifier) self.trading_environment.write_data( equities_identifiers=identifiers_to_build) return self.asset_finder.map_identifier_index_to_sids( identifiers, as_of_date, ) def _create_daily_stats(self, perfs): # create daily and cumulative stats dataframe daily_perfs = [] # TODO: the loop here could overwrite expected properties # of daily_perf. Could potentially raise or log a # warning. for perf in perfs: if 'daily_perf' in perf: perf['daily_perf'].update( perf['daily_perf'].pop('recorded_vars') ) perf['daily_perf'].update(perf['cumulative_risk_metrics']) daily_perfs.append(perf['daily_perf']) else: self.risk_report = perf daily_dts = [np.datetime64(perf['period_close'], utc=True) for perf in daily_perfs] daily_stats = pd.DataFrame(daily_perfs, index=daily_dts) return daily_stats @api_method def add_transform(self, transform, days=None): """ Ensures that the history container will have enough size to service a simple transform. :Arguments: transform : string The transform to add. must be an element of: {'mavg', 'stddev', 'vwap', 'returns'}. days : int <default=None> The maximum amount of days you will want for this transform. This is not needed for 'returns'. """ if transform not in {'mavg', 'stddev', 'vwap', 'returns'}: raise ValueError('Invalid transform') if transform == 'returns': if days is not None: raise ValueError('returns does use days') self.add_history(2, '1d', 'price') return elif days is None: raise ValueError('no number of days specified') if self.sim_params.data_frequency == 'daily': mult = 1 freq = '1d' else: mult = 390 freq = '1m' bars = mult * days self.add_history(bars, freq, 'price') if transform == 'vwap': self.add_history(bars, freq, 'volume') @api_method def get_environment(self, field='platform'): env = { 'arena': self.sim_params.arena, 'data_frequency': self.sim_params.data_frequency, 'start': self.sim_params.first_open, 'end': self.sim_params.last_close, 'capital_base': self.sim_params.capital_base, 'platform': self._platform } if field == '*': return env else: return env[field] def add_event(self, rule=None, callback=None): """ Adds an event to the algorithm's EventManager. """ self.event_manager.add_event( zipline.utils.events.Event(rule, callback), ) @api_method def schedule_function(self, func, date_rule=None, time_rule=None, half_days=True): """ Schedules a function to be called with some timed rules. """ date_rule = date_rule or DateRuleFactory.every_day() time_rule = ((time_rule or TimeRuleFactory.market_open()) if self.sim_params.data_frequency == 'minute' else # If we are in daily mode the time_rule is ignored. zipline.utils.events.Always()) self.add_event( make_eventrule(date_rule, time_rule, half_days), func, ) @api_method def record(self, *args, **kwargs): """ Track and record local variable (i.e. attributes) each day. """ # Make 2 objects both referencing the same iterator args = [iter(args)] * 2 # Zip generates list entries by calling `next` on each iterator it # receives. In this case the two iterators are the same object, so the # call to next on args[0] will also advance args[1], resulting in zip # returning (a,b) (c,d) (e,f) rather than (a,a) (b,b) (c,c) etc. positionals = zip(*args) for name, value in chain(positionals, iteritems(kwargs)): self._recorded_vars[name] = value @api_method def symbol(self, symbol_str): """ Default symbol lookup for any source that directly maps the symbol to the Asset (e.g. yahoo finance). """ # If the user has not set the symbol lookup date, # use the period_end as the date for sybmol->sid resolution. _lookup_date = self._symbol_lookup_date if self._symbol_lookup_date is not None \ else self.sim_params.period_end return self.asset_finder.lookup_symbol( symbol_str, as_of_date=_lookup_date, ) @api_method def symbols(self, *args): """ Default symbols lookup for any source that directly maps the symbol to the Asset (e.g. yahoo finance). """ return [self.symbol(identifier) for identifier in args] @api_method def sid(self, a_sid): """ Default sid lookup for any source that directly maps the integer sid to the Asset. """ return self.asset_finder.retrieve_asset(a_sid) @api_method def future_chain(self, root_symbol, as_of_date=None): """ Look up a future chain with the specified parameters. Parameters ---------- root_symbol : str The root symbol of a future chain. as_of_date : datetime.datetime or pandas.Timestamp or str, optional Date at which the chain determination is rooted. I.e. the existing contract whose notice date is first after this date is the primary contract, etc. Returns ------- FutureChain The future chain matching the specified parameters. Raises ------ RootSymbolNotFound If a future chain could not be found for the given root symbol. """ if as_of_date: try: as_of_date = pd.Timestamp(as_of_date, tz='UTC') except ValueError: raise UnsupportedDatetimeFormat(input=as_of_date, method='future_chain') return FutureChain( asset_finder=self.asset_finder, get_datetime=self.get_datetime, root_symbol=root_symbol.upper(), as_of_date=as_of_date ) def _calculate_order_value_amount(self, asset, value): """ Calculates how many shares/contracts to order based on the type of asset being ordered. """ last_price = self.trading_client.current_data[asset].price if tolerant_equals(last_price, 0): zero_message = "Price of 0 for {psid}; can't infer value".format( psid=asset ) if self.logger: self.logger.debug(zero_message) # Don't place any order return 0 if isinstance(asset, Future): value_multiplier = asset.contract_multiplier else: value_multiplier = 1 return value / (last_price * value_multiplier) @api_method def order(self, sid, amount, limit_price=None, stop_price=None, style=None): """ Place an order using the specified parameters. """ def round_if_near_integer(a, epsilon=1e-4): """ Round a to the nearest integer if that integer is within an epsilon of a. """ if abs(a - round(a)) <= epsilon: return round(a) else: return a # Truncate to the integer share count that's either within .0001 of # amount or closer to zero. # E.g. 3.9999 -> 4.0; 5.5 -> 5.0; -5.5 -> -5.0 amount = int(round_if_near_integer(amount)) # Raises a ZiplineError if invalid parameters are detected. self.validate_order_params(sid, amount, limit_price, stop_price, style) # Convert deprecated limit_price and stop_price parameters to use # ExecutionStyle objects. style = self.__convert_order_params_for_blotter(limit_price, stop_price, style) return self.blotter.order(sid, amount, style) def validate_order_params(self, asset, amount, limit_price, stop_price, style): """ Helper method for validating parameters to the order API function. Raises an UnsupportedOrderParameters if invalid arguments are found. """ if not self.initialized: raise OrderDuringInitialize( msg="order() can only be called from within handle_data()" ) if style: if limit_price: raise UnsupportedOrderParameters( msg="Passing both limit_price and style is not supported." ) if stop_price: raise UnsupportedOrderParameters( msg="Passing both stop_price and style is not supported." ) if not isinstance(asset, Asset): raise UnsupportedOrderParameters( msg="Passing non-Asset argument to 'order()' is not supported." " Use 'sid()' or 'symbol()' methods to look up an Asset." ) for control in self.trading_controls: control.validate(asset, amount, self.updated_portfolio(), self.get_datetime(), self.trading_client.current_data) @staticmethod def __convert_order_params_for_blotter(limit_price, stop_price, style): """ Helper method for converting deprecated limit_price and stop_price arguments into ExecutionStyle instances. This function assumes that either style == None or (limit_price, stop_price) == (None, None). """ # TODO_SS: DeprecationWarning for usage of limit_price and stop_price. if style: assert (limit_price, stop_price) == (None, None) return style if limit_price and stop_price: return StopLimitOrder(limit_price, stop_price) if limit_price: return LimitOrder(limit_price) if stop_price: return StopOrder(stop_price) else: return MarketOrder() @api_method def order_value(self, sid, value, limit_price=None, stop_price=None, style=None): """ Place an order by desired value rather than desired number of shares. If the requested sid is found in the universe, the requested value is divided by its price to imply the number of shares to transact. If the Asset being ordered is a Future, the 'value' calculated is actually the exposure, as Futures have no 'value'. value > 0 :: Buy/Cover value < 0 :: Sell/Short Market order: order(sid, value) Limit order: order(sid, value, limit_price) Stop order: order(sid, value, None, stop_price) StopLimit order: order(sid, value, limit_price, stop_price) """ amount = self._calculate_order_value_amount(sid, value) return self.order(sid, amount, limit_price=limit_price, stop_price=stop_price, style=style) @property def recorded_vars(self): return copy(self._recorded_vars) @property def portfolio(self): return self.updated_portfolio() def updated_portfolio(self): if self.portfolio_needs_update: self._portfolio = \ self.perf_tracker.get_portfolio(self.performance_needs_update) self.portfolio_needs_update = False self.performance_needs_update = False return self._portfolio @property def account(self): return self.updated_account() def updated_account(self): if self.account_needs_update: self._account = \ self.perf_tracker.get_account(self.performance_needs_update) self.account_needs_update = False self.performance_needs_update = False return self._account def set_logger(self, logger): self.logger = logger def on_dt_changed(self, dt): """ Callback triggered by the simulation loop whenever the current dt changes. Any logic that should happen exactly once at the start of each datetime group should happen here. """ assert isinstance(dt, datetime), \ "Attempt to set algorithm's current time with non-datetime" assert dt.tzinfo == pytz.utc, \ "Algorithm expects a utc datetime" self.datetime = dt self.perf_tracker.set_date(dt) self.blotter.set_date(dt) @api_method def get_datetime(self, tz=None): """ Returns the simulation datetime. """ dt = self.datetime assert dt.tzinfo == pytz.utc, "Algorithm should have a utc datetime" if tz is not None: # Convert to the given timezone passed as a string or tzinfo. if isinstance(tz, string_types): tz = pytz.timezone(tz) dt = dt.astimezone(tz) return dt # datetime.datetime objects are immutable. def set_transact(self, transact): """ Set the method that will be called to create a transaction from open orders and trade events. """ self.blotter.transact = transact def update_dividends(self, dividend_frame): """ Set DataFrame used to process dividends. DataFrame columns should contain at least the entries in zp.DIVIDEND_FIELDS. """ self.perf_tracker.update_dividends(dividend_frame) @api_method def set_slippage(self, slippage): if not isinstance(slippage, SlippageModel): raise UnsupportedSlippageModel() if self.initialized: raise OverrideSlippagePostInit() self.slippage = slippage @api_method def set_commission(self, commission): if not isinstance(commission, (PerShare, PerTrade, PerDollar)): raise UnsupportedCommissionModel() if self.initialized: raise OverrideCommissionPostInit() self.commission = commission @api_method def set_symbol_lookup_date(self, dt): """ Set the date for which symbols will be resolved to their sids (symbols may map to different firms or underlying assets at different times) """ try: self._symbol_lookup_date = pd.Timestamp(dt, tz='UTC') except ValueError: raise UnsupportedDatetimeFormat(input=dt, method='set_symbol_lookup_date') def set_sources(self, sources): assert isinstance(sources, list) self.sources = sources # Remain backwards compatibility @property def data_frequency(self): return self.sim_params.data_frequency @data_frequency.setter def data_frequency(self, value): assert value in ('daily', 'minute') self.sim_params.data_frequency = value @api_method def order_percent(self, sid, percent, limit_price=None, stop_price=None, style=None): """ Place an order in the specified asset corresponding to the given percent of the current portfolio value. Note that percent must expressed as a decimal (0.50 means 50\%). """ value = self.portfolio.portfolio_value * percent return self.order_value(sid, value, limit_price=limit_price, stop_price=stop_price, style=style) @api_method def order_target(self, sid, target, limit_price=None, stop_price=None, style=None): """ Place an order to adjust a position to a target number of shares. If the position doesn't already exist, this is equivalent to placing a new order. If the position does exist, this is equivalent to placing an order for the difference between the target number of shares and the current number of shares. """ if sid in self.portfolio.positions: current_position = self.portfolio.positions[sid].amount req_shares = target - current_position return self.order(sid, req_shares, limit_price=limit_price, stop_price=stop_price, style=style) else: return self.order(sid, target, limit_price=limit_price, stop_price=stop_price, style=style) @api_method def order_target_value(self, sid, target, limit_price=None, stop_price=None, style=None): """ Place an order to adjust a position to a target value. If the position doesn't already exist, this is equivalent to placing a new order. If the position does exist, this is equivalent to placing an order for the difference between the target value and the current value. If the Asset being ordered is a Future, the 'target value' calculated is actually the target exposure, as Futures have no 'value'. """ target_amount = self._calculate_order_value_amount(sid, target) return self.order_target(sid, target_amount, limit_price=limit_price, stop_price=stop_price, style=style) @api_method def order_target_percent(self, sid, target, limit_price=None, stop_price=None, style=None): """ Place an order to adjust a position to a target percent of the current portfolio value. If the position doesn't already exist, this is equivalent to placing a new order. If the position does exist, this is equivalent to placing an order for the difference between the target percent and the current percent. Note that target must expressed as a decimal (0.50 means 50\%). """ target_value = self.portfolio.portfolio_value * target return self.order_target_value(sid, target_value, limit_price=limit_price, stop_price=stop_price, style=style) @api_method def get_open_orders(self, sid=None): if sid is None: return { key: [order.to_api_obj() for order in orders] for key, orders in iteritems(self.blotter.open_orders) if orders } if sid in self.blotter.open_orders: orders = self.blotter.open_orders[sid] return [order.to_api_obj() for order in orders] return [] @api_method def get_order(self, order_id): if order_id in self.blotter.orders: return self.blotter.orders[order_id].to_api_obj() @api_method def cancel_order(self, order_param): order_id = order_param if isinstance(order_param, zipline.protocol.Order): order_id = order_param.id self.blotter.cancel(order_id) @api_method def add_history(self, bar_count, frequency, field, ffill=True): data_frequency = self.sim_params.data_frequency history_spec = HistorySpec(bar_count, frequency, field, ffill, data_frequency=data_frequency, env=self.trading_environment) self.history_specs[history_spec.key_str] = history_spec if self.initialized: if self.history_container: self.history_container.ensure_spec( history_spec, self.datetime, self._most_recent_data, ) else: self.history_container = self.history_container_class( self.history_specs, self.current_universe(), self.sim_params.first_open, self.sim_params.data_frequency, env=self.trading_environment, ) def get_history_spec(self, bar_count, frequency, field, ffill): spec_key = HistorySpec.spec_key(bar_count, frequency, field, ffill) if spec_key not in self.history_specs: data_freq = self.sim_params.data_frequency spec = HistorySpec( bar_count, frequency, field, ffill, data_frequency=data_freq, env=self.trading_environment, ) self.history_specs[spec_key] = spec if not self.history_container: self.history_container = self.history_container_class( self.history_specs, self.current_universe(), self.datetime, self.sim_params.data_frequency, bar_data=self._most_recent_data, env=self.trading_environment, ) self.history_container.ensure_spec( spec, self.datetime, self._most_recent_data, ) return self.history_specs[spec_key] @api_method def history(self, bar_count, frequency, field, ffill=True): history_spec = self.get_history_spec( bar_count, frequency, field, ffill, ) return self.history_container.get_history(history_spec, self.datetime) #################### # Account Controls # #################### def register_account_control(self, control): """ Register a new AccountControl to be checked on each bar. """ if self.initialized: raise RegisterAccountControlPostInit() self.account_controls.append(control) def validate_account_controls(self): for control in self.account_controls: control.validate(self.updated_portfolio(), self.updated_account(), self.get_datetime(), self.trading_client.current_data) @api_method def set_max_leverage(self, max_leverage=None): """ Set a limit on the maximum leverage of the algorithm. """ control = MaxLeverage(max_leverage) self.register_account_control(control) #################### # Trading Controls # #################### def register_trading_control(self, control): """ Register a new TradingControl to be checked prior to order calls. """ if self.initialized: raise RegisterTradingControlPostInit() self.trading_controls.append(control) @api_method def set_max_position_size(self, sid=None, max_shares=None, max_notional=None): """ Set a limit on the number of shares and/or dollar value held for the given sid. Limits are treated as absolute values and are enforced at the time that the algo attempts to place an order for sid. This means that it's possible to end up with more than the max number of shares due to splits/dividends, and more than the max notional due to price improvement. If an algorithm attempts to place an order that would result in increasing the absolute value of shares/dollar value exceeding one of these limits, raise a TradingControlException. """ control = MaxPositionSize(asset=sid, max_shares=max_shares, max_notional=max_notional) self.register_trading_control(control) @api_method def set_max_order_size(self, sid=None, max_shares=None, max_notional=None): """ Set a limit on the number of shares and/or dollar value of any single order placed for sid. Limits are treated as absolute values and are enforced at the time that the algo attempts to place an order for sid. If an algorithm attempts to place an order that would result in exceeding one of these limits, raise a TradingControlException. """ control = MaxOrderSize(asset=sid, max_shares=max_shares, max_notional=max_notional) self.register_trading_control(control) @api_method def set_max_order_count(self, max_count): """ Set a limit on the number of orders that can be placed within the given time interval. """ control = MaxOrderCount(max_count) self.register_trading_control(control) @api_method def set_do_not_order_list(self, restricted_list): """ Set a restriction on which sids can be ordered. """ control = RestrictedListOrder(restricted_list) self.register_trading_control(control) @api_method def set_long_only(self): """ Set a rule specifying that this algorithm cannot take short positions. """ self.register_trading_control(LongOnly()) ########### # FFC API # ########### @api_method @require_not_initialized(AddTermPostInit()) def add_factor(self, factor, name): if name in self._factors: raise ValueError("Name %r is already a factor!" % name) self._factors[name] = factor @api_method @require_not_initialized(AddTermPostInit()) def add_filter(self, filter): name = "anon_filter_%d" % len(self._filters) self._filters[name] = filter # Note: add_classifier is not yet implemented since you can't do anything # useful with classifiers yet. def _all_terms(self): # Merge all three dicts. return dict( chain.from_iterable( iteritems(terms) for terms in (self._filters, self._factors, self._classifiers) ) ) def compute_factor_matrix(self, start_date): """ Compute a factor matrix containing at least the data necessary to provide values for `start_date`. Loads a factor matrix with data extending from `start_date` until a year from `start_date`, or until the end of the simulation. """ days = self.trading_environment.trading_days # Load data starting from the previous trading day... start_date_loc = days.get_loc(start_date) # ...continuing until either the day before the simulation end, or # until 252 days of data have been loaded. 252 is a totally arbitrary # choice that seemed reasonable based on napkin math. sim_end = self.sim_params.last_close.normalize() end_loc = min(start_date_loc + 252, days.get_loc(sim_end)) end_date = days[end_loc] return self.engine.factor_matrix( self._all_terms(), start_date, end_date, ), end_date def current_universe(self): return self._current_universe @classmethod def all_api_methods(cls): """ Return a list of all the TradingAlgorithm API methods. """ return [ fn for fn in itervalues(vars(cls)) if getattr(fn, 'is_api_method', False) ]
def __init__(self, *args, **kwargs): """Initialize sids and other state variables. :Arguments: :Optional: initialize : function Function that is called with a single argument at the begninning of the simulation. handle_data : function Function that is called with 2 arguments (context and data) on every bar. script : str Algoscript that contains initialize and handle_data function definition. data_frequency : str (daily, hourly or minutely) The duration of the bars. annualizer : int <optional> Which constant to use for annualizing risk metrics. If not provided, will extract from data_frequency. capital_base : float <default: 1.0e5> How much capital to start with. instant_fill : bool <default: False> Whether to fill orders immediately or on next bar. """ self.datetime = None self.registered_transforms = {} self.transforms = [] self.sources = [] self._recorded_vars = {} self.logger = None self.benchmark_return_source = None self.perf_tracker = None # default components for transact self.slippage = VolumeShareSlippage() self.commission = PerShare() if 'data_frequency' in kwargs: self.set_data_frequency(kwargs.pop('data_frequency')) else: self.data_frequency = None self.instant_fill = kwargs.pop('instant_fill', False) # Override annualizer if set if 'annualizer' in kwargs: self.annualizer = kwargs['annualizer'] # set the capital base self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE) self.sim_params = kwargs.pop('sim_params', None) if self.sim_params: if self.data_frequency is None: self.data_frequency = self.sim_params.data_frequency else: self.sim_params.data_frequency = self.data_frequency self.perf_tracker = PerformanceTracker(self.sim_params) self.blotter = kwargs.pop('blotter', None) if not self.blotter: self.blotter = Blotter() self.portfolio_needs_update = True self._portfolio = None # If string is passed in, execute and get reference to # functions. self.algoscript = kwargs.pop('script', None) self._initialize = None if self.algoscript is not None: self.ns = {} exec_(self.algoscript, self.ns) if 'initialize' not in self.ns: raise ValueError('You must define an initialze function.') if 'handle_data' not in self.ns: raise ValueError('You must define a handle_data function.') self._initialize = self.ns['initialize'] self._handle_data = self.ns['handle_data'] # If two functions are passed in assume initialize and # handle_data are passed in. elif kwargs.get('initialize', False) and kwargs.get('handle_data'): if self.algoscript is not None: raise ValueError('You can not set script and \ initialize/handle_data.') self._initialize = kwargs.pop('initialize') self._handle_data = kwargs.pop('handle_data') if self._initialize is None: self._initialize = lambda x: None # an algorithm subclass needs to set initialized to True when # it is fully initialized. self.initialized = False self.initialize(*args, **kwargs)
def __init__(self, *args, **kwargs): """Initialize sids and other state variables. :Arguments: :Optional: initialize : function Function that is called with a single argument at the begninning of the simulation. handle_data : function Function that is called with 2 arguments (context and data) on every bar. script : str Algoscript that contains initialize and handle_data function definition. data_frequency : str (daily, hourly or minutely) The duration of the bars. capital_base : float <default: 1.0e5> How much capital to start with. instant_fill : bool <default: False> Whether to fill orders immediately or on next bar. environment : str <default: 'zipline'> The environment that this algorithm is running in. """ self.datetime = None self.sources = [] # List of trading controls to be used to validate orders. self.trading_controls = [] # List of account controls to be checked on each bar. self.account_controls = [] self._recorded_vars = {} self.namespace = kwargs.get('namespace', {}) self._platform = kwargs.pop('platform', 'zipline') self.logger = None self.benchmark_return_source = None # default components for transact self.slippage = VolumeShareSlippage() self.commission = PerShare() self.instant_fill = kwargs.pop('instant_fill', False) # set the capital base self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE) self.sim_params = kwargs.pop('sim_params', None) if self.sim_params is None: self.sim_params = create_simulation_parameters( capital_base=self.capital_base ) self.perf_tracker = PerformanceTracker(self.sim_params) self.blotter = kwargs.pop('blotter', None) if not self.blotter: self.blotter = Blotter() self.portfolio_needs_update = True self.account_needs_update = True self.performance_needs_update = True self._portfolio = None self._account = None self.history_container_class = kwargs.pop( 'history_container_class', HistoryContainer, ) self.history_container = None self.history_specs = {} # If string is passed in, execute and get reference to # functions. self.algoscript = kwargs.pop('script', None) self._initialize = None self._before_trading_start = None self._analyze = None self.event_manager = EventManager() if self.algoscript is not None: filename = kwargs.pop('algo_filename', None) if filename is None: filename = '<string>' code = compile(self.algoscript, filename, 'exec') exec_(code, self.namespace) self._initialize = self.namespace.get('initialize') if 'handle_data' not in self.namespace: raise ValueError('You must define a handle_data function.') else: self._handle_data = self.namespace['handle_data'] self._before_trading_start = \ self.namespace.get('before_trading_start') # Optional analyze function, gets called after run self._analyze = self.namespace.get('analyze') elif kwargs.get('initialize') and kwargs.get('handle_data'): if self.algoscript is not None: raise ValueError('You can not set script and \ initialize/handle_data.') self._initialize = kwargs.pop('initialize') self._handle_data = kwargs.pop('handle_data') self._before_trading_start = kwargs.pop('before_trading_start', None) self.event_manager.add_event( zipline.utils.events.Event( zipline.utils.events.Always(), # We pass handle_data.__func__ to get the unbound method. # We will explicitly pass the algorithm to bind it again. self.handle_data.__func__, ), prepend=True, ) # If method not defined, NOOP if self._initialize is None: self._initialize = lambda x: None # Alternative way of setting data_frequency for backwards # compatibility. if 'data_frequency' in kwargs: self.data_frequency = kwargs.pop('data_frequency') self._most_recent_data = None # Subclasses that override initialize should only worry about # setting self.initialized = True if AUTO_INITIALIZE is # is manually set to False. self.initialized = False self.initialize(*args, **kwargs) if self.AUTO_INITIALIZE: self.initialized = True
class TradingAlgorithm(object): """ Base class for trading algorithms. Inherit and overload initialize() and handle_data(data). A new algorithm could look like this: ``` class MyAlgo(TradingAlgorithm): def initialize(self, sids, amount): self.sids = sids self.amount = amount def handle_data(self, data): sid = self.sids[0] amount = self.amount self.order(sid, amount) ``` To then to run this algorithm: my_algo = MyAlgo([0], 100) # first argument has to be list of sids stats = my_algo.run(data) """ def __init__(self, *args, **kwargs): """Initialize sids and other state variables. :Arguments: data_frequency : str (daily, hourly or minutely) The duration of the bars. annualizer : int <optional> Which constant to use for annualizing risk metrics. If not provided, will extract from data_frequency. capital_base : float <default: 1.0e5> How much capital to start with. instant_fill : bool <default: False> Whether to fill orders immediately or on next bar. """ self.datetime = None self.registered_transforms = {} self.transforms = [] self.sources = [] self._recorded_vars = {} self.logger = None self.benchmark_return_source = None self.perf_tracker = None # default components for transact self.slippage = VolumeShareSlippage() self.commission = PerShare() if 'data_frequency' in kwargs: self.set_data_frequency(kwargs.pop('data_frequency')) else: self.data_frequency = None self.instant_fill = kwargs.pop('instant_fill', False) # Override annualizer if set if 'annualizer' in kwargs: self.annualizer = kwargs['annualizer'] # set the capital base self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE) self.sim_params = kwargs.pop('sim_params', None) if self.sim_params: self.sim_params.data_frequency = self.data_frequency self.perf_tracker = PerformanceTracker(self.sim_params) self.blotter = kwargs.pop('blotter', None) if not self.blotter: self.blotter = Blotter() self.portfolio_needs_update = True self._portfolio = None # an algorithm subclass needs to set initialized to True when # it is fully initialized. self.initialized = False # call to user-defined constructor method self.initialize(*args, **kwargs) def __repr__(self): """ N.B. this does not yet represent a string that can be used to instantiate an exact copy of an algorithm. However, it is getting close, and provides some value as something that can be inspected interactively. """ return """ {class_name}( capital_base={capital_base} sim_params={sim_params}, initialized={initialized}, slippage={slippage}, commission={commission}, blotter={blotter}, recorded_vars={recorded_vars}) """.strip().format(class_name=self.__class__.__name__, capital_base=self.capital_base, sim_params=repr(self.sim_params), initialized=self.initialized, slippage=repr(self.slippage), commission=repr(self.commission), blotter=repr(self.blotter), recorded_vars=repr(self.recorded_vars)) def _create_data_generator(self, source_filter, sim_params): """ Create a merged data generator using the sources and transforms attached to this algorithm. ::source_filter:: is a method that receives events in date sorted order, and returns True for those events that should be processed by the zipline, and False for those that should be skipped. """ if self.benchmark_return_source is None: benchmark_return_source = [ Event({ 'dt': dt, 'returns': ret, 'type': zipline.protocol.DATASOURCE_TYPE.BENCHMARK, 'source_id': 'benchmarks' }) for dt, ret in trading.environment.benchmark_returns.iterkv() if dt.date() >= sim_params.period_start.date() and dt.date() <= sim_params.period_end.date() ] else: benchmark_return_source = self.benchmark_return_source date_sorted = date_sorted_sources(*self.sources) if source_filter: date_sorted = ifilter(source_filter, date_sorted) with_tnfms = sequential_transforms(date_sorted, *self.transforms) with_alias_dt = alias_dt(with_tnfms) with_benchmarks = date_sorted_sources(benchmark_return_source, with_alias_dt) # Group together events with the same dt field. This depends on the # events already being sorted. return groupby(with_benchmarks, attrgetter('dt')) def _create_generator(self, sim_params, source_filter=None): """ Create a basic generator setup using the sources and transforms attached to this algorithm. ::source_filter:: is a method that receives events in date sorted order, and returns True for those events that should be processed by the zipline, and False for those that should be skipped. """ sim_params.data_frequency = self.data_frequency # perf_tracker will be instantiated in __init__ if a sim_params # is passed to the constructor. If not, we instantiate here. if self.perf_tracker is None: self.perf_tracker = PerformanceTracker(sim_params) self.data_gen = self._create_data_generator(source_filter, sim_params) self.trading_client = AlgorithmSimulator(self, sim_params) transact_method = transact_partial(self.slippage, self.commission) self.set_transact(transact_method) return self.trading_client.transform(self.data_gen) def get_generator(self): """ Override this method to add new logic to the construction of the generator. Overrides can use the _create_generator method to get a standard construction generator. """ return self._create_generator(self.sim_params) def initialize(self, *args, **kwargs): pass # TODO: make a new subclass, e.g. BatchAlgorithm, and move # the run method to the subclass, and refactor to put the # generator creation logic into get_generator. def run(self, source, sim_params=None, benchmark_return_source=None): """Run the algorithm. :Arguments: source : can be either: - pandas.DataFrame - zipline source - list of zipline sources If pandas.DataFrame is provided, it must have the following structure: * column names must consist of ints representing the different sids * index must be DatetimeIndex * array contents should be price info. :Returns: daily_stats : pandas.DataFrame Daily performance metrics such as returns, alpha etc. """ if isinstance(source, (list, tuple)): assert self.sim_params is not None or sim_params is not None, \ """When providing a list of sources, \ sim_params have to be specified as a parameter or in the constructor.""" elif isinstance(source, pd.DataFrame): # if DataFrame provided, wrap in DataFrameSource source = DataFrameSource(source) elif isinstance(source, pd.Panel): source = DataPanelSource(source) if not isinstance(source, (list, tuple)): self.sources = [source] else: self.sources = source # Check for override of sim_params. # If it isn't passed to this function, # use the default params set with the algorithm. # Else, we create simulation parameters using the start and end of the # source provided. if not sim_params: if not self.sim_params: start = source.start end = source.end sim_params = create_simulation_parameters( start=start, end=end, capital_base=self.capital_base) else: sim_params = self.sim_params # Create transforms by wrapping them into StatefulTransforms self.transforms = [] for namestring, trans_descr in self.registered_transforms.iteritems(): sf = StatefulTransform(trans_descr['class'], *trans_descr['args'], **trans_descr['kwargs']) sf.namestring = namestring self.transforms.append(sf) # force a reset of the performance tracker, in case # this is a repeat run of the algorithm. self.perf_tracker = None # create transforms and zipline self.gen = self._create_generator(sim_params) # loop through simulated_trading, each iteration returns a # perf dictionary perfs = [] for perf in self.gen: perfs.append(perf) # convert perf dict to pandas dataframe daily_stats = self._create_daily_stats(perfs) return daily_stats def _create_daily_stats(self, perfs): # create daily and cumulative stats dataframe daily_perfs = [] # TODO: the loop here could overwrite expected properties # of daily_perf. Could potentially raise or log a # warning. for perf in perfs: if 'daily_perf' in perf: perf['daily_perf'].update( perf['daily_perf'].pop('recorded_vars')) daily_perfs.append(perf['daily_perf']) else: self.risk_report = perf daily_dts = [ np.datetime64(perf['period_close'], utc=True) for perf in daily_perfs ] daily_stats = pd.DataFrame(daily_perfs, index=daily_dts) return daily_stats def add_transform(self, transform_class, tag, *args, **kwargs): """Add a single-sid, sequential transform to the model. :Arguments: transform_class : class Which transform to use. E.g. mavg. tag : str How to name the transform. Can later be access via: data[sid].tag() Extra args and kwargs will be forwarded to the transform instantiation. """ self.registered_transforms[tag] = { 'class': transform_class, 'args': args, 'kwargs': kwargs } def record(self, **kwargs): """ Track and record local variable (i.e. attributes) each day. """ for name, value in kwargs.items(): self._recorded_vars[name] = value def order(self, sid, amount, limit_price=None, stop_price=None): return self.blotter.order(sid, amount, limit_price, stop_price) def order_value(self, sid, value, limit_price=None, stop_price=None): """ Place an order by desired value rather than desired number of shares. If the requested sid is found in the universe, the requested value is divided by its price to imply the number of shares to transact. value > 0 :: Buy/Cover value < 0 :: Sell/Short Market order: order(sid, value) Limit order: order(sid, value, limit_price) Stop order: order(sid, value, None, stop_price) StopLimit order: order(sid, value, limit_price, stop_price) """ last_price = self.trading_client.current_data[sid].price if np.allclose(last_price, 0): zero_message = "Price of 0 for {psid}; can't infer value".format( psid=sid) self.logger.debug(zero_message) # Don't place any order return else: amount = value / last_price return self.order(sid, amount, limit_price, stop_price) @property def recorded_vars(self): return copy(self._recorded_vars) @property def portfolio(self): # internally this will cause a refresh of the # period performance calculations. return self.perf_tracker.get_portfolio() def updated_portfolio(self): # internally this will cause a refresh of the # period performance calculations. if self.portfolio_needs_update: self._portfolio = self.perf_tracker.get_portfolio() self.portfolio_needs_update = False return self._portfolio def set_logger(self, logger): self.logger = logger def set_datetime(self, dt): assert isinstance(dt, datetime), \ "Attempt to set algorithm's current time with non-datetime" assert dt.tzinfo == pytz.utc, \ "Algorithm expects a utc datetime" self.datetime = dt def get_datetime(self): """ Returns a copy of the datetime. """ date_copy = copy(self.datetime) assert date_copy.tzinfo == pytz.utc, \ "Algorithm should have a utc datetime" return date_copy def set_transact(self, transact): """ Set the method that will be called to create a transaction from open orders and trade events. """ self.blotter.transact = transact def set_slippage(self, slippage): if not isinstance(slippage, SlippageModel): raise UnsupportedSlippageModel() if self.initialized: raise OverrideSlippagePostInit() self.slippage = slippage def set_commission(self, commission): if not isinstance(commission, (PerShare, PerTrade, PerDollar)): raise UnsupportedCommissionModel() if self.initialized: raise OverrideCommissionPostInit() self.commission = commission def set_sources(self, sources): assert isinstance(sources, list) self.sources = sources def set_transforms(self, transforms): assert isinstance(transforms, list) self.transforms = transforms def set_data_frequency(self, data_frequency): assert data_frequency in ('daily', 'minute') self.data_frequency = data_frequency self.annualizer = ANNUALIZER[self.data_frequency] def order_percent(self, sid, percent, limit_price=None, stop_price=None): """ Place an order in the specified security corresponding to the given percent of the current portfolio value. Note that percent must expressed as a decimal (0.50 means 50\%). """ value = self.portfolio.portfolio_value * percent return self.order_value(sid, value, limit_price, stop_price) def order_target(self, sid, target, limit_price=None, stop_price=None): """ Place an order to adjust a position to a target number of shares. If the position doesn't already exist, this is equivalent to placing a new order. If the position does exist, this is equivalent to placing an order for the difference between the target number of shares and the current number of shares. """ if sid in self.portfolio.positions: current_position = self.portfolio.positions[sid].amount req_shares = target - current_position return self.order(sid, req_shares, limit_price, stop_price) else: return self.order(sid, target, limit_price, stop_price) def order_target_value(self, sid, target, limit_price=None, stop_price=None): """ Place an order to adjust a position to a target value. If the position doesn't already exist, this is equivalent to placing a new order. If the position does exist, this is equivalent to placing an order for the difference between the target value and the current value. """ if sid in self.portfolio.positions: current_position = self.portfolio.positions[sid].amount current_price = self.portfolio.positions[sid].last_sale_price current_value = current_position * current_price req_value = target - current_value return self.order_value(sid, req_value, limit_price, stop_price) else: return self.order_value(sid, target, limit_price, stop_price) def order_target_percent(self, sid, target, limit_price=None, stop_price=None): """ Place an order to adjust a position to a target percent of the current portfolio value. If the position doesn't already exist, this is equivalent to placing a new order. If the position does exist, this is equivalent to placing an order for the difference between the target percent and the current percent. Note that target must expressed as a decimal (0.50 means 50\%). """ if sid in self.portfolio.positions: current_position = self.portfolio.positions[sid].amount current_price = self.portfolio.positions[sid].last_sale_price current_value = current_position * current_price else: current_value = 0 target_value = self.portfolio.portfolio_value * target req_value = target_value - current_value return self.order_value(sid, req_value, limit_price, stop_price)
class TradingAlgorithm(object): """ Base class for trading algorithms. Inherit and overload initialize() and handle_data(data). A new algorithm could look like this: ``` from zipline.api import order def initialize(context): context.sid = 'AAPL' context.amount = 100 def handle_data(self, data): sid = context.sid amount = context.amount order(sid, amount) ``` To then to run this algorithm pass these functions to TradingAlgorithm: my_algo = TradingAlgorithm(initialize, handle_data) stats = my_algo.run(data) """ def __init__(self, *args, **kwargs): """Initialize sids and other state variables. :Arguments: :Optional: initialize : function Function that is called with a single argument at the begninning of the simulation. handle_data : function Function that is called with 2 arguments (context and data) on every bar. script : str Algoscript that contains initialize and handle_data function definition. data_frequency : str (daily, hourly or minutely) The duration of the bars. annualizer : int <optional> Which constant to use for annualizing risk metrics. If not provided, will extract from data_frequency. capital_base : float <default: 1.0e5> How much capital to start with. instant_fill : bool <default: False> Whether to fill orders immediately or on next bar. """ self.datetime = None self.registered_transforms = {} self.transforms = [] self.sources = [] self._recorded_vars = {} self.logger = None self.benchmark_return_source = None self.perf_tracker = None # default components for transact self.slippage = VolumeShareSlippage() self.commission = PerShare() if 'data_frequency' in kwargs: self.set_data_frequency(kwargs.pop('data_frequency')) else: self.data_frequency = None self.instant_fill = kwargs.pop('instant_fill', False) # Override annualizer if set if 'annualizer' in kwargs: self.annualizer = kwargs['annualizer'] # set the capital base self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE) self.sim_params = kwargs.pop('sim_params', None) if self.sim_params: if self.data_frequency is None: self.data_frequency = self.sim_params.data_frequency else: self.sim_params.data_frequency = self.data_frequency self.perf_tracker = PerformanceTracker(self.sim_params) self.blotter = kwargs.pop('blotter', None) if not self.blotter: self.blotter = Blotter() self.portfolio_needs_update = True self._portfolio = None # If string is passed in, execute and get reference to # functions. self.algoscript = kwargs.pop('script', None) if self.algoscript is not None: self.ns = {} exec_(self.algoscript, self.ns) if 'initialize' not in self.ns: raise ValueError('You must define an initialze function.') if 'handle_data' not in self.ns: raise ValueError('You must define a handle_data function.') self._initialize = self.ns['initialize'] self._handle_data = self.ns['handle_data'] # If two functions are passed in assume initialize and # handle_data are passed in. elif kwargs.get('initialize', False) and kwargs.get('handle_data'): if self.algoscript is not None: raise ValueError('You can not set script and \ initialize/handle_data.') self._initialize = kwargs.pop('initialize') self._handle_data = kwargs.pop('handle_data') # an algorithm subclass needs to set initialized to True when # it is fully initialized. self.initialized = False self.initialize(*args, **kwargs) def initialize(self, *args, **kwargs): # store algo reference in global space set_algo_instance(self) try: self._initialize(self) finally: set_algo_instance(None) def handle_data(self, data): self._handle_data(self, data) def __repr__(self): """ N.B. this does not yet represent a string that can be used to instantiate an exact copy of an algorithm. However, it is getting close, and provides some value as something that can be inspected interactively. """ return """ {class_name}( capital_base={capital_base} sim_params={sim_params}, initialized={initialized}, slippage={slippage}, commission={commission}, blotter={blotter}, recorded_vars={recorded_vars}) """.strip().format(class_name=self.__class__.__name__, capital_base=self.capital_base, sim_params=repr(self.sim_params), initialized=self.initialized, slippage=repr(self.slippage), commission=repr(self.commission), blotter=repr(self.blotter), recorded_vars=repr(self.recorded_vars)) def _create_data_generator(self, source_filter, sim_params): """ Create a merged data generator using the sources and transforms attached to this algorithm. ::source_filter:: is a method that receives events in date sorted order, and returns True for those events that should be processed by the zipline, and False for those that should be skipped. """ if self.benchmark_return_source is None: env = trading.environment if (self.data_frequency == 'minute' or sim_params.emission_rate == 'minute'): update_time = lambda date: env.get_open_and_close(date)[1] else: update_time = lambda date: date benchmark_return_source = [ Event({'dt': update_time(dt), 'returns': ret, 'type': zipline.protocol.DATASOURCE_TYPE.BENCHMARK, 'source_id': 'benchmarks'}) for dt, ret in trading.environment.benchmark_returns.iterkv() if dt.date() >= sim_params.period_start.date() and dt.date() <= sim_params.period_end.date() ] else: benchmark_return_source = self.benchmark_return_source date_sorted = date_sorted_sources(*self.sources) if source_filter: date_sorted = filter(source_filter, date_sorted) with_tnfms = sequential_transforms(date_sorted, *self.transforms) with_alias_dt = alias_dt(with_tnfms) with_benchmarks = date_sorted_sources(benchmark_return_source, with_alias_dt) # Group together events with the same dt field. This depends on the # events already being sorted. return groupby(with_benchmarks, attrgetter('dt')) def _create_generator(self, sim_params, source_filter=None): """ Create a basic generator setup using the sources and transforms attached to this algorithm. ::source_filter:: is a method that receives events in date sorted order, and returns True for those events that should be processed by the zipline, and False for those that should be skipped. """ sim_params.data_frequency = self.data_frequency # perf_tracker will be instantiated in __init__ if a sim_params # is passed to the constructor. If not, we instantiate here. if self.perf_tracker is None: self.perf_tracker = PerformanceTracker(sim_params) self.data_gen = self._create_data_generator(source_filter, sim_params) self.trading_client = AlgorithmSimulator(self, sim_params) transact_method = transact_partial(self.slippage, self.commission) self.set_transact(transact_method) return self.trading_client.transform(self.data_gen) def get_generator(self): """ Override this method to add new logic to the construction of the generator. Overrides can use the _create_generator method to get a standard construction generator. """ return self._create_generator(self.sim_params) # TODO: make a new subclass, e.g. BatchAlgorithm, and move # the run method to the subclass, and refactor to put the # generator creation logic into get_generator. def run(self, source, sim_params=None, benchmark_return_source=None): """Run the algorithm. :Arguments: source : can be either: - pandas.DataFrame - zipline source - list of zipline sources If pandas.DataFrame is provided, it must have the following structure: * column names must consist of ints representing the different sids * index must be DatetimeIndex * array contents should be price info. :Returns: daily_stats : pandas.DataFrame Daily performance metrics such as returns, alpha etc. """ if isinstance(source, (list, tuple)): assert self.sim_params is not None or sim_params is not None, \ """When providing a list of sources, \ sim_params have to be specified as a parameter or in the constructor.""" elif isinstance(source, pd.DataFrame): # if DataFrame provided, wrap in DataFrameSource source = DataFrameSource(source) elif isinstance(source, pd.Panel): source = DataPanelSource(source) if not isinstance(source, (list, tuple)): self.sources = [source] else: self.sources = source # Check for override of sim_params. # If it isn't passed to this function, # use the default params set with the algorithm. # Else, we create simulation parameters using the start and end of the # source provided. if not sim_params: if not self.sim_params: start = source.start end = source.end sim_params = create_simulation_parameters( start=start, end=end, capital_base=self.capital_base ) else: sim_params = self.sim_params # Create transforms by wrapping them into StatefulTransforms self.transforms = [] for namestring, trans_descr in iteritems(self.registered_transforms): sf = StatefulTransform( trans_descr['class'], *trans_descr['args'], **trans_descr['kwargs'] ) sf.namestring = namestring self.transforms.append(sf) # force a reset of the performance tracker, in case # this is a repeat run of the algorithm. self.perf_tracker = None # create transforms and zipline self.gen = self._create_generator(sim_params) # store algo reference in global space set_algo_instance(self) try: # loop through simulated_trading, each iteration returns a # perf dictionary perfs = [] for perf in self.gen: perfs.append(perf) # convert perf dict to pandas dataframe daily_stats = self._create_daily_stats(perfs) finally: # remove algo from global space set_algo_instance(None) return daily_stats def _create_daily_stats(self, perfs): # create daily and cumulative stats dataframe daily_perfs = [] # TODO: the loop here could overwrite expected properties # of daily_perf. Could potentially raise or log a # warning. for perf in perfs: if 'daily_perf' in perf: perf['daily_perf'].update( perf['daily_perf'].pop('recorded_vars') ) daily_perfs.append(perf['daily_perf']) else: self.risk_report = perf daily_dts = [np.datetime64(perf['period_close'], utc=True) for perf in daily_perfs] daily_stats = pd.DataFrame(daily_perfs, index=daily_dts) return daily_stats def add_transform(self, transform_class, tag, *args, **kwargs): """Add a single-sid, sequential transform to the model. :Arguments: transform_class : class Which transform to use. E.g. mavg. tag : str How to name the transform. Can later be access via: data[sid].tag() Extra args and kwargs will be forwarded to the transform instantiation. """ self.registered_transforms[tag] = {'class': transform_class, 'args': args, 'kwargs': kwargs} @api_method def record(self, **kwargs): """ Track and record local variable (i.e. attributes) each day. """ for name, value in kwargs.items(): self._recorded_vars[name] = value @api_method def order(self, sid, amount, limit_price=None, stop_price=None): return self.blotter.order(sid, amount, limit_price, stop_price) @api_method def order_value(self, sid, value, limit_price=None, stop_price=None): """ Place an order by desired value rather than desired number of shares. If the requested sid is found in the universe, the requested value is divided by its price to imply the number of shares to transact. value > 0 :: Buy/Cover value < 0 :: Sell/Short Market order: order(sid, value) Limit order: order(sid, value, limit_price) Stop order: order(sid, value, None, stop_price) StopLimit order: order(sid, value, limit_price, stop_price) """ last_price = self.trading_client.current_data[sid].price if np.allclose(last_price, 0): zero_message = "Price of 0 for {psid}; can't infer value".format( psid=sid ) self.logger.debug(zero_message) # Don't place any order return else: amount = value / last_price return self.order(sid, amount, limit_price, stop_price) @property def recorded_vars(self): return copy(self._recorded_vars) @property def portfolio(self): # internally this will cause a refresh of the # period performance calculations. return self.perf_tracker.get_portfolio() def updated_portfolio(self): # internally this will cause a refresh of the # period performance calculations. if self.portfolio_needs_update: self._portfolio = self.perf_tracker.get_portfolio() self.portfolio_needs_update = False return self._portfolio def set_logger(self, logger): self.logger = logger def set_datetime(self, dt): assert isinstance(dt, datetime), \ "Attempt to set algorithm's current time with non-datetime" assert dt.tzinfo == pytz.utc, \ "Algorithm expects a utc datetime" self.datetime = dt @api_method def get_datetime(self): """ Returns a copy of the datetime. """ date_copy = copy(self.datetime) assert date_copy.tzinfo == pytz.utc, \ "Algorithm should have a utc datetime" return date_copy def set_transact(self, transact): """ Set the method that will be called to create a transaction from open orders and trade events. """ self.blotter.transact = transact @api_method def set_slippage(self, slippage): if not isinstance(slippage, SlippageModel): raise UnsupportedSlippageModel() if self.initialized: raise OverrideSlippagePostInit() self.slippage = slippage @api_method def set_commission(self, commission): if not isinstance(commission, (PerShare, PerTrade, PerDollar)): raise UnsupportedCommissionModel() if self.initialized: raise OverrideCommissionPostInit() self.commission = commission def set_sources(self, sources): assert isinstance(sources, list) self.sources = sources def set_transforms(self, transforms): assert isinstance(transforms, list) self.transforms = transforms def set_data_frequency(self, data_frequency): assert data_frequency in ('daily', 'minute') self.data_frequency = data_frequency self.annualizer = ANNUALIZER[self.data_frequency] @api_method def order_percent(self, sid, percent, limit_price=None, stop_price=None): """ Place an order in the specified security corresponding to the given percent of the current portfolio value. Note that percent must expressed as a decimal (0.50 means 50\%). """ value = self.portfolio.portfolio_value * percent return self.order_value(sid, value, limit_price, stop_price) @api_method def order_target(self, sid, target, limit_price=None, stop_price=None): """ Place an order to adjust a position to a target number of shares. If the position doesn't already exist, this is equivalent to placing a new order. If the position does exist, this is equivalent to placing an order for the difference between the target number of shares and the current number of shares. """ if sid in self.portfolio.positions: current_position = self.portfolio.positions[sid].amount req_shares = target - current_position return self.order(sid, req_shares, limit_price, stop_price) else: return self.order(sid, target, limit_price, stop_price) @api_method def order_target_value(self, sid, target, limit_price=None, stop_price=None): """ Place an order to adjust a position to a target value. If the position doesn't already exist, this is equivalent to placing a new order. If the position does exist, this is equivalent to placing an order for the difference between the target value and the current value. """ if sid in self.portfolio.positions: current_position = self.portfolio.positions[sid].amount current_price = self.portfolio.positions[sid].last_sale_price current_value = current_position * current_price req_value = target - current_value return self.order_value(sid, req_value, limit_price, stop_price) else: return self.order_value(sid, target, limit_price, stop_price) @api_method def order_target_percent(self, sid, target, limit_price=None, stop_price=None): """ Place an order to adjust a position to a target percent of the current portfolio value. If the position doesn't already exist, this is equivalent to placing a new order. If the position does exist, this is equivalent to placing an order for the difference between the target percent and the current percent. Note that target must expressed as a decimal (0.50 means 50\%). """ if sid in self.portfolio.positions: current_position = self.portfolio.positions[sid].amount current_price = self.portfolio.positions[sid].last_sale_price current_value = current_position * current_price else: current_value = 0 target_value = self.portfolio.portfolio_value * target req_value = target_value - current_value return self.order_value(sid, req_value, limit_price, stop_price) @api_method def get_open_orders(self, sid=None): if sid is None: return {key: [order.to_api_obj() for order in orders] for key, orders in self.blotter.open_orders.iteritems()} if sid in self.blotter.open_orders: orders = self.blotter.open_orders[sid] return [order.to_api_obj() for order in orders] return [] @api_method def get_order(self, order_id): if order_id in self.blotter.orders: return self.blotter.orders[order_id].to_api_obj() @api_method def cancel_order(self, order_param): order_id = order_param if isinstance(order_param, zipline.protocol.Order): order_id = order_param.id self.blotter.cancel(order_id) def raw_positions(self): """ Returns the current portfolio for the algorithm. N.B. this is not done as a property, so that the function can be passed and called from within a source. """ # Return the 'internal' positions object, as in the one that is # not passed to the algo, and thus should not have tainted keys. return self.perf_tracker.cumulative_performance.positions def raw_orders(self): """ Returns the current open orders from the blotter. N.B. this is not a property, so that the function can be passed and called back from within a source. """ return self.blotter.open_orders
def __init__(self, *args, **kwargs): """Initialize sids and other state variables. :Arguments: data_frequency : str (daily, hourly or minutely) The duration of the bars. annualizer : int <optional> Which constant to use for annualizing risk metrics. If not provided, will extract from data_frequency. capital_base : float <default: 1.0e5> How much capital to start with. """ self.datetime = None self.registered_transforms = {} self.transforms = [] self.sources = [] self._recorded_vars = {} self.logger = None self.benchmark_return_source = None self.perf_tracker = None # default components for transact self.slippage = VolumeShareSlippage() self.commission = PerShare() if 'data_frequency' in kwargs: self.set_data_frequency(kwargs.pop('data_frequency')) else: self.data_frequency = None self.instant_fill = kwargs.pop('instant_fill', False) # Override annualizer if set if 'annualizer' in kwargs: self.annualizer = kwargs['annualizer'] # set the capital base self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE) self.sim_params = kwargs.pop('sim_params', None) if self.sim_params: self.sim_params.data_frequency = self.data_frequency self.perf_tracker = PerformanceTracker(self.sim_params) self.blotter = kwargs.pop('blotter', None) if not self.blotter: self.blotter = Blotter() self.portfolio_needs_update = True self._portfolio = None # an algorithm subclass needs to set initialized to True when # it is fully initialized. self.initialized = False # call to user-defined constructor method self.initialize(*args, **kwargs)
def transaction_sim(self, **params): """ This is a utility method that asserts expected results for conversion of orders to transactions given a trade history""" tempdir = TempDirectory() try: trade_count = params['trade_count'] trade_interval = params['trade_interval'] order_count = params['order_count'] order_amount = params['order_amount'] order_interval = params['order_interval'] expected_txn_count = params['expected_txn_count'] expected_txn_volume = params['expected_txn_volume'] # optional parameters # --------------------- # if present, alternate between long and short sales alternate = params.get('alternate') # if present, expect transaction amounts to match orders exactly. complete_fill = params.get('complete_fill') env = TradingEnvironment() sid = 1 if trade_interval < timedelta(days=1): sim_params = factory.create_simulation_parameters( data_frequency="minute") minutes = env.market_minute_window( sim_params.first_open, int((trade_interval.total_seconds() / 60) * trade_count) + 100) price_data = np.array([10.1] * len(minutes)) assets = { sid: pd.DataFrame({ "open": price_data, "high": price_data, "low": price_data, "close": price_data, "volume": np.array([100] * len(minutes)), "dt": minutes }).set_index("dt") } write_bcolz_minute_data( env, env.days_in_range(minutes[0], minutes[-1]), tempdir.path, assets) equity_minute_reader = BcolzMinuteBarReader(tempdir.path) data_portal = DataPortal( env, equity_minute_reader=equity_minute_reader, ) else: sim_params = factory.create_simulation_parameters( data_frequency="daily") days = sim_params.trading_days assets = { 1: pd.DataFrame( { "open": [10.1] * len(days), "high": [10.1] * len(days), "low": [10.1] * len(days), "close": [10.1] * len(days), "volume": [100] * len(days), "day": [day.value for day in days] }, index=days) } path = os.path.join(tempdir.path, "testdata.bcolz") DailyBarWriterFromDataFrames(assets).write(path, days, assets) equity_daily_reader = BcolzDailyBarReader(path) data_portal = DataPortal( env, equity_daily_reader=equity_daily_reader, ) if "default_slippage" not in params or \ not params["default_slippage"]: slippage_func = FixedSlippage() else: slippage_func = None blotter = Blotter(sim_params.data_frequency, self.env.asset_finder, slippage_func) env.write_data( equities_data={ sid: { "start_date": sim_params.trading_days[0], "end_date": sim_params.trading_days[-1] } }) start_date = sim_params.first_open if alternate: alternator = -1 else: alternator = 1 tracker = PerformanceTracker(sim_params, self.env) # replicate what tradesim does by going through every minute or day # of the simulation and processing open orders each time if sim_params.data_frequency == "minute": ticks = minutes else: ticks = days transactions = [] order_list = [] order_date = start_date for tick in ticks: blotter.current_dt = tick if tick >= order_date and len(order_list) < order_count: # place an order direction = alternator**len(order_list) order_id = blotter.order( blotter.asset_finder.retrieve_asset(sid), order_amount * direction, MarketOrder()) order_list.append(blotter.orders[order_id]) order_date = order_date + order_interval # move after market orders to just after market next # market open. if order_date.hour >= 21: if order_date.minute >= 00: order_date = order_date + timedelta(days=1) order_date = order_date.replace(hour=14, minute=30) else: bar_data = BarData(data_portal, lambda: tick, sim_params.data_frequency) txns, _ = blotter.get_transactions(bar_data) for txn in txns: tracker.process_transaction(txn) transactions.append(txn) for i in range(order_count): order = order_list[i] self.assertEqual(order.sid, sid) self.assertEqual(order.amount, order_amount * alternator**i) if complete_fill: self.assertEqual(len(transactions), len(order_list)) total_volume = 0 for i in range(len(transactions)): txn = transactions[i] total_volume += txn.amount if complete_fill: order = order_list[i] self.assertEqual(order.amount, txn.amount) self.assertEqual(total_volume, expected_txn_volume) self.assertEqual(len(transactions), expected_txn_count) cumulative_pos = tracker.position_tracker.positions[sid] if total_volume == 0: self.assertIsNone(cumulative_pos) else: self.assertEqual(total_volume, cumulative_pos.amount) # the open orders should not contain sid. oo = blotter.open_orders self.assertNotIn(sid, oo, "Entry is removed when no open orders") finally: tempdir.cleanup()
def transaction_sim(self, **params): """ This is a utility method that asserts expected results for conversion of orders to transactions given a trade history""" trade_count = params['trade_count'] trade_interval = params['trade_interval'] order_count = params['order_count'] order_amount = params['order_amount'] order_interval = params['order_interval'] expected_txn_count = params['expected_txn_count'] expected_txn_volume = params['expected_txn_volume'] # optional parameters # --------------------- # if present, alternate between long and short sales alternate = params.get('alternate') # if present, expect transaction amounts to match orders exactly. complete_fill = params.get('complete_fill') sid = 1 sim_params = factory.create_simulation_parameters() trade_sim = TransactionSimulator() price = [10.1] * trade_count volume = [100] * trade_count start_date = sim_params.first_open generated_trades = factory.create_trade_history( sid, price, volume, trade_interval, sim_params) if alternate: alternator = -1 else: alternator = 1 order_date = start_date for i in xrange(order_count): order = Order( **{ 'sid': sid, 'amount': order_amount * alternator**i, 'dt': order_date }) trade_sim.place_order(order) order_date = order_date + order_interval # move after market orders to just after market next # market open. if order_date.hour >= 21: if order_date.minute >= 00: order_date = order_date + timedelta(days=1) order_date = order_date.replace(hour=14, minute=30) # there should now be one open order list stored under the sid oo = trade_sim.open_orders self.assertEqual(len(oo), 1) self.assertTrue(sid in oo) order_list = oo[sid] self.assertEqual(order_count, len(order_list)) for i in xrange(order_count): order = order_list[i] self.assertEqual(order.sid, sid) self.assertEqual(order.amount, order_amount * alternator**i) tracker = PerformanceTracker(sim_params) # this approximates the loop inside TradingSimulationClient transactions = [] for dt, trades in itertools.groupby(generated_trades, operator.attrgetter('dt')): for trade in trades: trade_sim.update(trade) if trade.TRANSACTION: transactions.append(trade.TRANSACTION) tracker.process_event(trade) if complete_fill: self.assertEqual(len(transactions), len(order_list)) total_volume = 0 for i in xrange(len(transactions)): txn = transactions[i] total_volume += txn.amount if complete_fill: order = order_list[i] self.assertEqual(order.amount, txn.amount) self.assertEqual(total_volume, expected_txn_volume) self.assertEqual(len(transactions), expected_txn_count) cumulative_pos = tracker.cumulative_performance.positions[sid] self.assertEqual(total_volume, cumulative_pos.amount) # the open orders should now be empty oo = trade_sim.open_orders self.assertTrue(sid in oo) order_list = oo[sid] self.assertEqual(0, len(order_list))
class TradeSimulationClient(object): """ Generator-style class that takes the expected output of a merge, a user algorithm, a trading environment, and a simulator slippage as arguments. Pipes the merge stream through a TransactionSimulator and a PerformanceTracker, which keep track of the current state of our algorithm's simulated universe. Results are fed to the user's algorithm, which directly inserts transactions into the TransactionSimulator's order book. TransactionSimulator maintains a dictionary from sids to the as-yet unfilled orders placed by the user's algorithm. As trade events arrive, if the algorithm has open orders against the trade's sid, the simulator will fill orders up to 25% of market cap. Applied transactions are added to a txn field on the event and forwarded to PerformanceTracker. The txn field is set to None on non-trade events and events that do not match any open orders. PerformanceTracker receives the updated event messages from TransactionSimulator, maintaining a set of daily and cumulative performance metrics for the algorithm. The tracker removes the txn field from each event it receives, replacing it with a portfolio field to be fed into the user algo. At the end of each trading day, the PerformanceTracker also generates a daily performance report, which is appended to event's perf_report field. Fully processed events are fed to AlgorithmSimulator, which batches together events with the same dt field into a single snapshot to be fed to the algo. The portfolio object is repeatedly overwritten so that only the most recent snapshot of the universe is sent to the algo. """ def __init__(self, algo, environment): self.algo = algo self.environment = environment self.ordering_client = TransactionSimulator() self.perf_tracker = PerformanceTracker(self.environment) self.algo_start = self.environment.first_open self.algo_sim = AlgorithmSimulator( self.ordering_client, self.perf_tracker, self.algo, self.algo_start ) def get_hash(self): """ There should only ever be one TSC in the system, so we don't bother passing args into the hash. """ return self.__class__.__name__ + hash_args() def simulate(self, stream_in): """ Main generator work loop. """ # Simulate filling any open orders made by the previous run of # the user's algorithm. Fills the Transaction field on any # event that results in a filled order. with_filled_orders = self.ordering_client.transform(stream_in) # Pipe the events with transactions to perf. This will remove # the TRANSACTION field added by TransactionSimulator and replace it # with a portfolio field to be passed to the user's # algorithm. Also adds a perf_messages field which is usually # empty, but contains update messages once per day. with_portfolio = self.perf_tracker.transform(with_filled_orders) # Pass the messages from perf to the user's algorithm for simulation. # Events are batched by dt so that the algo handles all events for a # given timestamp at one one go. performance_messages = self.algo_sim.transform(with_portfolio) # The algorithm will yield a daily_results message (as # calculated by the performance tracker) at the end of each # day. It will also yield a risk report at the end of the # simulation. for message in performance_messages: yield message
def transaction_sim(self, **params): """This is a utility method that asserts expected results for conversion of orders to transactions given a trade history """ trade_count = params["trade_count"] trade_interval = params["trade_interval"] order_count = params["order_count"] order_amount = params["order_amount"] order_interval = params["order_interval"] expected_txn_count = params["expected_txn_count"] expected_txn_volume = params["expected_txn_volume"] # optional parameters # --------------------- # if present, alternate between long and short sales alternate = params.get("alternate") # if present, expect transaction amounts to match orders exactly. complete_fill = params.get("complete_fill") sid = 1 metadata = make_simple_equity_info([sid], self.start, self.end) with TempDirectory() as tempdir, tmp_trading_env(equities=metadata) as env: if trade_interval < timedelta(days=1): sim_params = factory.create_simulation_parameters( start=self.start, end=self.end, data_frequency="minute" ) minutes = env.market_minute_window( sim_params.first_open, int((trade_interval.total_seconds() / 60) * trade_count) + 100 ) price_data = np.array([10.1] * len(minutes)) assets = { sid: pd.DataFrame( { "open": price_data, "high": price_data, "low": price_data, "close": price_data, "volume": np.array([100] * len(minutes)), "dt": minutes, } ).set_index("dt") } write_bcolz_minute_data( env, env.days_in_range(minutes[0], minutes[-1]), tempdir.path, iteritems(assets) ) equity_minute_reader = BcolzMinuteBarReader(tempdir.path) data_portal = DataPortal( env, first_trading_day=equity_minute_reader.first_trading_day, equity_minute_reader=equity_minute_reader, ) else: sim_params = factory.create_simulation_parameters(data_frequency="daily") days = sim_params.trading_days assets = { 1: pd.DataFrame( { "open": [10.1] * len(days), "high": [10.1] * len(days), "low": [10.1] * len(days), "close": [10.1] * len(days), "volume": [100] * len(days), "day": [day.value for day in days], }, index=days, ) } path = os.path.join(tempdir.path, "testdata.bcolz") BcolzDailyBarWriter(path, days).write(assets.items()) equity_daily_reader = BcolzDailyBarReader(path) data_portal = DataPortal( env, first_trading_day=equity_daily_reader.first_trading_day, equity_daily_reader=equity_daily_reader, ) if "default_slippage" not in params or not params["default_slippage"]: slippage_func = FixedSlippage() else: slippage_func = None blotter = Blotter(sim_params.data_frequency, self.env.asset_finder, slippage_func) start_date = sim_params.first_open if alternate: alternator = -1 else: alternator = 1 tracker = PerformanceTracker(sim_params, self.env) # replicate what tradesim does by going through every minute or day # of the simulation and processing open orders each time if sim_params.data_frequency == "minute": ticks = minutes else: ticks = days transactions = [] order_list = [] order_date = start_date for tick in ticks: blotter.current_dt = tick if tick >= order_date and len(order_list) < order_count: # place an order direction = alternator ** len(order_list) order_id = blotter.order( blotter.asset_finder.retrieve_asset(sid), order_amount * direction, MarketOrder() ) order_list.append(blotter.orders[order_id]) order_date = order_date + order_interval # move after market orders to just after market next # market open. if order_date.hour >= 21: if order_date.minute >= 00: order_date = order_date + timedelta(days=1) order_date = order_date.replace(hour=14, minute=30) else: bar_data = BarData(data_portal, lambda: tick, sim_params.data_frequency) txns, _, closed_orders = blotter.get_transactions(bar_data) for txn in txns: tracker.process_transaction(txn) transactions.append(txn) blotter.prune_orders(closed_orders) for i in range(order_count): order = order_list[i] self.assertEqual(order.sid, sid) self.assertEqual(order.amount, order_amount * alternator ** i) if complete_fill: self.assertEqual(len(transactions), len(order_list)) total_volume = 0 for i in range(len(transactions)): txn = transactions[i] total_volume += txn.amount if complete_fill: order = order_list[i] self.assertEqual(order.amount, txn.amount) self.assertEqual(total_volume, expected_txn_volume) self.assertEqual(len(transactions), expected_txn_count) cumulative_pos = tracker.position_tracker.positions[sid] if total_volume == 0: self.assertIsNone(cumulative_pos) else: self.assertEqual(total_volume, cumulative_pos.amount) # the open orders should not contain sid. oo = blotter.open_orders self.assertNotIn(sid, oo, "Entry is removed when no open orders")