def test_simulation_broker_register_tick_multi_symbol( sim_broker_runner_and_streamer_1s, ): _, runner, streamer = sim_broker_runner_and_streamer_1s spy_stock_contract = StockContract(symbol="SPY") schw_stock_contract = StockContract(symbol="SCHW") spy_ask = [] schw_ask = [] def spy_receiver(_, price): nonlocal spy_ask spy_ask.append(price) def schw_receiver(_, price): nonlocal schw_ask schw_ask.append(price) streamer.subscribe_to_tick_data( contract=spy_stock_contract, func=spy_receiver, price_type=PriceType.ASK, ) streamer.subscribe_to_tick_data( contract=schw_stock_contract, func=schw_receiver, price_type=PriceType.ASK, ) runner.run_sim(step_count=4) np.testing.assert_equal(spy_ask, [413.53, 411.52]) np.testing.assert_equal(schw_ask, [37.442, 37.442])
def test_tick_data_delivery_order(sim_broker_runner_and_streamer_1s): _, runner, streamer = sim_broker_runner_and_streamer_1s spy_stock_contract = StockContract(symbol="SPY") schw_stock_contract = StockContract(symbol="SCHW") spy_ask = [] schw_ask = [] def spy_receiver(_, price): nonlocal spy_ask, schw_ask spy_ask.append(price) if price == 406.9: assert abs(len(spy_ask) - len(schw_ask)) == 0 elif price == 407.57: assert abs(len(spy_ask) - len(schw_ask)) == 0 def schw_receiver(_, price): nonlocal schw_ask schw_ask.append(price) streamer.subscribe_to_tick_data( contract=spy_stock_contract, func=spy_receiver, price_type=PriceType.ASK, ) streamer.subscribe_to_tick_data( contract=schw_stock_contract, func=schw_receiver, price_type=PriceType.ASK, ) runner.run_sim(step_count=9) assert len(spy_ask) == 17 assert len(schw_ask) == 18
def test_cancel_tick_data(streamer): mid = None def update_mid(c, price_): nonlocal mid mid = price_ contract = StockContract(symbol="SPY") streamer.subscribe_to_tick_data( contract=contract, func=update_mid, price_type=PriceType.MARKET, ) t0 = time.time() while mid is None and time.time() - t0 <= AWAIT_TIME_OUT: streamer.sleep() assert mid is not None mid = None t0 = time.time() while mid is None and time.time() - t0 <= AWAIT_TIME_OUT: streamer.sleep() assert mid is not None # was refreshed streamer.cancel_tick_data(contract=contract, func=update_mid) streamer.sleep() mid = None t0 = time.time() while mid is None and time.time() - t0 <= AWAIT_TIME_OUT: streamer.sleep() assert mid is None # did not refresh again
def _trades_receiver(self, trade: Dict): contract = StockContract(symbol=trade["sym"]) tick = self._parse_trade(trade=trade) with self._trade_subscribers_lock: sub_dict = self._trade_subscribers[contract] for func, fn_kwargs in sub_dict.items(): func(tick, **fn_kwargs)
def test_subscribe_to_bars_data(streamer, bar_size): latest = None def update(bar): nonlocal latest latest = bar contract = StockContract(symbol="SPY") streamer.subscribe_to_bars( contract=contract, bar_size=bar_size, func=update, rth=False, ) streamer.sleep(bar_size.seconds + 10) assert latest is not None assert latest.name > datetime.now() - bar_size - timedelta(seconds=10) prev = latest.copy() latest = None streamer.sleep(bar_size.seconds) assert latest is not None assert latest.name > datetime.now() - bar_size assert latest.name > prev.name
def test_retrieving_intermittently_cached_intraday(tmpdir, provider): retriever = HistoricalRetriever(provider=provider, hist_data_dir=tmpdir,) data = pd.DataFrame() dates = generate_trading_days( start_date=date.today() - timedelta(days=10), end_date=date.today() - timedelta(days=1), ) date_ranges = [ dates[1:2], dates[-3:-2], dates, ] for date_range in date_ranges: start_date = date_range[0] end_date = date_range[-1] contract = StockContract(symbol="SPY") try: data = retriever.retrieve_bar_data( contract=contract, start_date=start_date, end_date=end_date, bar_size=timedelta(days=1), ) except NotImplementedError: # todo: fix return validate_data_range(data=data, start_date=dates[0], end_date=dates[-1])
def test_retrieving_intermittently_cached_trades(tmpdir, provider): retriever = HistoricalRetriever(provider=provider, hist_data_dir=tmpdir,) data = pd.DataFrame() dates = generate_trading_days( start_date=date(2020, 7, 21), end_date=date(2020, 7, 23), ) date_ranges = [ dates[:1], dates[-1:], dates, ] for date_range in date_ranges: start_date = date_range[0] end_date = date_range[-1] contract = StockContract(symbol="SPY") try: data = retriever.retrieve_trades_data( contract=contract, start_date=start_date, end_date=end_date, ) except NotImplementedError: return validate_data_range(data=data, start_date=dates[0], end_date=dates[-1])
def _validate_contract(contract: AContract) -> AContract: if isinstance(contract, StockContract): contract = StockContract(symbol=contract.symbol) else: raise TypeError(f"Unknown contract type {type(contract)}.") return contract
def test_cancel_bars_data(streamer): latest = None def update(bar): nonlocal latest latest = bar contract = StockContract(symbol="SPY") streamer.subscribe_to_bars( contract=contract, bar_size=timedelta(seconds=5), func=update, rth=False, ) streamer.sleep(15) assert latest is not None assert latest.name > datetime.now() - timedelta(seconds=15) streamer.cancel_bars(contract=contract, func=update) last_latest = latest streamer.sleep(15) assert last_latest.name == latest.name
def test_simulation_broker_register_tick_resolution_fail( sim_broker_runner_and_streamer_15m, ): spy_stock_contract = StockContract(symbol="SPY") _, _, streamer = sim_broker_runner_and_streamer_15m with pytest.raises(ValueError): streamer.subscribe_to_tick_data( contract=spy_stock_contract, func=lambda x: None, )
def test_subscribe_to_position_updates(): con = StockContract(symbol="SPY") pos = Position( account=DEFAULT_SIM_ACC, contract=con, position=10, ave_fill_price=345.6, ) starting_positions = { DEFAULT_SIM_ACC: {con: pos}, } untethered_sim_broker = SimulationBroker( sim_streamer=SimulationDataStreamer(), starting_funds={Currency.USD: 10_000}, transaction_cost=1, starting_positions=starting_positions, ) pos_updates = [] def pos_updates_receiver(pos_: Position): pos_updates.append(pos_) untethered_sim_broker.subscribe_to_position_updates( func=pos_updates_receiver, ) sim_order = LimitOrder(action=OrderAction.BUY, quantity=2, limit_price=346) sim_trade = Trade(contract=con, order=sim_order) untethered_sim_broker.place_trade(trade=sim_trade) untethered_sim_broker.simulate_trade_execution( trade=sim_trade, price=345.8, n_shares=1, ) untethered_sim_broker.step() assert len(pos_updates) == 1 updated_pos: Position = pos_updates[0] target_ave_price = (10 * 345.6 + 345.8) / 11 assert updated_pos.contract == con assert updated_pos.position == 11 assert updated_pos.ave_fill_price == target_ave_price assert updated_pos.account == DEFAULT_SIM_ACC untethered_sim_broker.simulate_trade_execution( trade=sim_trade, price=345.9, n_shares=1, ) untethered_sim_broker.step() updated_pos: Position = pos_updates[1] target_ave_price = (target_ave_price * 11 + 345.9) / 12 assert updated_pos.contract == con assert updated_pos.position == 12 assert updated_pos.ave_fill_price == target_ave_price assert updated_pos.account == DEFAULT_SIM_ACC
def __init__(self, start_date: date, end_date: date, bar_size: timedelta): hist_retriever = HistoricalRetriever(hist_data_dir=TEST_DATA_DIR) contract = StockContract(symbol="SPY") self._sim_data = hist_retriever.retrieve_bar_data( contract=contract, bar_size=bar_size, start_date=start_date, end_date=end_date, cache_only=True, ) self._sim_idx = 0
def get_1_spy_mkt_trade(buy: bool) -> Trade: contract = StockContract(symbol="SPY") if buy: action = OrderAction.BUY else: action = OrderAction.SELL order = MarketOrder(action=action, quantity=1) trade = Trade(contract=contract, order=order) return trade
def test_simulation_broker_init(): spy_stock_contract = StockContract(symbol="SPY") streamer = SimulationDataStreamer() broker = SimulationBroker( sim_streamer=streamer, starting_funds={Currency.USD: 1_000}, transaction_cost=1, ) assert broker.acc_cash[Currency.USD] == 1_000 assert broker.get_position(contract=spy_stock_contract) == 0 assert broker.get_transaction_fee() == 1
def test_cancel_trades_data(streamer): tick = None received = threading.Event() def receiver(t): nonlocal tick nonlocal received tick = t if not received.is_set(): received.set() streamer.subscribe_to_trades( contract=StockContract(symbol="SPY"), func=receiver, ) received.wait(1) assert tick is not None tick = None time.sleep(1) assert tick is not None streamer.cancel_trades( contract=StockContract(symbol="SPY"), func=receiver, ) time.sleep(1) tick = None time.sleep(1) assert tick is None
def _from_ib_contract(self, ib_contract: _IBContract): contract = None exchange = self._from_ib_exchange(ib_exchange=ib_contract.exchange) currency = self._from_ib_currency(ib_currency=ib_contract.currency) con_id = ib_contract.conId if con_id == 0: con_id = None if isinstance(ib_contract, _IBStock): contract = StockContract( con_id=con_id, symbol=ib_contract.symbol, exchange=exchange, currency=currency, ) elif isinstance(ib_contract, _IBOption): last_trade_date = _get_opt_trade_date( last_trade_date_str=ib_contract.lastTradeDateOrContractMonth, ) if ib_contract.right == "C": right = Right.CALL elif ib_contract.right == "P": right = Right.PUT else: raise ValueError(f"Unknown right type {ib_contract.right}.") contract = OptionContract( con_id=con_id, symbol=ib_contract.symbol, strike=ib_contract.strike, right=right, multiplier=float(ib_contract.multiplier), last_trade_date=last_trade_date, exchange=exchange, currency=currency, ) elif isinstance(ib_contract, _IBForex): contract = ForexContract( symbol=ib_contract.symbol, con_id=con_id, exchange=exchange, currency=currency, ) else: logging.warning( f"Contract type {ib_contract.secType} not understood." f" No contract was built." ) return contract
def test_retrieve_non_cached_trades_data(tmpdir, provider): start_date = date(2020, 7, 22) end_date = date(2020, 7, 23) retriever = HistoricalRetriever(provider=provider, hist_data_dir=tmpdir) contract = StockContract(symbol="SPY") try: data = retriever.retrieve_trades_data( contract=contract, start_date=start_date, end_date=end_date, ) except NotImplementedError: return validate_data_range(data=data, start_date=start_date, end_date=end_date)
def test_simulation_broker_register_daily(sim_broker_runner_and_streamer_15m): _, runner, streamer = sim_broker_runner_and_streamer_15m checker = BarChecker( start_date=date(2020, 4, 6), end_date=date(2020, 4, 7), bar_size=timedelta(days=1), ) contract = StockContract(symbol="SPY") streamer.subscribe_to_bars( contract=contract, bar_size=timedelta(days=1), func=checker.bar_receiver, ) runner.run_sim() checker.assert_all_received()
def test_retrieve_cached_trades(): start_date = date(2020, 6, 17) end_date = date(2020, 6, 19) retriever = HistoricalRetriever(hist_data_dir=TEST_DATA_DIR) contract = StockContract(symbol="SPY") data = retriever.retrieve_trades_data( contract=contract, start_date=start_date, end_date=end_date, cache_only=True, ) assert len(data) != 0 validate_data_range(data=data, start_date=start_date, end_date=end_date)
def test_retriever_cached_daily(): start_date = date(2020, 4, 1) end_date = date(2020, 4, 2) retriever = HistoricalRetriever(hist_data_dir=TEST_DATA_DIR) contract = StockContract(symbol="SPY") data = retriever.retrieve_bar_data( contract=contract, start_date=start_date, end_date=end_date, bar_size=timedelta(days=1), cache_only=True, ) assert len(data) != 0 validate_data_range(data=data, start_date=start_date, end_date=end_date)
def test_retrieve_non_cached_daily(tmpdir, provider): start_date = date(2020, 4, 1) end_date = date(2020, 4, 2) retriever = HistoricalRetriever(provider=provider, hist_data_dir=tmpdir,) contract = StockContract(symbol="SPY") try: data = retriever.retrieve_bar_data( contract=contract, start_date=start_date, end_date=end_date, bar_size=timedelta(days=1), ) except NotImplementedError: return validate_data_range(data=data, start_date=start_date, end_date=end_date)
def test_subscribe_to_tick_data(streamer): con = None ask = None bid = None mid = None def update_ask(contract_, price_): nonlocal con, ask con = contract_ ask = price_ def update_bid(c, price_): nonlocal bid bid = price_ def update_mid(c, price_): nonlocal mid mid = price_ contract = StockContract(symbol="SPY") streamer.subscribe_to_tick_data( contract=contract, func=update_ask, price_type=PriceType.ASK, ) streamer.subscribe_to_tick_data( contract=contract, func=update_bid, price_type=PriceType.BID, ) streamer.subscribe_to_tick_data( contract=contract, func=update_mid, price_type=PriceType.MARKET, ) t0 = time.time() while (con is None and ask is None and bid is None and mid is None and time.time() - t0 <= AWAIT_TIME_OUT): streamer.sleep() assert con == contract assert ask > bid assert mid == (ask + bid) / 2
def test_simulation_broker_buy( sim_broker_runner_and_streamer_15m, buy_1_spy_mkt_trade, ): broker, runner, _ = sim_broker_runner_and_streamer_15m spy_stock_contract = StockContract(symbol="SPY") assert broker.get_position(contract=spy_stock_contract) == 0 runner.run_sim(step_count=1, cache_only=True) broker.place_trade(trade=buy_1_spy_mkt_trade) runner.run_sim(step_count=1, cache_only=True) spy_2020_4_6_9_45_open = 259.79 assert np.isclose( broker.acc_cash[Currency.USD], 1000 - spy_2020_4_6_9_45_open - 1, ) assert broker.get_position(contract=spy_stock_contract) == 1
def test_retrieve_non_cached_trades_data_today_partial(tmpdir, provider): end_date = date.today() start_date = end_date - timedelta(days=1) retriever = HistoricalRetriever(provider=provider, hist_data_dir=tmpdir) contract = StockContract(symbol="SPY") try: data = retriever.retrieve_trades_data( contract=contract, start_date=start_date, end_date=end_date, allow_partial=True, rth=False, ) except NotImplementedError: return validate_data_range(data=data, start_date=start_date, end_date=end_date) assert data.iloc[-1].name.date() == date.today()
def test_simulation_broker_register_tick_single_symbol( sim_broker_runner_and_streamer_1s, ): spy_stock_contract = StockContract(symbol="SPY") _, runner, streamer = sim_broker_runner_and_streamer_1s mkt_prices = [] ask_prices = [] bid_prices = [] def mkt_receiver(_, price): nonlocal mkt_prices mkt_prices.append(price) def ask_receiver(_, price): nonlocal ask_prices ask_prices.append(price) def bid_receiver(_, price): nonlocal bid_prices bid_prices.append(price) streamer.subscribe_to_tick_data( contract=spy_stock_contract, func=mkt_receiver, ) streamer.subscribe_to_tick_data( contract=spy_stock_contract, func=ask_receiver, price_type=PriceType.ASK, ) streamer.subscribe_to_tick_data( contract=spy_stock_contract, func=bid_receiver, price_type=PriceType.BID, ) runner.run_sim(step_count=4) np.testing.assert_equal(mkt_prices, [409.99, 410.385]) np.testing.assert_equal(ask_prices, [413.53, 411.52]) np.testing.assert_equal(bid_prices, [406.45, 409.25])
def test_simulation_broker_sell( sim_broker_runner_and_streamer_15m, sell_1_spy_mkt_trade, ): broker, runner, _ = sim_broker_runner_and_streamer_15m spy_stock_contract = StockContract(symbol="SPY") assert broker.get_position(contract=spy_stock_contract) == 0 runner.run_sim(step_count=1) broker.place_trade(trade=sell_1_spy_mkt_trade) runner.run_sim(step_count=1) spy_2020_4_6_9_30_close = 259.79 assert np.isclose( broker.acc_cash[Currency.USD], 1000 + spy_2020_4_6_9_30_close - 1, 0.01, )
def test_subscribe_to_trades_data(streamer): tick = None received = threading.Event() def receiver(t): nonlocal tick nonlocal received if not received.is_set(): tick = t received.set() streamer.subscribe_to_trades( contract=StockContract(symbol="SPY"), func=receiver, ) received.wait(2) assert isinstance(tick, Tick) assert tick.symbol == "SPY"
def test_historical_bar_aggregator(): start_date = date(2020, 4, 6) end_date = date(2020, 4, 7) retriever = HistoricalRetriever(hist_data_dir=TEST_DATA_DIR) contract = StockContract(symbol="SPY") base_data = retriever.retrieve_bar_data( contract=contract, start_date=start_date, end_date=end_date, bar_size=timedelta(minutes=1), cache_only=True, ) aggregator = HistoricalAggregator(hist_data_dir=TEST_DATA_DIR) agg_data = aggregator.aggregate_data( contract=contract, start_date=start_date, end_date=end_date, base_bar_size=timedelta(minutes=1), target_bar_size=timedelta(minutes=5), ) assert len(agg_data) == 156 assert agg_data.iloc[0]["open"] == base_data.iloc[0]["open"] assert agg_data.iloc[0]["close"] == base_data.iloc[4]["close"] assert agg_data.iloc[0]["high"] == base_data.iloc[:5]["high"].max() assert agg_data.iloc[0]["low"] == base_data.iloc[:5]["low"].min() assert agg_data.iloc[0]["volume"] == base_data.iloc[:5]["volume"].sum() agg_data = aggregator.aggregate_data( contract=contract, start_date=start_date, end_date=end_date, base_bar_size=timedelta(minutes=1), target_bar_size=timedelta(minutes=10), ) assert len(agg_data) == 78
def test_retrieving_intermittently_cached_daily(tmpdir, provider): retriever = HistoricalRetriever(provider=provider, hist_data_dir=tmpdir) start_date = date(2020, 3, 3) end_date = date(2020, 3, 3) contract = StockContract(symbol="SPY") try: retriever.retrieve_bar_data( contract=contract, start_date=start_date, end_date=end_date, bar_size=timedelta(days=1), ) except NotImplementedError: # todo: fix return start_date = date(2020, 3, 5) end_date = date(2020, 3, 5) retriever.retrieve_bar_data( contract=contract, start_date=start_date, end_date=end_date, bar_size=timedelta(days=1), ) start_date = date(2020, 3, 2) end_date = date(2020, 3, 6) data = retriever.retrieve_bar_data( contract=contract, start_date=start_date, end_date=end_date, bar_size=timedelta(days=1), ) validate_data_range(data=data, start_date=start_date, end_date=end_date)
def test_retrieve_non_cached_intraday(tmpdir, provider): start_date = date.today() - timedelta(days=7) end_date = date.today() - timedelta(days=1) retriever = HistoricalRetriever(provider=provider, hist_data_dir=tmpdir) contract = StockContract(symbol="SPY") try: data = retriever.retrieve_bar_data( contract=contract, start_date=start_date, end_date=end_date, bar_size=timedelta(minutes=1), ) except NotImplementedError: return validate_data_range(data=data, start_date=start_date, end_date=end_date) assert ( np.isclose(len(data), 5 * 6.5 * 60, atol=7 * 60) or len(data) == 4800 # for outside RTHs IB )