def test_empty_pipeline(self): # For ensuring we call before_trading_start. count = [0] def initialize(context): pipeline = attach_pipeline(Pipeline(), 'test') vwap = VWAP(window_length=10) pipeline.add(vwap, 'vwap') # Nothing should have prices less than 0. pipeline.set_screen(vwap < 0) def handle_data(context, data): pass def before_trading_start(context, data): context.results = pipeline_output('test') self.assertTrue(context.results.empty) count[0] += 1 algo = TradingAlgorithm( initialize=initialize, handle_data=handle_data, before_trading_start=before_trading_start, data_frequency='daily', get_pipeline_loader=lambda column: self.pipeline_loader, start=self.dates[0], end=self.dates[-1], env=self.env, ) algo.run( FakeDataPortal(), overwrite_sim_params=False, ) self.assertTrue(count[0] > 0)
def test_before_trading_start(self, test_name, num_days, freq, emission_rate): params = factory.create_simulation_parameters( num_days=num_days, data_frequency=freq, emission_rate=emission_rate) def fake_benchmark(self, dt): return 0.01 with patch.object(BenchmarkSource, "get_value", self.fake_minutely_benchmark): algo = BeforeTradingAlgorithm(sim_params=params) algo.run(FakeDataPortal()) self.assertEqual(algo.perf_tracker.day_count, num_days) self.assertTrue( params.trading_days.equals( pd.DatetimeIndex(algo.before_trading_at)), "Expected %s but was %s." % (params.trading_days, algo.before_trading_at))
def test_before_trading_start(self, test_name, num_days, freq, emission_rate): params = factory.create_simulation_parameters( num_days=num_days, data_frequency=freq, emission_rate=emission_rate) def fake_benchmark(self, dt): return 0.01 with patch.object(BenchmarkSource, "get_value", self.fake_minutely_benchmark): algo = BeforeTradingAlgorithm(sim_params=params, env=self.env) algo.run(FakeDataPortal(self.env)) self.assertEqual(len(algo.sim_params.sessions), num_days) bts_minutes = days_at_time(params.sessions, time(8, 45), "US/Eastern") self.assertTrue( bts_minutes.equals(pd.DatetimeIndex(algo.before_trading_at)), "Expected %s but was %s." % (params.sessions, algo.before_trading_at))
def test_handle_adjustment(self, set_screen): AAPL, MSFT, BRK_A = assets = self.AAPL, self.MSFT, self.BRK_A window_lengths = [1, 2, 5, 10] vwaps = self.compute_expected_vwaps(window_lengths) def vwap_key(length): return "vwap_%d" % length def initialize(context): pipeline = Pipeline() context.vwaps = [] for length in vwaps: name = vwap_key(length) factor = VWAP(window_length=length) context.vwaps.append(factor) pipeline.add(factor, name=name) filter_ = (USEquityPricing.close.latest > 300) pipeline.add(filter_, 'filter') if set_screen: pipeline.set_screen(filter_) attach_pipeline(pipeline, 'test') def handle_data(context, data): today = get_datetime() results = pipeline_output('test') expect_over_300 = { AAPL: today < self.AAPL_split_date, MSFT: False, BRK_A: True, } for asset in assets: should_pass_filter = expect_over_300[asset] if set_screen and not should_pass_filter: self.assertNotIn(asset, results.index) continue asset_results = results.loc[asset] self.assertEqual(asset_results['filter'], should_pass_filter) for length in vwaps: computed = results.loc[asset, vwap_key(length)] expected = vwaps[length][asset].loc[today] # Only having two places of precision here is a bit # unfortunate. assert_almost_equal(computed, expected, decimal=2) # Do the same checks in before_trading_start before_trading_start = handle_data algo = TradingAlgorithm( initialize=initialize, handle_data=handle_data, before_trading_start=before_trading_start, data_frequency='daily', get_pipeline_loader=lambda column: self.pipeline_loader, start=self.dates[max(window_lengths)], end=self.dates[-1], env=self.env, ) algo.run( FakeDataPortal(), # Yes, I really do want to use the start and end dates I passed to # TradingAlgorithm. overwrite_sim_params=False, )