def setUp(self): self.sim_params = factory.create_simulation_parameters(num_days=4) self.sid = 133 self.trade_history = factory.create_trade_history( self.sid, [10.0, 10.0, 11.0, 11.0], [100, 100, 100, 300], timedelta(days=1), self.sim_params) self.source = SpecificEquityTrades(event_list=self.trade_history)
def test_single_source(self): # Just using the built-in defaults. See # zipline.sources.py source = SpecificEquityTrades() expected = list(source) source.rewind() # The raw source doesn't handle done messaging, so we need to # append a done message for sort to work properly. with_done = chain(source, [done_message(source.get_hash())]) self.run_date_sort(with_done, expected, [source.get_hash()])
class TestAccountControls(TestCase): def setUp(self): self.sim_params = factory.create_simulation_parameters(num_days=4) self.sid = 133 self.trade_history = factory.create_trade_history( self.sid, [10.0, 10.0, 11.0, 11.0], [100, 100, 100, 300], timedelta(days=1), self.sim_params ) self.source = SpecificEquityTrades(event_list=self.trade_history) def _check_algo(self, algo, handle_data, expected_exc): algo._handle_data = handle_data with self.assertRaises(expected_exc) if expected_exc else nullctx(): algo.run(self.source) self.source.rewind() def check_algo_succeeds(self, algo, handle_data): # Default for order_count assumes one order per handle_data call. self._check_algo(algo, handle_data, None) def check_algo_fails(self, algo, handle_data): self._check_algo(algo, handle_data, AccountControlViolation) def test_set_max_leverage(self): # Set max leverage to 0 so buying one share fails. def handle_data(algo, data): algo.order(self.sid, 1) algo = SetMaxLeverageAlgorithm(0) self.check_algo_fails(algo, handle_data) # Set max leverage to 1 so buying one share passes def handle_data(algo, data): algo.order(self.sid, 1) algo = SetMaxLeverageAlgorithm(1) self.check_algo_succeeds(algo, handle_data)
class TestAccountControls(TestCase): def setUp(self): self.sim_params = factory.create_simulation_parameters(num_days=4) self.sidint = 133 self.trade_history = factory.create_trade_history( self.sidint, [10.0, 10.0, 11.0, 11.0], [100, 100, 100, 300], timedelta(days=1), self.sim_params ) self.source = SpecificEquityTrades(event_list=self.trade_history) def _check_algo(self, algo, handle_data, expected_exc): algo._handle_data = handle_data with self.assertRaises(expected_exc) if expected_exc else nullctx(): algo.run(self.source) self.source.rewind() def check_algo_succeeds(self, algo, handle_data): # Default for order_count assumes one order per handle_data call. self._check_algo(algo, handle_data, None) def check_algo_fails(self, algo, handle_data): self._check_algo(algo, handle_data, AccountControlViolation) def test_set_max_leverage(self): # Set max leverage to 0 so buying one share fails. def handle_data(algo, data): algo.order(algo.sid(self.sidint), 1) algo = SetMaxLeverageAlgorithm(0) self.check_algo_fails(algo, handle_data) # Set max leverage to 1 so buying one share passes def handle_data(algo, data): algo.order(algo.sid(self.sidint), 1) algo = SetMaxLeverageAlgorithm(1) self.check_algo_succeeds(algo, handle_data)
def test_moving_stddev(self): trade_history = factory.create_trade_history( 133, [10.0, 15.0, 13.0, 12.0], [100, 100, 100, 100], timedelta(days=1), self.sim_params ) stddev = MovingStandardDev( market_aware=True, window_length=3, ) self.source = SpecificEquityTrades(event_list=trade_history) transformed = list(stddev.transform(self.source)) vals = [message[stddev.get_hash()] for message in transformed] expected = [ None, np.std([10.0, 15.0], ddof=1), np.std([10.0, 15.0, 13.0], ddof=1), np.std([15.0, 13.0, 12.0], ddof=1), ] # np has odd rounding behavior, cf. # http://docs.scipy.org/doc/np/reference/generated/np.std.html for v1, v2 in zip(vals, expected): if v1 is None: self.assertIsNone(v2) continue self.assertEquals(round(v1, 5), round(v2, 5))
def test_returns(self): # Daily returns. returns = Returns(1) transformed = list(returns.transform(self.source)) tnfm_vals = [message.tnfm_value for message in transformed] # No returns for the first event because we don't have a # previous close. expected = [0.0, 0.0, 0.1, 0.0] assert tnfm_vals == expected # Two-day returns. An extra kink here is that the # factory will automatically skip a weekend for the # last event. Results shouldn't notice this blip. trade_history = factory.create_trade_history( 133, [10.0, 15.0, 13.0, 12.0, 13.0], [100, 100, 100, 300, 100], timedelta(days=1), self.trading_environment) self.source = SpecificEquityTrades(event_list=trade_history) returns = StatefulTransform(Returns, 2) transformed = list(returns.transform(self.source)) tnfm_vals = [message.tnfm_value for message in transformed] expected = [ 0.0, 0.0, (13.0 - 10.0) / 10.0, (12.0 - 15.0) / 15.0, (13.0 - 13.0) / 13.0 ] assert tnfm_vals == expected
def test_moving_stddev(self): trade_history = factory.create_trade_history(133, [10.0, 15.0, 13.0, 12.0], [100, 100, 100, 100], timedelta(hours=1), self.trading_environment) stddev = MovingStandardDev( market_aware=False, delta=timedelta(minutes=150), ) self.source = SpecificEquityTrades(event_list=trade_history) transformed = list(stddev.transform(self.source)) vals = [message.tnfm_value for message in transformed] expected = [ None, np.std([10.0, 15.0], ddof=1), np.std([10.0, 15.0, 13.0], ddof=1), np.std([15.0, 13.0, 12.0], ddof=1), ] # np has odd rounding behavior, cf. # http://docs.scipy.org/doc/np/reference/generated/np.std.html for v1, v2 in zip(vals, expected): if v1 is None: assert v2 is None continue assert round(v1, 5) == round(v2, 5)
def create_trade_source(sids, trade_time_increment, sim_params, asset_finder, trading_calendar): # If the sim_params define an end that is during market hours, that will be # used as the end of the data source if trading_calendar.is_open_on_minute(sim_params.end_session): end = sim_params.end_session # Otherwise, the last_close after the end_session is used as the end of the # data source else: end = sim_params.last_close args = tuple() kwargs = { 'sids': sids, 'start': sim_params.first_open, 'end': end, 'delta': trade_time_increment, 'trading_calendar': trading_calendar, 'asset_finder': asset_finder, } source = SpecificEquityTrades(*args, **kwargs) return source
def test_algo_with_rl_violation_cumulative(self): """ Add a new restriction, run a test long after both knowledge dates, make sure stock from original restriction set is still disallowed. """ sim_params = factory.create_simulation_parameters( start=list(LEVERAGED_ETFS.keys())[0] + timedelta(days=7), num_days=4) with security_list_copy(): add_security_data(['AAPL'], []) trade_history = factory.create_trade_history( 'BZQ', [10.0, 10.0, 11.0, 11.0], [100, 100, 100, 300], timedelta(days=1), sim_params, env=self.env, ) self.source = SpecificEquityTrades(event_list=trade_history, env=self.env) algo = RestrictedAlgoWithoutCheck(symbol='BZQ', sim_params=sim_params, env=self.env) with self.assertRaises(TradingControlViolation) as ctx: algo.run(self.source) self.check_algo_exception(algo, ctx, 0)
def test_set_max_order_count(self): # Override the default setUp to use six-hour intervals instead of full # days so we can exercise trading-session rollover logic. trade_history = factory.create_trade_history( self.sid, [10.0, 10.0, 11.0, 11.0], [100, 100, 100, 300], timedelta(hours=6), self.sim_params ) self.source = SpecificEquityTrades(event_list=trade_history) def handle_data(algo, data): for i in range(5): algo.order(self.sid, 1) algo.order_count += 1 algo = SetMaxOrderCountAlgorithm(3) self.check_algo_fails(algo, handle_data, 3) # Second call to handle_data is the same day as the first, so the last # order of the second call should fail. algo = SetMaxOrderCountAlgorithm(9) self.check_algo_fails(algo, handle_data, 9) # Only ten orders are placed per day, so this should pass even though # in total more than 20 orders are placed. algo = SetMaxOrderCountAlgorithm(10) self.check_algo_succeeds(algo, handle_data, order_count=20)
def create_trade_source(sids, trade_time_increment, sim_params, env, concurrent=False): # If the sim_params define an end that is during market hours, that will be # used as the end of the data source if env.is_market_hours(sim_params.period_end): end = sim_params.period_end # Otherwise, the last_close after the period_end is used as the end of the # data source else: end = sim_params.last_close args = tuple() kwargs = { 'sids': sids, 'start': sim_params.first_open, 'end': end, 'delta': trade_time_increment, 'filter': sids, 'concurrent': concurrent, 'env': env, } source = SpecificEquityTrades(*args, **kwargs) return source
def setUp(self): self.trading_environment = factory.create_trading_environment() setup_logger(self) trade_history = factory.create_trade_history(133, [10.0, 10.0, 11.0, 11.0], [100, 100, 100, 300], timedelta(days=1), self.trading_environment) self.source = SpecificEquityTrades(event_list=trade_history)
def setUp(self): self.sim_params = factory.create_simulation_parameters() trade_history = factory.create_trade_history(133, [10.0, 10.0, 11.0, 11.0], [100, 100, 100, 300], timedelta(days=1), self.sim_params) self.source = SpecificEquityTrades(event_list=trade_history) self.df_source, self.df = \ factory.create_test_df_source(self.sim_params)
def test_algo_with_rl_violation(self): sim_params = factory.create_simulation_parameters( start=list(LEVERAGED_ETFS.keys())[0], num_days=4) trade_history = factory.create_trade_history( 'BZQ', [10.0, 10.0, 11.0, 11.0], [100, 100, 100, 300], timedelta(days=1), sim_params ) self.source = SpecificEquityTrades(event_list=trade_history) self.df_source, self.df = \ factory.create_test_df_source(sim_params) algo = RestrictedAlgoWithoutCheck(sid='BZQ', sim_params=sim_params) with self.assertRaises(TradingControlViolation) as ctx: algo.run(self.source) self.check_algo_exception(algo, ctx, 0) # repeat with a symbol from a different lookup date trade_history = factory.create_trade_history( 'JFT', [10.0, 10.0, 11.0, 11.0], [100, 100, 100, 300], timedelta(days=1), sim_params ) self.source = SpecificEquityTrades(event_list=trade_history) self.df_source, self.df = \ factory.create_test_df_source(sim_params) algo = RestrictedAlgoWithoutCheck(sid='JFT', sim_params=sim_params) with self.assertRaises(TradingControlViolation) as ctx: algo.run(self.source) self.check_algo_exception(algo, ctx, 0)
def setUp(self): self.sim_params = factory.create_simulation_parameters(num_days=4) self.sid = 133 self.trade_history = factory.create_trade_history( self.sid, [10.0, 10.0, 11.0, 11.0], [100, 100, 100, 300], timedelta(days=1), self.sim_params ) self.source = SpecificEquityTrades(event_list=self.trade_history)
def test_algo_without_rl_violation(self): sim_params = factory.create_simulation_parameters(start=list( LEVERAGED_ETFS.keys())[0], num_days=4) trade_history = factory.create_trade_history('AAPL', [10.0, 10.0, 11.0, 11.0], [100, 100, 100, 300], timedelta(days=1), sim_params) self.source = SpecificEquityTrades(event_list=trade_history) algo = RestrictedAlgoWithoutCheck(symbol='AAPL', sim_params=sim_params) algo.run(self.source)
def test_iterate_over_rl(self): sim_params = factory.create_simulation_parameters(start=list( LEVERAGED_ETFS.keys())[0], num_days=4) trade_history = factory.create_trade_history('BZQ', [10.0, 10.0, 11.0, 11.0], [100, 100, 100, 300], timedelta(days=1), sim_params) self.source = SpecificEquityTrades(event_list=trade_history) algo = IterateRLAlgo(symbol='BZQ', sim_params=sim_params) algo.run(self.source) self.assertTrue(algo.found)
def test_algo_with_rl_violation_after_add(self): with security_list_copy(): add_security_data(['AAPL'], []) sim_params = factory.create_simulation_parameters( start=self.trading_day_before_first_kd, num_days=4) trade_history = factory.create_trade_history( 'AAPL', [10.0, 10.0, 11.0, 11.0], [100, 100, 100, 300], timedelta(days=1), sim_params) self.source = SpecificEquityTrades(event_list=trade_history) algo = RestrictedAlgoWithoutCheck(symbol='AAPL', sim_params=sim_params) with self.assertRaises(TradingControlViolation) as ctx: algo.run(self.source) self.check_algo_exception(algo, ctx, 2)
def test_algo_without_rl_violation_after_delete(self): with security_list_copy(): # add a delete statement removing bzq # write a new delete statement file to disk add_security_data([], ['BZQ']) sim_params = factory.create_simulation_parameters( start=self.extra_knowledge_date, num_days=3) trade_history = factory.create_trade_history( 'BZQ', [10.0, 10.0, 11.0, 11.0], [100, 100, 100, 300], timedelta(days=1), sim_params) self.source = SpecificEquityTrades(event_list=trade_history) algo = RestrictedAlgoWithoutCheck(symbol='BZQ', sim_params=sim_params) algo.run(self.source)
def test_algo_with_rl_violation_after_knowledge_date(self): sim_params = factory.create_simulation_parameters( start=list(LEVERAGED_ETFS.keys())[0] + timedelta(days=7), num_days=5) trade_history = factory.create_trade_history('BZQ', [10.0, 10.0, 11.0, 11.0], [100, 100, 100, 300], timedelta(days=1), sim_params) self.source = SpecificEquityTrades(event_list=trade_history) algo = RestrictedAlgoWithoutCheck(symbol='BZQ', sim_params=sim_params) with self.assertRaises(TradingControlViolation) as ctx: algo.run(self.source) self.check_algo_exception(algo, ctx, 0)
def test_algo_with_rl_violation_on_knowledge_date(self): sim_params = factory.create_simulation_parameters( start=self.trading_day_before_first_kd, num_days=4) trade_history = factory.create_trade_history('BZQ', [10.0, 10.0, 11.0, 11.0], [100, 100, 100, 300], timedelta(days=1), sim_params) self.source = SpecificEquityTrades(event_list=trade_history) algo = RestrictedAlgoWithoutCheck(sid='BZQ', sim_params=sim_params) with self.assertRaises(TradingControlViolation) as ctx: algo.run(self.source) self.check_algo_exception(algo, ctx, 1)
def setUp(self): days = 251 self.sim_params = factory.create_simulation_parameters(num_days=days) setup_logger(self) trade_history = factory.create_trade_history(133, [10.0] * days, [100] * days, timedelta(days=1), self.sim_params) self.source = SpecificEquityTrades(event_list=trade_history) self.df_source, self.df = \ factory.create_test_df_source(self.sim_params) self.zipline_test_config = { 'sid': 0, }
def create_trade_source(sids, trade_count, trade_time_increment, sim_params, concurrent=False): args = tuple() kwargs = { 'count': trade_count, 'sids': sids, 'start': sim_params.first_open, 'delta': trade_time_increment, 'filter': sids, 'concurrent': concurrent } source = SpecificEquityTrades(*args, **kwargs) # TODO: do we need to set the trading environment's end to same dt as # the last trade in the history? # sim_params.period_end = trade_history[-1].dt return source
def test_returns(self, name, add_custom_events): # Daily returns. returns = Returns(1) if add_custom_events: self.source = self.intersperse_custom_events(self.source) transformed = list(returns.transform(self.source)) tnfm_vals = [ message[returns.get_hash()] for message in transformed if message.type != DATASOURCE_TYPE.CUSTOM ] # No returns for the first event because we don't have a # previous close. expected = [0.0, 0.0, 0.1, 0.0] self.assertEquals(tnfm_vals, expected) # Two-day returns. An extra kink here is that the # factory will automatically skip a weekend for the # last event. Results shouldn't notice this blip. trade_history = factory.create_trade_history( 133, [10.0, 15.0, 13.0, 12.0, 13.0], [100, 100, 100, 300, 100], timedelta(days=1), self.sim_params) self.source = SpecificEquityTrades(event_list=trade_history) returns = StatefulTransform(Returns, 2) transformed = list(returns.transform(self.source)) tnfm_vals = [message[returns.get_hash()] for message in transformed] expected = [ 0.0, 0.0, (13.0 - 10.0) / 10.0, (12.0 - 15.0) / 15.0, (13.0 - 13.0) / 13.0 ] self.assertEquals(tnfm_vals, expected)
def test_sort_composite(self): filter = [1, 2] #Set up source a. One hour between events. args_a = tuple() kwargs_a = { 'count': 100, 'sids': [1], 'start': datetime(2012, 6, 6, 0), 'delta': timedelta(hours=1), 'filter': filter } source_a = SpecificEquityTrades(*args_a, **kwargs_a) #Set up source b. One day between events. args_b = tuple() kwargs_b = { 'count': 50, 'sids': [2], 'start': datetime(2012, 6, 6, 0), 'delta': timedelta(days=1), 'filter': filter } source_b = SpecificEquityTrades(*args_b, **kwargs_b) #Set up source c. One minute between events. args_c = tuple() kwargs_c = { 'count': 150, 'sids': [1, 2], 'start': datetime(2012, 6, 6, 0), 'delta': timedelta(minutes=1), 'filter': filter } source_c = SpecificEquityTrades(*args_c, **kwargs_c) # Set up source d. This should produce no events because the # internal sids don't match the filter. args_d = tuple() kwargs_d = { 'count': 50, 'sids': [3], 'start': datetime(2012, 6, 6, 0), 'delta': timedelta(minutes=1), 'filter': filter } source_d = SpecificEquityTrades(*args_d, **kwargs_d) sources = [source_a, source_b, source_c, source_d] hashes = [source.get_hash() for source in sources] sort_out = date_sorted_sources(*sources) # Read all the values from sort and assert that they arrive in # the correct sorting with the expected hash values. to_list = list(sort_out) copy = to_list[:] # We should have 300 events (100 from a, 150 from b, 50 from c) assert len(to_list) == 300 for e in to_list: # All events should match one of our expected source_ids. assert e.source_id in hashes # But none of them should match source_d. assert e.source_id != source_d.get_hash() # The events should be sorted by dt, with source_id as tiebreaker. expected = sorted(copy, comp) assert to_list == expected
def test_multi_source(self): filter = [2, 3] args_a = tuple() kwargs_a = { 'count': 100, 'sids': [1, 2, 3], 'start': datetime(2012, 1, 3, 15, tzinfo=pytz.utc), 'delta': timedelta(minutes=6), 'filter': filter } source_a = SpecificEquityTrades(*args_a, **kwargs_a) args_b = tuple() kwargs_b = { 'count': 100, 'sids': [2, 3, 4], 'start': datetime(2012, 1, 3, 15, tzinfo=pytz.utc), 'delta': timedelta(minutes=5), 'filter': filter } source_b = SpecificEquityTrades(*args_b, **kwargs_b) all_events = list(chain(source_a, source_b)) # The expected output is all events, sorted by dt with # source_id as a tiebreaker. expected = sorted(all_events, comp) source_ids = [source_a.get_hash(), source_b.get_hash()] # Generating the events list consumes the sources. Rewind them # for testing. source_a.rewind() source_b.rewind() # Append a done message to each source. with_done_a = chain(source_a, [done_message(source_a.get_hash())]) with_done_b = chain(source_b, [done_message(source_b.get_hash())]) interleaved = alternate(with_done_a, with_done_b) # Test sort with alternating messages from source_a and # source_b. self.run_date_sort(interleaved, expected, source_ids) source_a.rewind() source_b.rewind() with_done_a = chain(source_a, [done_message(source_a.get_hash())]) with_done_b = chain(source_b, [done_message(source_b.get_hash())]) sequential = chain(with_done_a, with_done_b) # Test sort with all messages from a, followed by all messages # from b. self.run_date_sort(sequential, expected, source_ids)
class TestTradingControls(TestCase): def setUp(self): self.sim_params = factory.create_simulation_parameters(num_days=4) self.sid = 133 self.trade_history = factory.create_trade_history( self.sid, [10.0, 10.0, 11.0, 11.0], [100, 100, 100, 300], timedelta(days=1), self.sim_params ) self.source = SpecificEquityTrades(event_list=self.trade_history) def _check_algo(self, algo, handle_data, expected_order_count, expected_exc): algo._handle_data = handle_data with self.assertRaises(expected_exc) if expected_exc else nullctx(): algo.run(self.source) self.assertEqual(algo.order_count, expected_order_count) self.source.rewind() def check_algo_succeeds(self, algo, handle_data, order_count=4): # Default for order_count assumes one order per handle_data call. self._check_algo(algo, handle_data, order_count, None) def check_algo_fails(self, algo, handle_data, order_count): self._check_algo(algo, handle_data, order_count, TradingControlViolation) def test_set_max_position_size(self): # Buy one share four times. Should be fine. def handle_data(algo, data): algo.order(self.sid, 1) algo.order_count += 1 algo = SetMaxPositionSizeAlgorithm(sid=self.sid, max_shares=10, max_notional=500.0) self.check_algo_succeeds(algo, handle_data) # Buy three shares four times. Should bail on the fourth before it's # placed. def handle_data(algo, data): algo.order(self.sid, 3) algo.order_count += 1 algo = SetMaxPositionSizeAlgorithm(sid=self.sid, max_shares=10, max_notional=500.0) self.check_algo_fails(algo, handle_data, 3) # Buy two shares four times. Should bail due to max_notional on the # third attempt. def handle_data(algo, data): algo.order(self.sid, 3) algo.order_count += 1 algo = SetMaxPositionSizeAlgorithm(sid=self.sid, max_shares=10, max_notional=61.0) self.check_algo_fails(algo, handle_data, 2) # Set the trading control to a different sid, then BUY ALL THE THINGS!. # Should continue normally. def handle_data(algo, data): algo.order(self.sid, 10000) algo.order_count += 1 algo = SetMaxPositionSizeAlgorithm(sid=self.sid + 1, max_shares=10, max_notional=61.0) self.check_algo_succeeds(algo, handle_data) # Set the trading control sid to None, then BUY ALL THE THINGS!. Should # fail because setting sid to None makes the control apply to all sids. def handle_data(algo, data): algo.order(self.sid, 10000) algo.order_count += 1 algo = SetMaxPositionSizeAlgorithm(max_shares=10, max_notional=61.0) self.check_algo_fails(algo, handle_data, 0) def test_set_max_order_size(self): # Buy one share. def handle_data(algo, data): algo.order(self.sid, 1) algo.order_count += 1 algo = SetMaxOrderSizeAlgorithm(sid=self.sid, max_shares=10, max_notional=500.0) self.check_algo_succeeds(algo, handle_data) # Buy 1, then 2, then 3, then 4 shares. Bail on the last attempt # because we exceed shares. def handle_data(algo, data): algo.order(self.sid, algo.order_count + 1) algo.order_count += 1 algo = SetMaxOrderSizeAlgorithm(sid=self.sid, max_shares=3, max_notional=500.0) self.check_algo_fails(algo, handle_data, 3) # Buy 1, then 2, then 3, then 4 shares. Bail on the last attempt # because we exceed notional. def handle_data(algo, data): algo.order(self.sid, algo.order_count + 1) algo.order_count += 1 algo = SetMaxOrderSizeAlgorithm(sid=self.sid, max_shares=10, max_notional=40.0) self.check_algo_fails(algo, handle_data, 3) # Set the trading control to a different sid, then BUY ALL THE THINGS!. # Should continue normally. def handle_data(algo, data): algo.order(self.sid, 10000) algo.order_count += 1 algo = SetMaxOrderSizeAlgorithm(sid=self.sid + 1, max_shares=1, max_notional=1.0) self.check_algo_succeeds(algo, handle_data) # Set the trading control sid to None, then BUY ALL THE THINGS!. # Should fail because not specifying a sid makes the trading control # apply to all sids. def handle_data(algo, data): algo.order(self.sid, 10000) algo.order_count += 1 algo = SetMaxOrderSizeAlgorithm(max_shares=1, max_notional=1.0) self.check_algo_fails(algo, handle_data, 0) def test_set_max_order_count(self): # Override the default setUp to use six-hour intervals instead of full # days so we can exercise trading-session rollover logic. trade_history = factory.create_trade_history( self.sid, [10.0, 10.0, 11.0, 11.0], [100, 100, 100, 300], timedelta(hours=6), self.sim_params ) self.source = SpecificEquityTrades(event_list=trade_history) def handle_data(algo, data): for i in range(5): algo.order(self.sid, 1) algo.order_count += 1 algo = SetMaxOrderCountAlgorithm(3) self.check_algo_fails(algo, handle_data, 3) # Second call to handle_data is the same day as the first, so the last # order of the second call should fail. algo = SetMaxOrderCountAlgorithm(9) self.check_algo_fails(algo, handle_data, 9) # Only ten orders are placed per day, so this should pass even though # in total more than 20 orders are placed. algo = SetMaxOrderCountAlgorithm(10) self.check_algo_succeeds(algo, handle_data, order_count=20) def test_long_only(self): # Sell immediately -> fail immediately. def handle_data(algo, data): algo.order(self.sid, -1) algo.order_count += 1 algo = SetLongOnlyAlgorithm() self.check_algo_fails(algo, handle_data, 0) # Buy on even days, sell on odd days. Never takes a short position, so # should succeed. def handle_data(algo, data): if (algo.order_count % 2) == 0: algo.order(self.sid, 1) else: algo.order(self.sid, -1) algo.order_count += 1 algo = SetLongOnlyAlgorithm() self.check_algo_succeeds(algo, handle_data) # Buy on first three days, then sell off holdings. Should succeed. def handle_data(algo, data): amounts = [1, 1, 1, -3] algo.order(self.sid, amounts[algo.order_count]) algo.order_count += 1 algo = SetLongOnlyAlgorithm() self.check_algo_succeeds(algo, handle_data) # Buy on first three days, then sell off holdings plus an extra share. # Should fail on the last sale. def handle_data(algo, data): amounts = [1, 1, 1, -4] algo.order(self.sid, amounts[algo.order_count]) algo.order_count += 1 algo = SetLongOnlyAlgorithm() self.check_algo_fails(algo, handle_data, 3) def test_register_post_init(self): def initialize(algo): algo.initialized = True def handle_data(algo, data): with self.assertRaises(RegisterTradingControlPostInit): algo.set_max_position_size(self.sid, 1, 1) with self.assertRaises(RegisterTradingControlPostInit): algo.set_max_order_size(self.sid, 1, 1) with self.assertRaises(RegisterTradingControlPostInit): algo.set_max_order_count(1) with self.assertRaises(RegisterTradingControlPostInit): algo.set_long_only() algo = TradingAlgorithm(initialize=initialize, handle_data=handle_data) algo.run(self.source) self.source.rewind()