class TradingNode: """ Provides an asynchronous network node for live trading. """ def __init__( self, strategies: List[TradingStrategy], config: Dict[str, object], ): """ Initialize a new instance of the TradingNode class. Parameters ---------- strategies : list[TradingStrategy] The list of strategies to run on the trading node. config : dict[str, object] The configuration for the trading node. Raises ------ ValueError If strategies is None or empty. ValueError If config is None or empty. """ PyCondition.not_none(strategies, "strategies") PyCondition.not_none(config, "config") PyCondition.not_empty(strategies, "strategies") PyCondition.not_empty(config, "config") # Extract configs config_trader = config.get("trader", {}) config_log = config.get("logging", {}) config_exec_db = config.get("exec_database", {}) config_strategy = config.get("strategy", {}) config_adapters = config.get("adapters", {}) self._uuid_factory = UUIDFactory() self._loop = asyncio.get_event_loop() self._executor = concurrent.futures.ThreadPoolExecutor() self._loop.set_default_executor(self._executor) self._clock = LiveClock(loop=self._loop) self.created_time = self._clock.utc_now() self._is_running = False # Uncomment for debugging # self._loop.set_debug(True) # Setup identifiers self.trader_id = TraderId( name=config_trader["name"], tag=config_trader["id_tag"], ) # Setup logging self._logger = LiveLogger( clock=self._clock, name=self.trader_id.value, level_console=LogLevelParser.from_str_py(config_log.get("log_level_console")), level_file=LogLevelParser.from_str_py(config_log.get("log_level_file")), level_store=LogLevelParser.from_str_py(config_log.get("log_level_store")), run_in_process=config_log.get("run_in_process", True), # Run logger in a separate process log_thread=config_log.get("log_thread_id", False), log_to_file=config_log.get("log_to_file", False), log_file_path=config_log.get("log_file_path", ""), ) self._log = LoggerAdapter(component_name=self.__class__.__name__, logger=self._logger) self._log_header() self._log.info("Building...") self._setup_loop() # Requires the logger to be initialized self.portfolio = Portfolio( clock=self._clock, logger=self._logger, ) self._data_engine = LiveDataEngine( loop=self._loop, portfolio=self.portfolio, clock=self._clock, logger=self._logger, config={"qsize": 10000}, ) self.portfolio.register_cache(self._data_engine.cache) self.analyzer = PerformanceAnalyzer() if config_exec_db["type"] == "redis": exec_db = RedisExecutionDatabase( trader_id=self.trader_id, logger=self._logger, command_serializer=MsgPackCommandSerializer(), event_serializer=MsgPackEventSerializer(), config={ "host": config_exec_db["host"], "port": config_exec_db["port"], } ) else: exec_db = BypassExecutionDatabase( trader_id=self.trader_id, logger=self._logger, ) self._exec_engine = LiveExecutionEngine( loop=self._loop, database=exec_db, portfolio=self.portfolio, clock=self._clock, logger=self._logger, config={"qsize": 10000}, ) self._exec_engine.load_cache() self._setup_adapters(config_adapters, self._logger) self.trader = Trader( trader_id=self.trader_id, strategies=strategies, portfolio=self.portfolio, data_engine=self._data_engine, exec_engine=self._exec_engine, clock=self._clock, logger=self._logger, ) self._check_residuals_delay = config_trader.get("check_residuals_delay", 5.0) self._load_strategy_state = config_strategy.get("load_state", True) self._save_strategy_state = config_strategy.get("save_state", True) if self._load_strategy_state: self.trader.load() self._log.info("state=INITIALIZED.") self.time_to_initialize = self._clock.delta(self.created_time) self._log.info(f"Initialized in {self.time_to_initialize.total_seconds():.3f}s.") @property def is_running(self) -> bool: """ If the trading node is running. Returns ------- bool True if running, else False. """ return self._is_running def get_event_loop(self) -> asyncio.AbstractEventLoop: """ Return the event loop of the trading node. Returns ------- asyncio.AbstractEventLoop """ return self._loop def get_logger(self) -> LiveLogger: """ Return the logger for the trading node. Returns ------- LiveLogger """ return self._logger def start(self) -> None: """ Start the trading node. """ try: if self._loop.is_running(): self._loop.create_task(self._run()) else: self._loop.run_until_complete(self._run()) except RuntimeError as ex: self._log.exception(ex) def stop(self) -> None: """ Stop the trading node gracefully. After a specified delay the internal `Trader` residuals will be checked. If save strategy is specified then strategy states will then be saved. """ try: if self._loop.is_running(): self._loop.create_task(self._stop()) else: self._loop.run_until_complete(self._stop()) except RuntimeError as ex: self._log.exception(ex) def dispose(self) -> None: """ Dispose of the trading node. Gracefully shuts down the executor and event loop. """ try: timeout = self._clock.utc_now() + timedelta(seconds=5) while self._is_running: time.sleep(0.1) if self._clock.utc_now() >= timeout: self._log.warning("Timed out (5s) waiting for node to stop.") break self._log.info("state=DISPOSING...") self._log.debug(f"{self._data_engine.get_run_queue_task()}") self._log.debug(f"{self._exec_engine.get_run_queue_task()}") self.trader.dispose() self._data_engine.dispose() self._exec_engine.dispose() self._log.info("Shutting down executor...") if sys.version_info >= (3, 9): # cancel_futures added in Python 3.9 self._executor.shutdown(wait=True, cancel_futures=True) else: self._executor.shutdown(wait=True) self._log.info("Stopping event loop...") self._loop.stop() self._cancel_all_tasks() except RuntimeError as ex: self._log.error("Shutdown coro issues will be fixed soon...") # TODO: Remove when fixed self._log.exception(ex) finally: if self._loop.is_running(): self._log.warning("Cannot close a running event loop.") else: self._log.info("Closing event loop...") self._loop.close() # Check and log if event loop is running if self._loop.is_running(): self._log.warning(f"loop.is_running={self._loop.is_running()}") else: self._log.info(f"loop.is_running={self._loop.is_running()}") # Check and log if event loop is closed if not self._loop.is_closed(): self._log.warning(f"loop.is_closed={self._loop.is_closed()}") else: self._log.info(f"loop.is_closed={self._loop.is_closed()}") self._log.info("state=DISPOSED.") self._logger.stop() # Ensure process is stopped time.sleep(0.1) # Ensure final log messages def _log_header(self) -> None: nautilus_header(self._log) self._log.info(f"redis {redis.__version__}") self._log.info(f"msgpack {msgpack.version[0]}.{msgpack.version[1]}.{msgpack.version[2]}") if uvloop_version: self._log.info(f"uvloop {uvloop_version}") self._log.info("=================================================================") def _setup_loop(self) -> None: if self._loop.is_closed(): self._log.error("Cannot setup signal handling (event loop was closed).") return signal.signal(signal.SIGINT, signal.SIG_DFL) signals = (signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGABRT) for sig in signals: self._loop.add_signal_handler(sig, self._loop_sig_handler, sig) self._log.debug(f"Event loop {signals} handling setup.") def _loop_sig_handler(self, sig: signal.signal) -> None: self._loop.remove_signal_handler(signal.SIGTERM) self._loop.add_signal_handler(signal.SIGINT, lambda: None) self._log.warning(f"Received {sig!s}, shutting down...") self.stop() def _setup_adapters(self, config: Dict[str, object], logger: LiveLogger) -> None: # Setup each data client for name, config in config.items(): if name.startswith("ccxt-"): try: import ccxtpro # TODO: Find a better way of doing this except ImportError: raise ImportError("ccxtpro is not installed, " "installation instructions can be found at https://ccxt.pro") client_cls = getattr(ccxtpro, name.partition('-')[2].lower()) data_client, exec_client = CCXTClientsFactory.create( client_cls=client_cls, config=config, data_engine=self._data_engine, exec_engine=self._exec_engine, clock=self._clock, logger=logger, ) elif name == "oanda": data_client = OandaDataClientFactory.create( config=config, data_engine=self._data_engine, clock=self._clock, logger=logger, ) exec_client = None # TODO: Implement else: self._log.error(f"No adapter available for `{name}`.") continue if data_client is not None: self._data_engine.register_client(data_client) if exec_client is not None: self._exec_engine.register_client(exec_client) async def _run(self) -> None: try: self._log.info("state=STARTING...") self._is_running = True self._data_engine.start() self._exec_engine.start() result: bool = await self._await_engines_connected() if not result: return result: bool = await self._exec_engine.resolve_state() if not result: return self.trader.start() if self._loop.is_running(): self._log.info("state=RUNNING.") else: self._log.warning("Event loop is not running.") # Continue to run while engines are running... await self._data_engine.get_run_queue_task() await self._exec_engine.get_run_queue_task() except asyncio.CancelledError as ex: self._log.error(str(ex)) async def _await_engines_connected(self) -> bool: self._log.info("Waiting for engines to initialize...") # The data engine clients will be set as connected when all # instruments are received and updated with the data engine. # The execution engine clients will be set as connected when all # accounts are updated and the current order and position status is # confirmed. Thus any delay here will be due to blocking network IO. seconds = 5 # Hard coded for now timeout: timedelta = self._clock.utc_now() + timedelta(seconds=seconds) while True: await asyncio.sleep(0.1) if self._clock.utc_now() >= timeout: self._log.error(f"Timed out ({seconds}s) waiting for " f"engines to initialize.") return False if not self._data_engine.check_connected(): continue if not self._exec_engine.check_connected(): continue break return True # Engines initialized async def _stop(self) -> None: self._is_stopping = True self._log.info("state=STOPPING...") if self.trader.state == ComponentState.RUNNING: self.trader.stop() self._log.info(f"Awaiting residual state ({self._check_residuals_delay}s delay)...") await asyncio.sleep(self._check_residuals_delay) self.trader.check_residuals() if self._save_strategy_state: self.trader.save() if self._data_engine.state == ComponentState.RUNNING: self._data_engine.stop() if self._exec_engine.state == ComponentState.RUNNING: self._exec_engine.stop() await self._await_engines_disconnected() # Clean up remaining timers timer_names = self._clock.timer_names() self._clock.cancel_timers() for name in timer_names: self._log.info(f"Cancelled Timer(name={name}).") self._log.info("state=STOPPED.") self._is_running = False async def _await_engines_disconnected(self) -> None: self._log.info("Waiting for engines to disconnect...") seconds = 5 # Hard coded for now timeout: timedelta = self._clock.utc_now() + timedelta(seconds=seconds) while True: await asyncio.sleep(0.1) if self._clock.utc_now() >= timeout: self._log.warning(f"Timed out ({seconds}s) waiting for engines to disconnect.") break if not self._data_engine.check_disconnected(): continue if not self._exec_engine.check_disconnected(): continue break # Engines initialized def _cancel_all_tasks(self) -> None: to_cancel = asyncio.tasks.all_tasks(self._loop) if not to_cancel: self._log.info("All tasks finished.") return for task in to_cancel: self._log.warning(f"Cancelling pending task {task}") task.cancel() if self._loop.is_running(): self._log.warning("Event loop still running during `cancel_all_tasks`.") return finish_all_tasks: asyncio.Future = asyncio.tasks.gather( *to_cancel, loop=self._loop, return_exceptions=True, ) self._loop.run_until_complete(finish_all_tasks) self._log.debug(f"{finish_all_tasks}") for task in to_cancel: if task.cancelled(): continue if task.exception() is not None: self._loop.call_exception_handler({ 'message': 'unhandled exception during asyncio.run() shutdown', 'exception': task.exception(), 'task': task, })
class CCXTDataClientTests(unittest.TestCase): def setUp(self): # Fixture Setup self.clock = LiveClock() self.uuid_factory = UUIDFactory() self.trader_id = TraderId("TESTER", "001") # Fresh isolated loop testing pattern self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) # Setup logging self.logger = LiveLogger( loop=self.loop, clock=self.clock, ) self.portfolio = Portfolio( clock=self.clock, logger=self.logger, ) self.data_engine = LiveDataEngine( loop=self.loop, portfolio=self.portfolio, clock=self.clock, logger=self.logger, ) # Setup mock CCXT exchange with open(TEST_PATH + "markets.json") as response: markets = json.load(response) with open(TEST_PATH + "currencies.json") as response: currencies = json.load(response) with open(TEST_PATH + "watch_order_book.json") as response: order_book = json.load(response) with open(TEST_PATH + "fetch_trades.json") as response: fetch_trades = json.load(response) with open(TEST_PATH + "watch_trades.json") as response: watch_trades = json.load(response) self.mock_ccxt = MagicMock() self.mock_ccxt.name = "Binance" self.mock_ccxt.precisionMode = 2 self.mock_ccxt.markets = markets self.mock_ccxt.currencies = currencies self.mock_ccxt.watch_order_book = order_book self.mock_ccxt.watch_trades = watch_trades self.mock_ccxt.fetch_trades = fetch_trades self.client = CCXTDataClient( client=self.mock_ccxt, engine=self.data_engine, clock=self.clock, logger=self.logger, ) self.data_engine.register_client(self.client) def tearDown(self): self.loop.stop() self.loop.close() def test_connect(self): async def run_test(): # Arrange # Act self.data_engine.start() # Also starts client await asyncio.sleep(0.3) # Allow engine message queue to start # Assert self.assertTrue(self.client.is_connected) # Tear down self.data_engine.stop() await self.data_engine.get_run_queue_task() self.loop.run_until_complete(run_test()) def test_disconnect(self): async def run_test(): # Arrange self.data_engine.start() # Also starts client await asyncio.sleep(0.3) # Allow engine message queue to start # Act self.client.disconnect() await asyncio.sleep(0.3) # Assert self.assertFalse(self.client.is_connected) # Tear down self.data_engine.stop() await self.data_engine.get_run_queue_task() self.loop.run_until_complete(run_test()) def test_reset_when_not_connected_successfully_resets(self): async def run_test(): # Arrange self.data_engine.start() # Also starts client await asyncio.sleep(0.3) # Allow engine message queue to start self.data_engine.stop() await asyncio.sleep(0.3) # Allow engine message queue to stop # Act self.client.reset() # Assert self.assertFalse(self.client.is_connected) self.loop.run_until_complete(run_test()) def test_reset_when_connected_does_not_reset(self): async def run_test(): # Arrange self.data_engine.start() # Also starts client await asyncio.sleep(0.3) # Allow engine message queue to start # Act self.client.reset() # Assert self.assertTrue(self.client.is_connected) # Tear Down self.data_engine.stop() await self.data_engine.get_run_queue_task() self.loop.run_until_complete(run_test()) def test_dispose_when_not_connected_does_not_dispose(self): async def run_test(): # Arrange self.data_engine.start() # Also starts client await asyncio.sleep(0.3) # Allow engine message queue to start # Act self.client.dispose() # Assert self.assertTrue(self.client.is_connected) # Tear Down self.data_engine.stop() await self.data_engine.get_run_queue_task() self.loop.run_until_complete(run_test()) def test_subscribe_instrument(self): async def run_test(): # Arrange self.data_engine.start() # Also starts client await asyncio.sleep(0.3) # Allow engine message queue to start # Act self.client.subscribe_instrument(BTCUSDT) # Assert self.assertIn(BTCUSDT, self.client.subscribed_instruments) # Tear Down self.data_engine.stop() await self.data_engine.get_run_queue_task() self.loop.run_until_complete(run_test()) def test_subscribe_quote_ticks(self): async def run_test(): # Arrange self.data_engine.start() # Also starts client await asyncio.sleep(0.3) # Allow engine message queue to start # Act self.client.subscribe_quote_ticks(ETHUSDT) await asyncio.sleep(0.3) # Assert self.assertIn(ETHUSDT, self.client.subscribed_quote_ticks) self.assertTrue(self.data_engine.cache.has_quote_ticks(ETHUSDT)) # Tear Down self.data_engine.stop() await self.data_engine.get_run_queue_task() self.loop.run_until_complete(run_test()) def test_subscribe_trade_ticks(self): async def run_test(): # Arrange self.data_engine.start() # Also starts client await asyncio.sleep(0.3) # Allow engine message queue to start # Act self.client.subscribe_trade_ticks(ETHUSDT) await asyncio.sleep(0.3) # Assert self.assertIn(ETHUSDT, self.client.subscribed_trade_ticks) self.assertTrue(self.data_engine.cache.has_trade_ticks(ETHUSDT)) # Tear Down self.data_engine.stop() await self.data_engine.get_run_queue_task() self.loop.run_until_complete(run_test()) def test_subscribe_bars(self): async def run_test(): # Arrange self.data_engine.start() # Also starts client await asyncio.sleep(0.5) # Allow engine message queue to start bar_type = TestStubs.bartype_btcusdt_binance_100tick_last() # Act self.client.subscribe_bars(bar_type) # Assert self.assertIn(bar_type, self.client.subscribed_bars) # Tear Down self.data_engine.stop() await self.data_engine.get_run_queue_task() self.loop.run_until_complete(run_test()) def test_unsubscribe_instrument(self): async def run_test(): # Arrange self.data_engine.start() # Also starts client await asyncio.sleep(0.3) # Allow engine message queue to start self.client.subscribe_instrument(BTCUSDT) # Act self.client.unsubscribe_instrument(BTCUSDT) # Assert self.assertNotIn(BTCUSDT, self.client.subscribed_instruments) # Tear Down self.data_engine.stop() await self.data_engine.get_run_queue_task() self.loop.run_until_complete(run_test()) def test_unsubscribe_quote_ticks(self): async def run_test(): # Arrange self.data_engine.start() # Also starts client await asyncio.sleep(0.3) # Allow engine message queue to start self.client.subscribe_quote_ticks(ETHUSDT) await asyncio.sleep(0.3) # Act self.client.unsubscribe_quote_ticks(ETHUSDT) # Assert self.assertNotIn(ETHUSDT, self.client.subscribed_quote_ticks) # Tear Down self.data_engine.stop() await self.data_engine.get_run_queue_task() self.loop.run_until_complete(run_test()) def test_unsubscribe_trade_ticks(self): async def run_test(): # Arrange self.data_engine.start() # Also starts client await asyncio.sleep(0.3) # Allow engine message queue to start self.client.subscribe_trade_ticks(ETHUSDT) # Act self.client.unsubscribe_trade_ticks(ETHUSDT) # Assert self.assertNotIn(ETHUSDT, self.client.subscribed_trade_ticks) # Tear Down self.data_engine.stop() await self.data_engine.get_run_queue_task() self.loop.run_until_complete(run_test()) def test_unsubscribe_bars(self): async def run_test(): # Arrange self.data_engine.start() # Also starts client await asyncio.sleep(0.3) # Allow engine message queue to start bar_type = TestStubs.bartype_btcusdt_binance_100tick_last() self.client.subscribe_bars(bar_type) # Act self.client.unsubscribe_bars(bar_type) # Assert self.assertNotIn(bar_type, self.client.subscribed_bars) # Tear Down self.data_engine.stop() await self.data_engine.get_run_queue_task() self.loop.run_until_complete(run_test()) def test_request_instrument(self): async def run_test(): # Arrange self.data_engine.start() await asyncio.sleep(0.5) # Allow engine message queue to start # Act self.client.request_instrument(BTCUSDT, uuid4()) await asyncio.sleep(0.5) # Assert # Instruments additionally requested on start self.assertEqual(1, self.data_engine.response_count) # Tear Down self.data_engine.stop() await self.data_engine.get_run_queue_task() self.loop.run_until_complete(run_test()) def test_request_instruments(self): async def run_test(): # Arrange self.data_engine.start() # Also starts client await asyncio.sleep(0.5) # Allow engine message queue to start # Act self.client.request_instruments(uuid4()) await asyncio.sleep(0.5) # Assert # Instruments additionally requested on start self.assertEqual(1, self.data_engine.response_count) # Tear Down self.data_engine.stop() await self.data_engine.get_run_queue_task() self.loop.run_until_complete(run_test()) def test_request_quote_ticks(self): async def run_test(): # Arrange self.data_engine.start() # Also starts client await asyncio.sleep(0.3) # Allow engine message queue to start # Act self.client.request_quote_ticks(BTCUSDT, None, None, 0, uuid4()) # Assert self.assertTrue(True) # Logs warning # Tear Down self.data_engine.stop() await self.data_engine.get_run_queue_task() self.loop.run_until_complete(run_test()) def test_request_trade_ticks(self): async def run_test(): # Arrange self.data_engine.start() # Also starts client await asyncio.sleep(0.3) # Allow engine message queue to start handler = ObjectStorer() request = DataRequest( client_name=BINANCE.value, data_type=DataType( TradeTick, metadata={ "InstrumentId": ETHUSDT, "FromDateTime": None, "ToDateTime": None, "Limit": 100, }, ), callback=handler.store, request_id=self.uuid_factory.generate(), timestamp_ns=self.clock.timestamp_ns(), ) # Act self.data_engine.send(request) await asyncio.sleep(1) # Assert self.assertEqual(1, self.data_engine.response_count) self.assertEqual(1, handler.count) # Tear Down self.data_engine.stop() await self.data_engine.get_run_queue_task() self.loop.run_until_complete(run_test()) def test_request_bars(self): async def run_test(): # Arrange with open(TEST_PATH + "fetch_ohlcv.json") as response: fetch_ohlcv = json.load(response) self.mock_ccxt.fetch_ohlcv = fetch_ohlcv self.data_engine.start() # Also starts client await asyncio.sleep(0.3) # Allow engine message queue to start handler = ObjectStorer() bar_spec = BarSpecification(1, BarAggregation.MINUTE, PriceType.LAST) bar_type = BarType(instrument_id=ETHUSDT, bar_spec=bar_spec) request = DataRequest( client_name=BINANCE.value, data_type=DataType( Bar, metadata={ "BarType": bar_type, "FromDateTime": None, "ToDateTime": None, "Limit": 100, }, ), callback=handler.store, request_id=self.uuid_factory.generate(), timestamp_ns=self.clock.timestamp_ns(), ) # Act self.data_engine.send(request) await asyncio.sleep(0.3) # Assert self.assertEqual(1, self.data_engine.response_count) self.assertEqual(1, handler.count) self.assertEqual(100, len(handler.get_store()[0])) # Tear Down self.data_engine.stop() await self.data_engine.get_run_queue_task() self.loop.run_until_complete(run_test())
class TestOandaDataClient: def setup(self): # Fixture Setup self.clock = LiveClock() self.uuid_factory = UUIDFactory() self.trader_id = TraderId("TESTER-001") # Fresh isolated loop testing pattern self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) self.executor = concurrent.futures.ThreadPoolExecutor() self.loop.set_default_executor(self.executor) self.loop.set_debug(True) # Setup logging logger = LiveLogger( loop=self.loop, clock=self.clock, trader_id=self.trader_id, level_stdout=LogLevel.DEBUG, ) self.logger = LiveLogger( loop=self.loop, clock=self.clock, ) self.cache = TestStubs.cache() self.portfolio = Portfolio( cache=self.cache, clock=self.clock, logger=self.logger, ) self.data_engine = LiveDataEngine( loop=self.loop, portfolio=self.portfolio, cache=self.cache, clock=self.clock, logger=self.logger, ) self.mock_oanda = MagicMock() self.client = OandaDataClient( client=self.mock_oanda, account_id="001", engine=self.data_engine, clock=self.clock, logger=logger, ) self.data_engine.register_client(self.client) with open(TEST_PATH + "instruments.json") as response: instruments = json.load(response) self.mock_oanda.request.return_value = instruments def teardown(self): self.executor.shutdown(wait=True) self.loop.stop() self.loop.close() # TODO: WIP - why is this failing?? # def test_connect(self): # async def run_test(): # # Arrange # # Act # self.data_engine.start() # Also connects client # self.client.connect() # await asyncio.sleep(1) # # # Assert # assert self.client.is_connected # # # Tear Down # self.data_engine.stop() # # self.loop.run_until_complete(run_test()) def test_disconnect(self): # Arrange self.client.connect() # Act self.client.disconnect() # Assert assert not self.client.is_connected def test_reset(self): # Arrange # Act self.client.reset() # Assert assert not self.client.is_connected def test_dispose(self): # Arrange # Act self.client.dispose() # Assert assert not self.client.is_connected def test_subscribe_instrument(self): # Arrange self.client.connect() # Act self.client.subscribe_instrument(AUDUSD) # Assert assert AUDUSD in self.client.subscribed_instruments def test_subscribe_quote_ticks(self): async def run_test(): # Arrange self.mock_oanda.request.return_value = {"type": {"HEARTBEAT": "0"}} self.data_engine.start() # Act self.client.subscribe_quote_ticks(AUDUSD) await asyncio.sleep(0.3) # Assert assert AUDUSD in self.client.subscribed_quote_ticks # Tear Down self.data_engine.stop() self.loop.run_until_complete(run_test()) def test_subscribe_bars(self): # Arrange bar_spec = BarSpecification(1, BarAggregation.MINUTE, PriceType.MID) bar_type = BarType(instrument_id=AUDUSD, bar_spec=bar_spec) # Act self.client.subscribe_bars(bar_type) # Assert assert True def test_unsubscribe_instrument(self): # Arrange self.client.connect() # Act self.client.unsubscribe_instrument(AUDUSD) # Assert assert True def test_unsubscribe_quote_ticks(self): async def run_test(): # Arrange self.mock_oanda.request.return_value = {"type": {"HEARTBEAT": "0"}} self.data_engine.start() self.client.subscribe_quote_ticks(AUDUSD) await asyncio.sleep(0.3) # # Act self.client.unsubscribe_quote_ticks(AUDUSD) await asyncio.sleep(0.3) # Assert assert AUDUSD not in self.client.subscribed_quote_ticks # Tear Down self.data_engine.stop() self.loop.run_until_complete(run_test()) def test_unsubscribe_bars(self): # Arrange bar_spec = BarSpecification(1, BarAggregation.MINUTE, PriceType.MID) bar_type = BarType(instrument_id=AUDUSD, bar_spec=bar_spec) # Act self.client.unsubscribe_bars(bar_type) # Assert assert True def test_request_instrument(self): async def run_test(): # Arrange self.data_engine.start() # Also starts client # Act self.client.request_instrument(AUDUSD, uuid4()) await asyncio.sleep(1) # Assert # Instruments additionally requested on start assert self.data_engine.response_count == 1 # Tear Down self.data_engine.stop() await self.data_engine.get_run_queue_task() self.loop.run_until_complete(run_test()) def test_request_instruments(self): async def run_test(): # Arrange self.data_engine.start() # Also starts client await asyncio.sleep(0.5) # Act self.client.request_instruments(uuid4()) await asyncio.sleep(1) # Assert # Instruments additionally requested on start assert self.data_engine.response_count == 1 # Tear Down self.data_engine.stop() await self.data_engine.get_run_queue_task() self.loop.run_until_complete(run_test()) def test_request_bars(self): async def run_test(): # Arrange with open(TEST_PATH + "instruments.json") as response: instruments = json.load(response) # Arrange with open(TEST_PATH + "bars.json") as response: bars = json.load(response) self.mock_oanda.request.side_effect = [instruments, bars] handler = ObjectStorer() self.data_engine.start() await asyncio.sleep(0.3) bar_spec = BarSpecification(1, BarAggregation.MINUTE, PriceType.MID) bar_type = BarType(instrument_id=AUDUSD, bar_spec=bar_spec) request = DataRequest( client_id=ClientId(OANDA.value), data_type=DataType( Bar, metadata={ "bar_type": bar_type, "from_datetime": None, "to_datetime": None, "limit": 1000, }, ), callback=handler.store, request_id=self.uuid_factory.generate(), timestamp_ns=self.clock.timestamp_ns(), ) # Act self.data_engine.send(request) # Allow time for request to be sent, processed and response returned await asyncio.sleep(1) # Assert assert self.data_engine.response_count == 1 assert handler.count == 1 # Final bar incomplete so becomes partial assert len(handler.get_store()[0]) == 99 # Tear Down self.data_engine.stop() self.data_engine.dispose() self.loop.run_until_complete(run_test())
class OandaDataClientTests(unittest.TestCase): def setUp(self): # Fixture Setup self.clock = LiveClock() self.uuid_factory = UUIDFactory() self.trader_id = TraderId("TESTER", "001") # Fresh isolated loop testing pattern self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) self.executor = concurrent.futures.ThreadPoolExecutor() self.loop.set_default_executor(self.executor) self.loop.set_debug(True) # TODO: Development # Setup logging logger = LiveLogger( clock=self.clock, name=self.trader_id.value, level_console=LogLevel.DEBUG, level_file=LogLevel.DEBUG, level_store=LogLevel.WARNING, ) self.logger = LiveLogger(self.clock) self.portfolio = Portfolio( clock=self.clock, logger=self.logger, ) self.data_engine = LiveDataEngine( loop=self.loop, portfolio=self.portfolio, clock=self.clock, logger=self.logger, ) self.mock_oanda = MagicMock() self.client = OandaDataClient( client=self.mock_oanda, account_id="001", engine=self.data_engine, clock=self.clock, logger=logger, ) self.data_engine.register_client(self.client) with open(TEST_PATH + "instruments.json") as response: instruments = json.load(response) self.mock_oanda.request.return_value = instruments def tearDown(self): self.executor.shutdown(wait=True) self.loop.stop() self.loop.close() # TODO: WIP # def test_connect(self): # async def run_test(): # # Arrange # # Act # self.data_engine.start() # Also connects client # await asyncio.sleep(0.3) # # # Assert # self.assertTrue(self.client.is_connected) # # # Tear Down # self.data_engine.stop() # # self.loop.run_until_complete(run_test()) def test_disconnect(self): # Arrange self.client.connect() # Act self.client.disconnect() # Assert self.assertFalse(self.client.is_connected) def test_reset(self): # Arrange # Act self.client.reset() # Assert self.assertFalse(self.client.is_connected) def test_dispose(self): # Arrange # Act self.client.dispose() # Assert self.assertFalse(self.client.is_connected) def test_subscribe_instrument(self): # Arrange self.client.connect() # Act self.client.subscribe_instrument(AUDUSD) # Assert self.assertIn(AUDUSD, self.client.subscribed_instruments) def test_subscribe_quote_ticks(self): async def run_test(): # Arrange self.mock_oanda.request.return_value = {"type": {"HEARTBEAT": "0"}} self.data_engine.start() # Act self.client.subscribe_quote_ticks(AUDUSD) await asyncio.sleep(0.3) # Assert self.assertIn(AUDUSD, self.client.subscribed_quote_ticks) # Tear Down self.data_engine.stop() self.loop.run_until_complete(run_test()) def test_subscribe_bars(self): # Arrange bar_spec = BarSpecification(1, BarAggregation.MINUTE, PriceType.MID) bar_type = BarType(instrument_id=AUDUSD, bar_spec=bar_spec) # Act self.client.subscribe_bars(bar_type) # Assert self.assertTrue(True) def test_unsubscribe_instrument(self): # Arrange self.client.connect() # Act self.client.unsubscribe_instrument(AUDUSD) # Assert self.assertTrue(True) def test_unsubscribe_quote_ticks(self): async def run_test(): # Arrange self.mock_oanda.request.return_value = {"type": {"HEARTBEAT": "0"}} self.data_engine.start() self.client.subscribe_quote_ticks(AUDUSD) await asyncio.sleep(0.3) # # Act self.client.unsubscribe_quote_ticks(AUDUSD) await asyncio.sleep(0.3) # Assert self.assertNotIn(AUDUSD, self.client.subscribed_quote_ticks) # Tear Down self.data_engine.stop() self.loop.run_until_complete(run_test()) def test_unsubscribe_bars(self): # Arrange bar_spec = BarSpecification(1, BarAggregation.MINUTE, PriceType.MID) bar_type = BarType(instrument_id=AUDUSD, bar_spec=bar_spec) # Act self.client.unsubscribe_bars(bar_type) # Assert self.assertTrue(True) def test_request_instrument(self): async def run_test(): # Arrange self.data_engine.start() # Also starts client await asyncio.sleep(0.5) # Act self.client.request_instrument(AUDUSD, uuid4()) await asyncio.sleep(0.5) # Assert # Instruments additionally requested on start self.assertEqual(1, self.data_engine.response_count) # Tear Down self.data_engine.stop() await self.data_engine.get_run_queue_task() self.loop.run_until_complete(run_test()) def test_request_instruments(self): async def run_test(): # Arrange self.data_engine.start() # Also starts client await asyncio.sleep(0.5) # Act self.client.request_instruments(uuid4()) await asyncio.sleep(0.5) # Assert # Instruments additionally requested on start self.assertEqual(1, self.data_engine.response_count) # Tear Down self.data_engine.stop() await self.data_engine.get_run_queue_task() self.loop.run_until_complete(run_test()) def test_request_bars(self): async def run_test(): # Arrange with open(TEST_PATH + "instruments.json") as response: instruments = json.load(response) # Arrange with open(TEST_PATH + "bars.json") as response: bars = json.load(response) self.mock_oanda.request.side_effect = [instruments, bars] handler = ObjectStorer() self.data_engine.start() await asyncio.sleep(0.3) bar_spec = BarSpecification(1, BarAggregation.MINUTE, PriceType.MID) bar_type = BarType(instrument_id=AUDUSD, bar_spec=bar_spec) request = DataRequest( provider=OANDA.value, data_type=DataType(Bar, metadata={ "BarType": bar_type, "FromDateTime": None, "ToDateTime": None, "Limit": 1000, }), callback=handler.store_2, request_id=self.uuid_factory.generate(), request_timestamp=self.clock.utc_now(), ) # Act self.data_engine.send(request) # Allow time for request to be sent, processed and response returned await asyncio.sleep(0.3) # Assert self.assertEqual(1, self.data_engine.response_count) self.assertEqual(1, handler.count) # Final bar incomplete so becomes partial self.assertEqual(99, len(handler.get_store()[0][1])) # Tear Down self.data_engine.stop() self.data_engine.dispose() self.loop.run_until_complete(run_test())
class TestBetfairDataClient: def setup(self): # Fixture Setup self.loop = asyncio.get_event_loop() self.loop.set_debug(True) self.clock = LiveClock() self.uuid_factory = UUIDFactory() self.trader_id = TestStubs.trader_id() self.uuid = UUID4() self.venue = BETFAIR_VENUE self.account_id = AccountId(self.venue.value, "001") # Setup logging self.logger = LiveLogger(loop=self.loop, clock=self.clock, level_stdout=LogLevel.ERROR) self._log = LoggerAdapter("TestBetfairExecutionClient", self.logger) self.msgbus = MessageBus( trader_id=self.trader_id, clock=self.clock, logger=self.logger, ) self.cache = TestStubs.cache() self.cache.add_instrument(BetfairTestStubs.betting_instrument()) self.portfolio = Portfolio( msgbus=self.msgbus, cache=self.cache, clock=self.clock, logger=self.logger, ) self.data_engine = LiveDataEngine( loop=self.loop, msgbus=self.msgbus, cache=self.cache, clock=self.clock, logger=self.logger, ) self.betfair_client = BetfairTestStubs.betfair_client( loop=self.loop, logger=self.logger) self.instrument_provider = BetfairTestStubs.instrument_provider( betfair_client=self.betfair_client) # Add a subset of instruments instruments = [ ins for ins in INSTRUMENTS if ins.market_id in BetfairDataProvider.market_ids() ] self.instrument_provider.add_bulk(instruments) self.client = BetfairDataClient( loop=self.loop, client=self.betfair_client, msgbus=self.msgbus, cache=self.cache, clock=self.clock, logger=self.logger, instrument_provider=self.instrument_provider, market_filter={}, ) self.data_engine.register_client(self.client) # Re-route exec engine messages through `handler` self.messages = [] def handler(x, endpoint): self.messages.append(x) if endpoint == "execute": self.data_engine.execute(x) elif endpoint == "process": self.data_engine.process(x) elif endpoint == "response": self.data_engine.response(x) self.msgbus.deregister( endpoint="DataEngine.execute", handler=self.data_engine.execute) # type: ignore self.msgbus.register( endpoint="DataEngine.execute", handler=partial(handler, endpoint="execute") # type: ignore ) self.msgbus.deregister( endpoint="DataEngine.process", handler=self.data_engine.process) # type: ignore self.msgbus.register( endpoint="DataEngine.process", handler=partial(handler, endpoint="process") # type: ignore ) self.msgbus.deregister( endpoint="DataEngine.response", handler=self.data_engine.response) # type: ignore self.msgbus.register( endpoint="DataEngine.response", handler=partial(handler, endpoint="response") # type: ignore ) @pytest.mark.asyncio @patch( "nautilus_trader.adapters.betfair.data.BetfairDataClient._post_connect_heartbeat" ) @patch( "nautilus_trader.adapters.betfair.data.BetfairMarketStreamClient.connect" ) @patch("nautilus_trader.adapters.betfair.client.core.BetfairClient.connect" ) async def test_connect(self, mock_client_connect, mock_stream_connect, mock_post_connect_heartbeat): await self.client._connect() def test_subscriptions(self): self.client.subscribe_trade_ticks(BetfairTestStubs.instrument_id()) self.client.subscribe_instrument_status_updates( BetfairTestStubs.instrument_id()) self.client.subscribe_instrument_close_prices( BetfairTestStubs.instrument_id()) def test_market_heartbeat(self): self.client._on_market_update(BetfairStreaming.mcm_HEARTBEAT()) def test_stream_latency(self): logs = [] self.logger.register_sink(logs.append) self.client.start() self.client._on_market_update(BetfairStreaming.mcm_latency()) warning, degrading, degraded = logs[2:] assert warning["level"] == "WRN" assert warning["msg"] == "Stream unhealthy, waiting for recover" assert degraded["msg"] == "DEGRADED." def test_stream_con_true(self): logs = [] self.logger.register_sink(logs.append) self.client._on_market_update(BetfairStreaming.mcm_con_true()) (warning, ) = logs assert warning["level"] == "WRN" assert ( warning["msg"] == "Conflated stream - consuming data too slow (data received is delayed)" ) @pytest.mark.asyncio async def test_market_sub_image_market_def(self): update = BetfairStreaming.mcm_SUB_IMAGE() self.client._on_market_update(update) result = [type(event).__name__ for event in self.messages] expected = ["InstrumentStatusUpdate"] * 7 + ["OrderBookSnapshot"] * 7 assert result == expected # Check prices are probabilities result = set( float(order[0]) for ob_snap in self.messages if isinstance(ob_snap, OrderBookSnapshot) for order in ob_snap.bids + ob_snap.asks) expected = set([ 0.0010204, 0.0076923, 0.0217391, 0.0238095, 0.1724138, 0.2173913, 0.3676471, 0.3937008, 0.4587156, 0.5555556, ]) assert result == expected def test_market_sub_image_no_market_def(self): self.client._on_market_update( BetfairStreaming.mcm_SUB_IMAGE_no_market_def()) result = Counter([type(event).__name__ for event in self.messages]) expected = Counter({ "InstrumentStatusUpdate": 270, "OrderBookSnapshot": 270, "InstrumentClosePrice": 22, }) assert result == expected def test_market_resub_delta(self): self.client._on_market_update(BetfairStreaming.mcm_RESUB_DELTA()) result = [type(event).__name__ for event in self.messages] expected = ["InstrumentStatusUpdate"] * 12 + ["OrderBookDeltas"] * 269 assert result == expected def test_market_update(self): self.client._on_market_update(BetfairStreaming.mcm_UPDATE()) result = [type(event).__name__ for event in self.messages] expected = ["OrderBookDeltas"] * 1 assert result == expected result = [d.action for d in self.messages[0].deltas] expected = [BookAction.UPDATE, BookAction.DELETE] assert result == expected # Ensure order prices are coming through as probability update_op = self.messages[0].deltas[0] assert update_op.order.price == 0.212766 def test_market_update_md(self): self.client._on_market_update(BetfairStreaming.mcm_UPDATE_md()) result = [type(event).__name__ for event in self.messages] expected = ["InstrumentStatusUpdate"] * 2 assert result == expected def test_market_update_live_image(self): self.client._on_market_update(BetfairStreaming.mcm_live_IMAGE()) result = [type(event).__name__ for event in self.messages] expected = (["OrderBookSnapshot"] + ["TradeTick"] * 13 + ["OrderBookSnapshot"] + ["TradeTick"] * 17) assert result == expected def test_market_update_live_update(self): self.client._on_market_update(BetfairStreaming.mcm_live_UPDATE()) result = [type(event).__name__ for event in self.messages] expected = ["TradeTick", "OrderBookDeltas"] assert result == expected def test_market_bsp(self): # Setup update = BetfairStreaming.mcm_BSP() provider = self.client.instrument_provider() for mc in update[0]["mc"]: market_def = {**mc["marketDefinition"], "marketId": mc["id"]} instruments = make_instruments(market_definition=market_def, currency="GBP") provider.add_bulk(instruments) for update in update: self.client._on_market_update(update) result = Counter([type(event).__name__ for event in self.messages]) expected = { "TradeTick": 95, "BSPOrderBookDelta": 30, "InstrumentStatusUpdate": 9, "OrderBookSnapshot": 8, "OrderBookDeltas": 2, } assert result == expected @pytest.mark.asyncio async def test_request_search_instruments(self): req = DataType( type=InstrumentSearch, metadata={"event_type_id": "7"}, ) self.client.request(req, self.uuid) await asyncio.sleep(0) resp = self.messages[0] assert len(resp.data.instruments) == 6800 def test_orderbook_repr(self): self.client._on_market_update(BetfairStreaming.mcm_live_IMAGE()) ob_snap = self.messages[14] ob = L2OrderBook(InstrumentId(Symbol("1"), BETFAIR_VENUE), 5, 5) ob.apply_snapshot(ob_snap) print(ob.pprint()) assert ob.best_ask_price() == 0.5882353 assert ob.best_bid_price() == 0.5847953 def test_orderbook_updates(self): order_books = {} for raw_update in BetfairStreaming.market_updates(): for update in on_market_update( update=raw_update, instrument_provider=self.client.instrument_provider(), ): if len(order_books) > 1 and update.instrument_id != list( order_books)[1]: continue print(update) if isinstance(update, OrderBookSnapshot): order_books[update.instrument_id] = L2OrderBook( instrument_id=update.instrument_id, price_precision=4, size_precision=4, ) order_books[update.instrument_id].apply_snapshot(update) elif isinstance(update, OrderBookDeltas): order_books[update.instrument_id].apply_deltas(update) elif isinstance(update, TradeTick): pass else: raise KeyError book = order_books[list(order_books)[0]] expected = """bids price asks -------- ------- --------- 0.8621 [932.64] 0.8547 [1275.83] 0.8475 [151.96] [147.79] 0.8403 [156.74] 0.8333 [11.19] 0.8197""" result = book.pprint() assert result == expected def test_instrument_opening_events(self): updates = BetfairDataProvider.raw_market_updates() messages = on_market_update( instrument_provider=self.client.instrument_provider(), update=updates[0]) assert len(messages) == 2 assert (isinstance(messages[0], InstrumentStatusUpdate) and messages[0].status == InstrumentStatus.PRE_OPEN) assert (isinstance(messages[1], InstrumentStatusUpdate) and messages[0].status == InstrumentStatus.PRE_OPEN) def test_instrument_in_play_events(self): events = [ msg for update in BetfairDataProvider.raw_market_updates() for msg in on_market_update( instrument_provider=self.client.instrument_provider(), update=update) if isinstance(msg, InstrumentStatusUpdate) ] assert len(events) == 14 result = [ev.status for ev in events] expected = [ InstrumentStatus.PRE_OPEN.value, InstrumentStatus.PRE_OPEN.value, InstrumentStatus.PRE_OPEN.value, InstrumentStatus.PRE_OPEN.value, InstrumentStatus.PRE_OPEN.value, InstrumentStatus.PRE_OPEN.value, InstrumentStatus.PAUSE.value, InstrumentStatus.PAUSE.value, InstrumentStatus.OPEN.value, InstrumentStatus.OPEN.value, InstrumentStatus.PAUSE.value, InstrumentStatus.PAUSE.value, InstrumentStatus.CLOSED.value, InstrumentStatus.CLOSED.value, ] assert result == expected def test_instrument_closing_events(self): updates = BetfairDataProvider.raw_market_updates() messages = on_market_update( instrument_provider=self.client.instrument_provider(), update=updates[-1], ) assert len(messages) == 4 assert (isinstance(messages[0], InstrumentStatusUpdate) and messages[0].status == InstrumentStatus.CLOSED) assert isinstance( messages[1], InstrumentClosePrice) and messages[1].close_price == 1.0000 assert (isinstance(messages[1], InstrumentClosePrice) and messages[1].close_type == InstrumentCloseType.EXPIRED) assert (isinstance(messages[2], InstrumentStatusUpdate) and messages[2].status == InstrumentStatus.CLOSED) assert isinstance( messages[3], InstrumentClosePrice) and messages[3].close_price == 0.0 assert (isinstance(messages[3], InstrumentClosePrice) and messages[3].close_type == InstrumentCloseType.EXPIRED) def test_betfair_ticker(self): self.client._on_market_update(BetfairStreaming.mcm_UPDATE_tv()) ticker: BetfairTicker = self.messages[1] assert ticker.last_traded_price == Price.from_str("0.3174603") assert ticker.traded_volume == Quantity.from_str("364.45") def test_betfair_orderbook(self): book = L2OrderBook( instrument_id=BetfairTestStubs.instrument_id(), price_precision=2, size_precision=2, ) for update in BetfairDataProvider.raw_market_updates(): for message in on_market_update( instrument_provider=self.instrument_provider, update=update): try: if isinstance(message, OrderBookSnapshot): book.apply_snapshot(message) elif isinstance(message, OrderBookDeltas): book.apply_deltas(message) elif isinstance(message, OrderBookDelta): book.apply_delta(message) elif isinstance(message, (Ticker, TradeTick, InstrumentStatusUpdate, InstrumentClosePrice)): pass else: raise NotImplementedError(str(type(message))) book.check_integrity() except Exception as ex: print(str(type(ex)) + " " + str(ex))