Beispiel #1
0
class BinanceAPIOrderBookDataSourceUnitTests(unittest.TestCase):
    # logging.Level required to receive logs from the data source logger
    level = 0

    @classmethod
    def setUpClass(cls) -> None:
        super().setUpClass()
        cls.ev_loop = asyncio.get_event_loop()
        cls.base_asset = "COINALPHA"
        cls.quote_asset = "HBOT"
        cls.trading_pair = f"{cls.base_asset}-{cls.quote_asset}"
        cls.ex_trading_pair = cls.base_asset + cls.quote_asset
        cls.domain = "com"

    def setUp(self) -> None:
        super().setUp()
        self.log_records = []
        self.listening_task = None
        self.mocking_assistant = NetworkMockingAssistant()

        self.throttler = AsyncThrottler(rate_limits=CONSTANTS.RATE_LIMITS)
        self.data_source = BinanceAPIOrderBookDataSource(
            trading_pairs=[self.trading_pair],
            throttler=self.throttler,
            domain=self.domain)
        self.data_source.logger().setLevel(1)
        self.data_source.logger().addHandler(self)

        self.resume_test_event = asyncio.Event()

    def tearDown(self) -> None:
        self.listening_task and self.listening_task.cancel()
        super().tearDown()

    def handle(self, record):
        self.log_records.append(record)

    def _is_logged(self, log_level: str, message: str) -> bool:
        return any(
            record.levelname == log_level and record.getMessage() == message
            for record in self.log_records)

    def _raise_exception(self, exception_class):
        raise exception_class

    def _create_exception_and_unlock_test_with_event(self, exception):
        self.resume_test_event.set()
        raise exception

    def _successfully_subscribed_event(self):
        resp = {"result": None, "id": 1}
        return resp

    def _trade_update_event(self):
        resp = {
            "e": "trade",
            "E": 123456789,
            "s": self.ex_trading_pair,
            "t": 12345,
            "p": "0.001",
            "q": "100",
            "b": 88,
            "a": 50,
            "T": 123456785,
            "m": True,
            "M": True
        }
        return resp

    def _order_diff_event(self):
        resp = {
            "e": "depthUpdate",
            "E": 123456789,
            "s": self.ex_trading_pair,
            "U": 157,
            "u": 160,
            "b": [["0.0024", "10"]],
            "a": [["0.0026", "100"]]
        }
        return resp

    def _snapshot_response(self):
        resp = {
            "lastUpdateId": 1027024,
            "bids": [["4.00000000", "431.00000000"]],
            "asks": [["4.00000200", "12.00000000"]]
        }
        return resp

    @aioresponses()
    def test_get_last_trade_prices(self, mock_api):
        url = utils.public_rest_url(
            path_url=CONSTANTS.TICKER_PRICE_CHANGE_PATH_URL,
            domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_response: Dict[str, Any] = {
            # Truncated Response
            "lastPrice": "100",
        }

        mock_api.get(regex_url, body=ujson.dumps(mock_response))

        result: Dict[str, float] = self.ev_loop.run_until_complete(
            self.data_source.get_last_traded_prices(
                trading_pairs=[self.trading_pair], throttler=self.throttler))

        self.assertEqual(1, len(result))
        self.assertEqual(100, result[self.trading_pair])

    @aioresponses()
    @patch(
        "hummingbot.connector.exchange.binance.binance_utils.convert_from_exchange_trading_pair"
    )
    def test_get_all_mid_prices(self, mock_api, mock_utils):
        # Mocks binance_utils for BinanceOrderBook.diff_message_from_exchange()
        mock_utils.return_value = self.trading_pair
        url = utils.public_rest_url(
            path_url=CONSTANTS.TICKER_PRICE_CHANGE_PATH_URL,
            domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_response: List[Dict[str, Any]] = [{
            # Truncated Response
            "symbol": self.ex_trading_pair,
            "bidPrice": "99",
            "askPrice": "101",
        }]

        mock_api.get(regex_url, body=ujson.dumps(mock_response))

        result: Dict[str, float] = self.ev_loop.run_until_complete(
            self.data_source.get_all_mid_prices())

        self.assertEqual(1, len(result))
        self.assertEqual(100, result[self.trading_pair])

    @aioresponses()
    @patch(
        "hummingbot.connector.exchange.binance.binance_utils.convert_from_exchange_trading_pair"
    )
    def test_fetch_trading_pairs(self, mock_api, mock_utils):
        # Mocks binance_utils for BinanceOrderBook.diff_message_from_exchange()
        mock_utils.return_value = self.trading_pair
        url = utils.public_rest_url(path_url=CONSTANTS.EXCHANGE_INFO_PATH_URL,
                                    domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_response: Dict[str, Any] = {
            # Truncated Response
            "symbols": [
                {
                    "symbol": self.ex_trading_pair,
                    "status": "TRADING",
                    "baseAsset": self.base_asset,
                    "quoteAsset": self.quote_asset,
                },
            ]
        }

        mock_api.get(regex_url, body=ujson.dumps(mock_response))

        result: Dict[str] = self.ev_loop.run_until_complete(
            self.data_source.fetch_trading_pairs())

        self.assertEqual(1, len(result))
        self.assertTrue(self.trading_pair in result)

    @aioresponses()
    @patch(
        "hummingbot.connector.exchange.binance.binance_utils.convert_from_exchange_trading_pair"
    )
    def test_fetch_trading_pairs_exception_raised(self, mock_api, mock_utils):
        # Mocks binance_utils for BinanceOrderBook.diff_message_from_exchange()
        mock_utils.return_value = self.trading_pair
        url = utils.public_rest_url(path_url=CONSTANTS.EXCHANGE_INFO_PATH_URL,
                                    domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, exception=Exception)

        result: Dict[str] = self.ev_loop.run_until_complete(
            self.data_source.fetch_trading_pairs())

        self.assertEqual(0, len(result))

    def test_get_throttler_instance(self):
        self.assertIsInstance(
            BinanceAPIOrderBookDataSource._get_throttler_instance(),
            AsyncThrottler)

    @aioresponses()
    def test_get_snapshot_successful(self, mock_api):
        url = utils.public_rest_url(path_url=CONSTANTS.SNAPSHOT_PATH_URL,
                                    domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, body=ujson.dumps(self._snapshot_response()))

        result: Dict[str, Any] = self.ev_loop.run_until_complete(
            self.data_source.get_snapshot(self.trading_pair))

        self.assertEqual(self._snapshot_response(), result)

    @aioresponses()
    def test_get_snapshot_catch_exception(self, mock_api):
        url = utils.public_rest_url(path_url=CONSTANTS.SNAPSHOT_PATH_URL,
                                    domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, status=400)
        with self.assertRaises(IOError):
            self.ev_loop.run_until_complete(
                self.data_source.get_snapshot(self.trading_pair))

    @aioresponses()
    def test_get_new_order_book(self, mock_api):
        url = utils.public_rest_url(path_url=CONSTANTS.SNAPSHOT_PATH_URL,
                                    domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_response: Dict[str, Any] = {
            "lastUpdateId": 1,
            "bids": [["4.00000000", "431.00000000"]],
            "asks": [["4.00000200", "12.00000000"]]
        }
        mock_api.get(regex_url, body=ujson.dumps(mock_response))

        result: OrderBook = self.ev_loop.run_until_complete(
            self.data_source.get_new_order_book(self.trading_pair))

        self.assertEqual(1, result.snapshot_uid)

    @patch("aiohttp.ClientSession.ws_connect")
    def test_create_websocket_connection_cancelled_when_connecting(
            self, mock_ws):
        mock_ws.side_effect = asyncio.CancelledError

        with self.assertRaises(asyncio.CancelledError):
            self.ev_loop.run_until_complete(
                self.data_source._create_websocket_connection())

    @patch("aiohttp.ClientSession.ws_connect")
    def test_create_websocket_connection_exception_raised(self, mock_ws):
        mock_ws.side_effect = Exception("TEST ERROR.")

        with self.assertRaises(Exception):
            self.ev_loop.run_until_complete(
                self.data_source._create_websocket_connection())

        self.assertTrue(
            self._is_logged(
                "NETWORK",
                "Unexpected error occured when connecting to WebSocket server. Error: TEST ERROR."
            ))

    @patch(
        "hummingbot.core.data_type.order_book_tracker_data_source.OrderBookTrackerDataSource._sleep"
    )
    @patch("aiohttp.ClientSession.ws_connect")
    def test_listen_for_trades_cancelled_when_connecting(
            self, mock_ws, _: AsyncMock):
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.side_effect = asyncio.CancelledError

        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.data_source.listen_for_trades(self.ev_loop, msg_queue))
            self.ev_loop.run_until_complete(self.listening_task)

    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_trades_exception_raised_when_connecting(self, mock_ws):
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.side_effect = lambda **_: self._create_exception_and_unlock_test_with_event(
            Exception("TEST ERROR."))

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_trades(self.ev_loop, msg_queue))

        self.ev_loop.run_until_complete(self.resume_test_event.wait())

        self.assertTrue(
            self._is_logged(
                "NETWORK",
                "Unexpected error occured when connecting to WebSocket server. Error: TEST ERROR."
            ))
        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error with WebSocket connection. Retrying after 30 seconds..."
            ))

    @patch(
        "hummingbot.core.data_type.order_book_tracker_data_source.OrderBookTrackerDataSource._sleep"
    )
    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_trades_cancelled_when_listening(
            self, mock_ws, _: AsyncMock):
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        mock_ws.return_value.receive_json.side_effect = lambda: (
            self._raise_exception(asyncio.CancelledError))
        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.data_source.listen_for_trades(self.ev_loop, msg_queue))
            self.ev_loop.run_until_complete(self.listening_task)

    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_trades_logs_exception(self, mock_ws):
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        mock_ws.close.return_value = None

        incomplete_resp = {
            "m": 1,
            "i": 2,
        }
        self.mocking_assistant.add_websocket_json_message(
            mock_ws.return_value, incomplete_resp)
        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_trades(self.ev_loop, msg_queue))

        with self.assertRaises(asyncio.TimeoutError):
            self.ev_loop.run_until_complete(
                asyncio.wait_for(self.listening_task, 1))

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error with WebSocket connection. Retrying after 30 seconds..."
            ))

    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_trades_iter_message_throws_exception(self, mock_ws):
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        mock_ws.return_value.receive_json.side_effect = lambda: self._raise_exception(
            Exception("TEST ERROR"))
        mock_ws.close.return_value = None

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_trades(self.ev_loop, msg_queue))

        with self.assertRaises(asyncio.TimeoutError):
            self.ev_loop.run_until_complete(
                asyncio.wait_for(self.listening_task, 1))
        self.assertTrue(
            self._is_logged(
                "NETWORK",
                "Unexpected error occured when parsing websocket payload. Error: TEST ERROR"
            ))
        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error with WebSocket connection. Retrying after 30 seconds..."
            ))

    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_trades_successful(self, mock_ws):
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        mock_ws.close.return_value = None

        self.mocking_assistant.add_websocket_json_message(
            mock_ws.return_value, self._successfully_subscribed_event())
        self.mocking_assistant.add_websocket_json_message(
            mock_ws.return_value, self._trade_update_event())
        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_trades(self.ev_loop, msg_queue))

        msg: OrderBookMessage = self.ev_loop.run_until_complete(
            msg_queue.get())

        self.assertTrue(12345, msg.trade_id)

    @patch(
        "hummingbot.core.data_type.order_book_tracker_data_source.OrderBookTrackerDataSource._sleep"
    )
    @patch("aiohttp.ClientSession.ws_connect")
    def test_listen_for_order_book_diffs_cancelled_when_connecting(
            self, mock_ws, _: AsyncMock):
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.side_effect = asyncio.CancelledError

        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.data_source.listen_for_order_book_diffs(
                    self.ev_loop, msg_queue))
            self.ev_loop.run_until_complete(self.listening_task)

    @patch(
        "hummingbot.core.data_type.order_book_tracker_data_source.OrderBookTrackerDataSource._sleep"
    )
    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_order_book_diffs_cancelled_when_listening(
            self, mock_ws, _: AsyncMock):
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        mock_ws.return_value.receive_json.side_effect = lambda: (
            self._raise_exception(asyncio.CancelledError))
        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.data_source.listen_for_order_book_diffs(
                    self.ev_loop, msg_queue))
            self.ev_loop.run_until_complete(self.listening_task)

    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_order_book_diffs_logs_exception(self, mock_ws):
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        mock_ws.close.return_value = None

        incomplete_resp = {
            "m": 1,
            "i": 2,
        }
        self.mocking_assistant.add_websocket_json_message(
            mock_ws.return_value, incomplete_resp)
        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_order_book_diffs(
                self.ev_loop, msg_queue))

        with self.assertRaises(asyncio.TimeoutError):
            self.ev_loop.run_until_complete(
                asyncio.wait_for(self.listening_task, 1))

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error with WebSocket connection. Retrying after 30 seconds..."
            ))

    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_order_book_diffs_successful(self, mock_ws):
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        mock_ws.close.return_value = None

        self.mocking_assistant.add_websocket_json_message(
            mock_ws.return_value, self._successfully_subscribed_event())
        self.mocking_assistant.add_websocket_json_message(
            mock_ws.return_value, self._order_diff_event())
        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_order_book_diffs(
                self.ev_loop, msg_queue))

        msg: OrderBookMessage = self.ev_loop.run_until_complete(
            msg_queue.get())

        self.assertTrue(12345, msg.update_id)

    @aioresponses()
    def test_listen_for_order_book_snapshots_cancelled_when_fetching_snapshot(
            self, mock_api):
        url = utils.public_rest_url(path_url=CONSTANTS.SNAPSHOT_PATH_URL,
                                    domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, exception=asyncio.CancelledError)

        with self.assertRaises(asyncio.CancelledError):
            self.ev_loop.run_until_complete(
                self.data_source.listen_for_order_book_snapshots(
                    self.ev_loop, asyncio.Queue()))

    @aioresponses()
    def test_listen_for_order_book_snapshots_log_exception(self, mock_api):
        msg_queue: asyncio.Queue = asyncio.Queue()

        url = utils.public_rest_url(path_url=CONSTANTS.SNAPSHOT_PATH_URL,
                                    domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, exception=Exception)

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_order_book_snapshots(
                self.ev_loop, msg_queue))
        with self.assertRaises(asyncio.TimeoutError):
            self.ev_loop.run_until_complete(
                asyncio.wait_for(self.listening_task, 1))

        self.assertTrue(
            self._is_logged(
                "ERROR",
                f"Unexpected error fetching order book snapshot for {self.trading_pair}."
            ))

    @aioresponses()
    def test_listen_for_order_book_snapshots_successful(
        self,
        mock_api,
    ):
        msg_queue: asyncio.Queue = asyncio.Queue()
        url = utils.public_rest_url(path_url=CONSTANTS.SNAPSHOT_PATH_URL,
                                    domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, body=ujson.dumps(self._snapshot_response()))

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_order_book_snapshots(
                self.ev_loop, msg_queue))

        msg: OrderBookMessage = self.ev_loop.run_until_complete(
            msg_queue.get())

        self.assertTrue(12345, msg.update_id)
class BinancePerpetualAPIOrderBookDataSourceUnitTests(unittest.TestCase):
    # logging.Level required to receive logs from the data source logger
    level = 0

    @classmethod
    def setUpClass(cls) -> None:
        super().setUpClass()
        cls.ev_loop = asyncio.get_event_loop()
        cls.base_asset = "COINALPHA"
        cls.quote_asset = "HBOT"
        cls.trading_pair = f"{cls.base_asset}-{cls.quote_asset}"
        cls.ex_trading_pair = f"{cls.base_asset}{cls.quote_asset}"
        cls.domain = "binance_perpetual_testnet"

    def setUp(self) -> None:
        super().setUp()
        self.log_records = []
        self.listening_task = None

        self.data_source = BinancePerpetualAPIOrderBookDataSource(
            trading_pairs=[self.trading_pair],
            domain=self.domain,
        )
        self.data_source.logger().setLevel(1)
        self.data_source.logger().addHandler(self)

        self.mocking_assistant = NetworkMockingAssistant()
        self.resume_test_event = asyncio.Event()

    def tearDown(self) -> None:
        self.listening_task and self.listening_task.cancel()
        super().tearDown()

    def handle(self, record):
        self.log_records.append(record)

    def async_run_with_timeout(self, coroutine: Awaitable, timeout: float = 1):
        ret = self.ev_loop.run_until_complete(
            asyncio.wait_for(coroutine, timeout))
        return ret

    def resume_test_callback(self, *_, **__):
        self.resume_test_event.set()
        return None

    def _is_logged(self, log_level: str, message: str) -> bool:
        return any(
            record.levelname == log_level and record.getMessage() == message
            for record in self.log_records)

    def _raise_exception(self, exception_class):
        raise exception_class

    def _orderbook_update_event(self):
        resp = {
            "stream": f"{self.ex_trading_pair.lower()}@depth",
            "data": {
                "e": "depthUpdate",
                "E": 1631591424198,
                "T": 1631591424189,
                "s": self.ex_trading_pair,
                "U": 752409354963,
                "u": 752409360466,
                "pu": 752409354901,
                "b": [
                    ["43614.31", "0.000"],
                ],
                "a": [
                    ["45277.14", "0.257"],
                ]
            }
        }
        return resp

    def _orderbook_trade_event(self):
        resp = {
            "stream": f"{self.ex_trading_pair.lower()}@aggTrade",
            "data": {
                "e": "aggTrade",
                "E": 1631594403486,
                "a": 817295132,
                "s": self.ex_trading_pair,
                "p": "45266.16",
                "q": "2.206",
                "f": 1437689393,
                "l": 1437689407,
                "T": 1631594403330,
                "m": False
            }
        }
        return resp

    @aioresponses()
    def test_get_last_traded_prices(self, mock_api):
        url = utils.rest_url(path_url=CONSTANTS.TICKER_PRICE_CHANGE_URL,
                             domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))
        mock_response: Dict[str, Any] = {
            # Truncated responses
            "lastPrice": "10.0",
        }
        mock_api.get(regex_url, body=ujson.dumps(mock_response))

        result: Dict[str, Any] = self.async_run_with_timeout(
            self.data_source.get_last_traded_prices(
                trading_pairs=[self.trading_pair], domain=self.domain))
        self.assertTrue(self.trading_pair in result)
        self.assertEqual(10.0, result[self.trading_pair])

    def test_get_throttler_instance(self):
        self.assertTrue(
            isinstance(self.data_source._get_throttler_instance(),
                       AsyncThrottler))

    @aioresponses()
    def test_fetch_trading_pairs_failure(self, mock_api):
        url = utils.rest_url(path_url=CONSTANTS.EXCHANGE_INFO_URL,
                             domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, status=400, body=ujson.dumps({"ERROR"}))

        result: Dict[str, Any] = self.async_run_with_timeout(
            self.data_source.fetch_trading_pairs(domain=self.domain))
        self.assertEqual(0, len(result))

    @aioresponses()
    @patch(
        "hummingbot.connector.derivative.binance_perpetual.binance_perpetual_utils.convert_from_exchange_trading_pair"
    )
    def test_fetch_trading_pairs_successful(self, mock_api, mock_utils):
        mock_utils.return_value = self.trading_pair
        url = utils.rest_url(path_url=CONSTANTS.EXCHANGE_INFO_URL,
                             domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))
        mock_response: Dict[str, Any] = {
            # Truncated Responses
            "symbols": [{
                "symbol": self.ex_trading_pair,
                "pair": self.ex_trading_pair,
                "baseAsset": self.base_asset,
                "quoteAsset": self.quote_asset,
                "status": "TRADING",
            }, {
                "symbol": "INACTIVEMARKET",
                "status": "INACTIVE"
            }],
        }
        mock_api.get(regex_url, status=200, body=ujson.dumps(mock_response))
        result: Dict[str, Any] = self.async_run_with_timeout(
            self.data_source.fetch_trading_pairs(domain=self.domain))
        self.assertEqual(1, len(result))

    @aioresponses()
    def test_get_snapshot_exception_raised(self, mock_api):
        url = utils.rest_url(CONSTANTS.SNAPSHOT_REST_URL, domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))
        mock_api.get(regex_url, status=400, body=ujson.dumps({"ERROR"}))

        with self.assertRaises(IOError) as context:
            self.async_run_with_timeout(
                self.data_source.get_snapshot(trading_pair=self.trading_pair,
                                              domain=self.domain))

        self.assertEqual(
            str(context.exception),
            f"Error fetching Binance market snapshot for {self.trading_pair}.")

    @aioresponses()
    def test_get_snapshot_successful(self, mock_api):
        url = utils.rest_url(CONSTANTS.SNAPSHOT_REST_URL, domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))
        mock_response = {
            "lastUpdateId": 1027024,
            "E": 1589436922972,
            "T": 1589436922959,
            "bids": [["10", "1"]],
            "asks": [["11", "1"]]
        }
        mock_api.get(regex_url, status=200, body=ujson.dumps(mock_response))

        result: Dict[str, Any] = self.async_run_with_timeout(
            self.data_source.get_snapshot(trading_pair=self.trading_pair,
                                          domain=self.domain))
        self.assertEqual(mock_response, result)

    @aioresponses()
    def test_get_new_order_book(self, mock_api):
        url = utils.rest_url(CONSTANTS.SNAPSHOT_REST_URL, domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))
        mock_response = {
            "lastUpdateId": 1027024,
            "E": 1589436922972,
            "T": 1589436922959,
            "bids": [["10", "1"]],
            "asks": [["11", "1"]]
        }
        mock_api.get(regex_url, status=200, body=ujson.dumps(mock_response))
        result = self.async_run_with_timeout(
            self.data_source.get_new_order_book(
                trading_pair=self.trading_pair))
        self.assertIsInstance(result, OrderBook)
        self.assertEqual(1027024, result.snapshot_uid)

    @patch("aiohttp.ClientSession.ws_connect")
    def test_create_websocket_connection_cancelled_when_connecting(
            self, mock_ws):
        mock_ws.side_effect = asyncio.CancelledError

        with self.assertRaises(asyncio.CancelledError):
            self.async_run_with_timeout(
                self.data_source._create_websocket_connection())

    @patch("aiohttp.ClientSession.ws_connect")
    def test_create_websocket_connection_exception_raised(self, mock_ws):
        mock_ws.side_effect = Exception("TEST ERROR.")

        with self.assertRaises(Exception):
            self.async_run_with_timeout(
                self.data_source._create_websocket_connection())

        self.assertTrue(
            self._is_logged(
                "NETWORK",
                "Unexpected error occured when connecting to WebSocket server. Error: TEST ERROR."
            ))

    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    @patch(
        "hummingbot.core.data_type.order_book_tracker_data_source.OrderBookTrackerDataSource._sleep"
    )
    def test_listen_for_order_book_diffs_cancelled_when_connecting(
            self, _, mock_ws):
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.side_effect = asyncio.CancelledError

        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.data_source.listen_for_order_book_diffs(
                    self.ev_loop, msg_queue))
            self.async_run_with_timeout(self.listening_task)
        self.assertEqual(msg_queue.qsize(), 0)

    @patch(
        "hummingbot.core.data_type.order_book_tracker_data_source.OrderBookTrackerDataSource._sleep"
    )
    @patch(
        "hummingbot.connector.derivative.binance_perpetual.binance_perpetual_utils.convert_from_exchange_trading_pair"
    )
    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_order_book_diffs_logs_exception(
            self, mock_ws, mock_utils, *_):
        mock_utils.return_value = self.trading_pair
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        mock_ws.close.return_value = None
        incomplete_resp = {
            "m": 1,
            "i": 2,
        }
        self.mocking_assistant.add_websocket_json_message(
            mock_ws.return_value, incomplete_resp)
        self.mocking_assistant.add_websocket_json_message(
            mock_ws.return_value, self._orderbook_update_event())

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_order_book_diffs(
                self.ev_loop, msg_queue))

        self.mocking_assistant.run_until_all_json_messages_delivered(
            mock_ws.return_value)

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error with Websocket connection. Retrying after 30 seconds..."
            ))

    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    @patch(
        "hummingbot.connector.derivative.binance_perpetual.binance_perpetual_utils.convert_from_exchange_trading_pair"
    )
    def test_listen_for_order_book_diffs_successful(self, mock_utils, mock_ws):
        mock_utils.return_value = self.trading_pair
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        mock_ws.close.return_value = None

        self.mocking_assistant.add_websocket_json_message(
            mock_ws.return_value, self._orderbook_update_event())

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_order_book_diffs(
                self.ev_loop, msg_queue))

        result: OrderBookMessage = self.async_run_with_timeout(msg_queue.get())
        self.assertIsInstance(result, OrderBookMessage)
        self.assertEqual(OrderBookMessageType.DIFF, result.type)
        self.assertTrue(result.has_update_id)
        self.assertEqual(result.update_id, 752409360466)
        self.assertEqual(self.trading_pair, result.content["trading_pair"])
        self.assertEqual(1, len(result.content["bids"]))
        self.assertEqual(1, len(result.content["asks"]))

    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_trades_cancelled_error_raised(self, mock_ws):
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        mock_ws.return_value.receive_json.side_effect = lambda: (
            self._raise_exception(asyncio.CancelledError))
        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.data_source.listen_for_trades(self.ev_loop, msg_queue))
            self.async_run_with_timeout(self.listening_task)
        self.assertEqual(msg_queue.qsize(), 0)

    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    @patch(
        "hummingbot.core.data_type.order_book_tracker_data_source.OrderBookTrackerDataSource._sleep"
    )
    @patch(
        "hummingbot.connector.derivative.binance_perpetual.binance_perpetual_utils.convert_from_exchange_trading_pair"
    )
    def test_listen_for_trades_logs_exception(self, mock_utils, _, mock_ws):
        mock_utils.return_value = self.trading_pair
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        mock_ws.close.return_value = None
        incomplete_resp = {
            "m": 1,
            "i": 2,
        }
        self.mocking_assistant.add_websocket_json_message(
            mock_ws.return_value, incomplete_resp)
        self.mocking_assistant.add_websocket_json_message(
            mock_ws.return_value, self._orderbook_trade_event())

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_trades(self.ev_loop, msg_queue))

        self.async_run_with_timeout(msg_queue.get())

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error with Websocket connection. Retrying after 30 seconds..."
            ))

    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    @patch(
        "hummingbot.connector.derivative.binance_perpetual.binance_perpetual_utils.convert_from_exchange_trading_pair"
    )
    def test_listen_for_trades_successful(self, mock_utils, mock_ws):
        mock_utils.return_value = self.trading_pair
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        mock_ws.close.return_value = None

        self.mocking_assistant.add_websocket_json_message(
            mock_ws.return_value, self._orderbook_trade_event())

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_trades(self.ev_loop, msg_queue))

        result: OrderBookMessage = self.async_run_with_timeout(msg_queue.get())
        self.assertIsInstance(result, OrderBookMessage)
        self.assertEqual(OrderBookMessageType.TRADE, result.type)
        self.assertTrue(result.has_trade_id)
        self.assertEqual(result.trade_id, 817295132)
        self.assertEqual(self.trading_pair, result.content["trading_pair"])

    @aioresponses()
    def test_listen_for_order_book_snapshots_cancelled_error_raised(
            self, mock_api):
        url = utils.rest_url(CONSTANTS.SNAPSHOT_REST_URL, domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, exception=asyncio.CancelledError)

        msg_queue: asyncio.Queue = asyncio.Queue()

        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.data_source.listen_for_order_book_snapshots(
                    self.ev_loop, msg_queue))
            self.async_run_with_timeout(self.listening_task)

        self.assertEqual(0, msg_queue.qsize())

    @aioresponses()
    def test_listen_for_order_book_snapshots_logs_exception_error_with_response(
            self, mock_api):
        url = utils.rest_url(CONSTANTS.SNAPSHOT_REST_URL, domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_response = {
            "m": 1,
            "i": 2,
        }
        mock_api.get(regex_url,
                     body=ujson.dumps(mock_response),
                     callback=self.resume_test_callback)

        msg_queue: asyncio.Queue = asyncio.Queue()

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_order_book_snapshots(
                self.ev_loop, msg_queue))

        self.async_run_with_timeout(self.resume_test_event.wait())

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error occurred fetching orderbook snapshots. Retrying in 5 seconds..."
            ))

    @aioresponses()
    def test_listen_for_order_book_snapshots_successful(self, mock_api):
        url = utils.rest_url(CONSTANTS.SNAPSHOT_REST_URL, domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_response = {
            "lastUpdateId": 1027024,
            "E": 1589436922972,
            "T": 1589436922959,
            "bids": [["10", "1"]],
            "asks": [["11", "1"]]
        }
        mock_api.get(regex_url, body=ujson.dumps(mock_response))

        msg_queue: asyncio.Queue = asyncio.Queue()
        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_order_book_snapshots(
                self.ev_loop, msg_queue))

        result = self.async_run_with_timeout(msg_queue.get())

        self.assertIsInstance(result, OrderBookMessage)
        self.assertEqual(OrderBookMessageType.SNAPSHOT, result.type)
        self.assertTrue(result.has_update_id)
        self.assertEqual(result.update_id, 1027024)
        self.assertEqual(self.trading_pair, result.content["trading_pair"])
class BybitPerpetualUserStreamDataSourceTests(TestCase):
    # the level is required to receive logs from the data source loger
    level = 0

    def setUp(self) -> None:
        super().setUp()
        self.api_key = 'testAPIKey'
        self.secret = 'testSecret'
        self.log_records = []
        self.listening_task = None

        self.data_source = BybitPerpetualUserStreamDataSource(
            auth_assistant=BybitPerpetualAuth(api_key=self.api_key,
                                              secret_key=self.secret))
        self.data_source.logger().setLevel(1)
        self.data_source.logger().addHandler(self)

        self.mocking_assistant = NetworkMockingAssistant()

    def tearDown(self) -> None:
        self.listening_task and self.listening_task.cancel()
        if self.data_source._session is not None:
            asyncio.get_event_loop().run_until_complete(
                self.data_source._session.close())
        super().tearDown()

    def handle(self, record):
        self.log_records.append(record)

    def _is_logged(self, log_level: str, message: str) -> bool:
        return any(
            record.levelname == log_level and record.getMessage() == message
            for record in self.log_records)

    def _authentication_response(self, authenticated: bool) -> str:
        request = {
            "op": "auth",
            "args": ['testAPIKey', 'testExpires', 'testSignature']
        }
        message = {
            "success": authenticated,
            "ret_msg": "",
            "conn_id": "testConnectionID",
            "request": request
        }

        return message

    def _subscription_response(self, subscribed: bool,
                               subscription: str) -> str:
        request = {"op": "subscribe", "args": [subscription]}
        message = {
            "success": subscribed,
            "ret_msg": "",
            "conn_id": "testConnectionID",
            "request": request
        }

        return message

    def _raise_exception(self, exception_class):
        raise exception_class

    @patch('aiohttp.ClientSession.ws_connect', new_callable=AsyncMock)
    def test_listening_process_authenticates_and_subscribes_to_events(
            self, ws_connect_mock):
        messages = asyncio.Queue()
        ws_connect_mock.return_value = self.mocking_assistant.create_websocket_mock(
        )
        initial_last_recv_time = self.data_source.last_recv_time

        self.listening_task = asyncio.get_event_loop().create_task(
            self.data_source._listen_for_user_stream_on_url(
                "test_url", messages))
        # Add the authentication response for the websocket
        self.mocking_assistant.add_websocket_json_message(
            ws_connect_mock.return_value, self._authentication_response(True))
        self.mocking_assistant.add_websocket_json_message(
            ws_connect_mock.return_value,
            self._subscription_response(
                True, CONSTANTS.WS_SUBSCRIPTION_POSITIONS_ENDPOINT_NAME))
        self.mocking_assistant.add_websocket_json_message(
            ws_connect_mock.return_value,
            self._subscription_response(
                True, CONSTANTS.WS_SUBSCRIPTION_ORDERS_ENDPOINT_NAME))
        self.mocking_assistant.add_websocket_json_message(
            ws_connect_mock.return_value,
            self._subscription_response(
                True, CONSTANTS.WS_SUBSCRIPTION_EXECUTIONS_ENDPOINT_NAME))

        # Add a dummy message for the websocket to read and include in the "messages" queue
        self.mocking_assistant.add_websocket_json_message(
            ws_connect_mock.return_value, json.dumps('dummyMessage'))

        asyncio.get_event_loop().run_until_complete(messages.get())

        self.assertTrue(
            self._is_logged('INFO', "Authenticating to User Stream..."))
        self.assertTrue(
            self._is_logged('INFO',
                            "Successfully authenticated to User Stream."))
        self.assertTrue(
            self._is_logged(
                'INFO',
                "Successful subscription to the topic ['position'] on test_url"
            ))
        self.assertTrue(
            self._is_logged(
                "INFO",
                "Successful subscription to the topic ['order'] on test_url"))
        self.assertTrue(
            self._is_logged(
                "INFO",
                "Successful subscription to the topic ['execution'] on test_url"
            ))

        sent_messages = self.mocking_assistant.json_messages_sent_through_websocket(
            ws_connect_mock.return_value)
        self.assertEqual(4, len(sent_messages))
        authentication_request = sent_messages[0]
        subscription_positions_request = sent_messages[1]
        subscription_orders_request = sent_messages[2]
        subscription_executions_request = sent_messages[3]

        self.assertEqual(
            CONSTANTS.WS_AUTHENTICATE_USER_ENDPOINT_NAME,
            BybitPerpetualWebSocketAdaptor.endpoint_from_message(
                authentication_request))
        self.assertEqual(
            CONSTANTS.WS_SUBSCRIPTION_POSITIONS_ENDPOINT_NAME,
            BybitPerpetualWebSocketAdaptor.endpoint_from_message(
                subscription_positions_request))
        self.assertEqual(
            CONSTANTS.WS_SUBSCRIPTION_ORDERS_ENDPOINT_NAME,
            BybitPerpetualWebSocketAdaptor.endpoint_from_message(
                subscription_orders_request))
        self.assertEqual(
            CONSTANTS.WS_SUBSCRIPTION_EXECUTIONS_ENDPOINT_NAME,
            BybitPerpetualWebSocketAdaptor.endpoint_from_message(
                subscription_executions_request))

        subscription_positions_payload = BybitPerpetualWebSocketAdaptor.payload_from_message(
            subscription_positions_request)
        expected_payload = {"op": "subscribe", "args": ["position"]}
        self.assertEqual(expected_payload, subscription_positions_payload)

        subscription_orders_payload = BybitPerpetualWebSocketAdaptor.payload_from_message(
            subscription_orders_request)
        expected_payload = {"op": "subscribe", "args": ["order"]}
        self.assertEqual(expected_payload, subscription_orders_payload)

        subscription_executions_payload = BybitPerpetualWebSocketAdaptor.payload_from_message(
            subscription_executions_request)
        expected_payload = {"op": "subscribe", "args": ["execution"]}
        self.assertEqual(expected_payload, subscription_executions_payload)

        self.assertGreater(self.data_source.last_recv_time,
                           initial_last_recv_time)

    @patch('aiohttp.ClientSession.ws_connect', new_callable=AsyncMock)
    def test_listening_process_fails_when_authentication_fails(
            self, ws_connect_mock):
        messages = asyncio.Queue()
        ws_connect_mock.return_value = self.mocking_assistant.create_websocket_mock(
        )
        # Make the close function raise an exception to finish the execution
        ws_connect_mock.return_value.close.side_effect = lambda: self._raise_exception(
            Exception)

        self.listening_task = asyncio.get_event_loop().create_task(
            self.data_source.listen_for_user_stream(messages))
        # Add the authentication response for the websocket
        self.mocking_assistant.add_websocket_json_message(
            ws_connect_mock.return_value, self._authentication_response(False))

        try:
            asyncio.get_event_loop().run_until_complete(self.listening_task)
        except Exception:
            pass

        self.assertTrue(
            self._is_logged(
                "ERROR", "Error occurred when authenticating to user stream "
                "(Could not authenticate websocket connection with Bybit Perpetual)"
            ))
        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error with Bybit Perpetual WebSocket connection on"
                " wss://stream.bybit.com/realtime_private. Retrying in 30 seconds. "
                "(Could not authenticate websocket connection with Bybit Perpetual)"
            ))

    @patch('aiohttp.ClientSession.ws_connect', new_callable=AsyncMock)
    def test_listening_process_canceled_when_cancel_exception_during_initialization(
            self, ws_connect_mock):
        messages = asyncio.Queue()
        ws_connect_mock.side_effect = asyncio.CancelledError

        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = asyncio.get_event_loop().create_task(
                self.data_source.listen_for_user_stream(messages))
            asyncio.get_event_loop().run_until_complete(self.listening_task)

    @patch('aiohttp.ClientSession.ws_connect', new_callable=AsyncMock)
    def test_listening_process_canceled_when_cancel_exception_during_authentication(
            self, ws_connect_mock):
        messages = asyncio.Queue()
        ws_connect_mock.return_value = self.mocking_assistant.create_websocket_mock(
        )
        ws_connect_mock.return_value.send_json.side_effect = lambda sent_message: (
            self._raise_exception(asyncio.CancelledError)
            if CONSTANTS.WS_AUTHENTICATE_USER_ENDPOINT_NAME in sent_message[
                "op"] else self.mocking_assistant.add_websocket_json_message(
                    ws_connect_mock.return_value, sent_message))

        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = asyncio.get_event_loop().create_task(
                self.data_source.listen_for_user_stream(messages))
            asyncio.get_event_loop().run_until_complete(self.listening_task)

    @patch('aiohttp.ClientSession.ws_connect', new_callable=AsyncMock)
    def test_listening_process_canceled_when_cancel_exception_during_positions_subscription(
            self, ws_connect_mock):
        messages = asyncio.Queue()
        ws_connect_mock.return_value = self.mocking_assistant.create_websocket_mock(
        )
        ws_connect_mock.return_value.send_json.side_effect = lambda sent_message: (
            self._raise_exception(asyncio.CancelledError) if CONSTANTS.
            WS_SUBSCRIPTION_POSITIONS_ENDPOINT_NAME in sent_message[
                "args"] else self.mocking_assistant.add_websocket_json_message(
                    ws_connect_mock.return_value, sent_message))

        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = asyncio.get_event_loop().create_task(
                self.data_source.listen_for_user_stream(messages))
            # Add the authentication response for the websocket
            self.mocking_assistant.add_websocket_json_message(
                ws_connect_mock.return_value,
                self._authentication_response(True))
            asyncio.get_event_loop().run_until_complete(self.listening_task)

    @patch('aiohttp.ClientSession.ws_connect', new_callable=AsyncMock)
    def test_listening_process_canceled_when_cancel_exception_during_orders_subscription(
            self, ws_connect_mock):
        messages = asyncio.Queue()
        ws_connect_mock.return_value = self.mocking_assistant.create_websocket_mock(
        )
        ws_connect_mock.return_value.send_json.side_effect = lambda sent_message: (
            self._raise_exception(asyncio.CancelledError)
            if CONSTANTS.WS_SUBSCRIPTION_ORDERS_ENDPOINT_NAME in sent_message[
                "args"] else self.mocking_assistant.add_websocket_json_message(
                    ws_connect_mock.return_value, sent_message))

        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = asyncio.get_event_loop().create_task(
                self.data_source.listen_for_user_stream(messages))
            # Add the authentication response for the websocket
            self.mocking_assistant.add_websocket_json_message(
                ws_connect_mock.return_value,
                self._authentication_response(True))
            asyncio.get_event_loop().run_until_complete(self.listening_task)

    @patch('aiohttp.ClientSession.ws_connect', new_callable=AsyncMock)
    def test_listening_process_canceled_when_cancel_exception_during_executions_subscription(
            self, ws_connect_mock):
        messages = asyncio.Queue()
        ws_connect_mock.return_value = self.mocking_assistant.create_websocket_mock(
        )
        ws_connect_mock.return_value.send_json.side_effect = lambda sent_message: (
            self._raise_exception(asyncio.CancelledError) if CONSTANTS.
            WS_SUBSCRIPTION_EXECUTIONS_ENDPOINT_NAME in sent_message[
                "args"] else self.mocking_assistant.add_websocket_json_message(
                    ws_connect_mock.return_value, sent_message))

        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = asyncio.get_event_loop().create_task(
                self.data_source.listen_for_user_stream(messages))
            # Add the authentication response for the websocket
            self.mocking_assistant.add_websocket_json_message(
                ws_connect_mock.return_value,
                self._authentication_response(True))
            asyncio.get_event_loop().run_until_complete(self.listening_task)

    @patch('aiohttp.ClientSession.ws_connect', new_callable=AsyncMock)
    def test_listening_process_logs_exception_details_during_initialization(
            self, ws_connect_mock):
        ws_connect_mock.side_effect = Exception

        with self.assertRaises(Exception):
            self.listening_task = asyncio.get_event_loop().create_task(
                self.data_source._create_websocket_connection("test_url"))
            asyncio.get_event_loop().run_until_complete(self.listening_task)
        self.assertTrue(
            self._is_logged(
                "NETWORK",
                "Unexpected error occurred during bybit_perpetual WebSocket Connection on test_url ()"
            ))

    @patch('aiohttp.ClientSession.ws_connect', new_callable=AsyncMock)
    def test_listening_process_logs_exception_details_during_authentication(
            self, ws_connect_mock):
        messages = asyncio.Queue()
        ws_connect_mock.return_value = self.mocking_assistant.create_websocket_mock(
        )
        ws_connect_mock.return_value.send_json.side_effect = lambda sent_message: (
            self._raise_exception(Exception)
            if CONSTANTS.WS_AUTHENTICATE_USER_ENDPOINT_NAME in sent_message[
                "op"] else self.mocking_assistant.add_websocket_json_message(
                    sent_message))
        # Make the close function raise an exception to finish the execution
        ws_connect_mock.return_value.close.side_effect = lambda: self._raise_exception(
            Exception)

        try:
            self.listening_task = asyncio.get_event_loop().create_task(
                self.data_source.listen_for_user_stream(messages))
            asyncio.get_event_loop().run_until_complete(self.listening_task)
        except Exception:
            pass

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Error occurred when authenticating to user stream ()"))
        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error with Bybit Perpetual WebSocket connection on"
                " wss://stream.bybit.com/realtime_private. Retrying in 30 seconds. ()"
            ))

    @patch('aiohttp.ClientSession.ws_connect', new_callable=AsyncMock)
    def test_listening_process_logs_exception_during_positions_subscription(
            self, ws_connect_mock):
        messages = asyncio.Queue()
        ws_connect_mock.return_value = self.mocking_assistant.create_websocket_mock(
        )
        ws_connect_mock.return_value.send_json.side_effect = lambda sent_message: (
            self._raise_exception(Exception) if CONSTANTS.
            WS_SUBSCRIPTION_POSITIONS_ENDPOINT_NAME in sent_message[
                "args"] else self.mocking_assistant.add_websocket_json_message(
                    ws_connect_mock.return_value, sent_message))
        # Make the close function raise an exception to finish the execution
        ws_connect_mock.return_value.close.side_effect = lambda: self._raise_exception(
            Exception)

        try:
            self.listening_task = asyncio.get_event_loop().create_task(
                self.data_source.listen_for_user_stream(messages))
            # Add the authentication response for the websocket
            self.mocking_assistant.add_websocket_json_message(
                ws_connect_mock.return_value,
                self._authentication_response(True))
            asyncio.get_event_loop().run_until_complete(self.listening_task)
        except Exception:
            pass

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Error occurred subscribing to bybit_perpetual private channels ()"
            ))
        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error with Bybit Perpetual WebSocket connection on"
                " wss://stream.bybit.com/realtime_private. Retrying in 30 seconds. ()"
            ))

    @patch('aiohttp.ClientSession.ws_connect', new_callable=AsyncMock)
    def test_listening_process_logs_exception_during_orders_subscription(
            self, ws_connect_mock):
        messages = asyncio.Queue()
        ws_connect_mock.return_value = self.mocking_assistant.create_websocket_mock(
        )
        ws_connect_mock.return_value.send_json.side_effect = lambda sent_message: (
            self._raise_exception(Exception)
            if CONSTANTS.WS_SUBSCRIPTION_ORDERS_ENDPOINT_NAME in sent_message[
                "args"] else self.mocking_assistant.add_websocket_json_message(
                    ws_connect_mock.return_value, sent_message))
        # Make the close function raise an exception to finish the execution
        ws_connect_mock.return_value.close.side_effect = lambda: self._raise_exception(
            Exception)

        try:
            self.listening_task = asyncio.get_event_loop().create_task(
                self.data_source.listen_for_user_stream(messages))
            # Add the authentication response for the websocket
            self.mocking_assistant.add_websocket_json_message(
                ws_connect_mock.return_value,
                self._authentication_response(True))
            asyncio.get_event_loop().run_until_complete(self.listening_task)
        except Exception:
            pass

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Error occurred subscribing to bybit_perpetual private channels ()"
            ))
        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error with Bybit Perpetual WebSocket connection on"
                " wss://stream.bybit.com/realtime_private. Retrying in 30 seconds. ()"
            ))

    @patch('aiohttp.ClientSession.ws_connect', new_callable=AsyncMock)
    def test_listening_process_logs_exception_during_executions_subscription(
            self, ws_connect_mock):
        messages = asyncio.Queue()
        ws_connect_mock.return_value = self.mocking_assistant.create_websocket_mock(
        )
        ws_connect_mock.return_value.send_json.side_effect = lambda sent_message: (
            self._raise_exception(Exception) if CONSTANTS.
            WS_SUBSCRIPTION_EXECUTIONS_ENDPOINT_NAME in sent_message[
                "args"] else self.mocking_assistant.add_websocket_json_message(
                    ws_connect_mock.return_value, sent_message))
        # Make the close function raise an exception to finish the execution
        ws_connect_mock.return_value.close.side_effect = lambda: self._raise_exception(
            Exception)

        try:
            self.listening_task = asyncio.get_event_loop().create_task(
                self.data_source.listen_for_user_stream(messages))
            # Add the authentication response for the websocket
            self.mocking_assistant.add_websocket_json_message(
                ws_connect_mock.return_value,
                self._authentication_response(True))
            asyncio.get_event_loop().run_until_complete(self.listening_task)
        except Exception:
            pass

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Error occurred subscribing to bybit_perpetual private channels ()"
            ))
        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error with Bybit Perpetual WebSocket connection on"
                " wss://stream.bybit.com/realtime_private. Retrying in 30 seconds. ()"
            ))
Beispiel #4
0
class ProbitAPIUserStreamDataSourceTest(unittest.TestCase):
    # logging.Level required to receive logs from the data source logger
    level = 0

    @classmethod
    def setUpClass(cls) -> None:
        cls.base_asset = "BTC"
        cls.quote_asset = "USDT"
        cls.trading_pair = f"{cls.base_asset}-{cls.quote_asset}"

    def setUp(self) -> None:
        super().setUp()

        self.ev_loop = asyncio.get_event_loop()

        self.api_key = "someKey"
        self.api_secret = "someSecret"
        self.auth = ProbitAuth(self.api_key, self.api_secret)
        self.data_source = ProbitAPIUserStreamDataSource(
            self.auth, trading_pairs=[self.trading_pair])
        self.data_source.logger().setLevel(1)
        self.data_source.logger().addHandler(self)

        self.log_records = []
        self.mocking_assistant = NetworkMockingAssistant()

        self.async_task: Optional[asyncio.Task] = None

    def tearDown(self) -> None:
        self.async_task and self.async_task.cancel()
        super().tearDown()

    def handle(self, record):
        self.log_records.append(record)

    def check_is_logged(self, log_level: str, message: str) -> bool:
        return any(
            record.levelname == log_level and record.getMessage() == message
            for record in self.log_records)

    def async_run_with_timeout(self, coroutine: Awaitable, timeout: float = 1):
        ret = self.ev_loop.run_until_complete(
            asyncio.wait_for(coroutine, timeout))
        return ret

    @patch("aiohttp.client.ClientSession.ws_connect", new_callable=AsyncMock)
    @patch(
        "hummingbot.connector.exchange.probit.probit_auth.ProbitAuth.get_ws_auth_payload",
        new_callable=AsyncMock,
    )
    def test_listen_for_user_stream(self, get_ws_auth_payload_mock,
                                    ws_connect_mock):
        auth_msg = {"type": "authorization", "token": "someToken"}
        get_ws_auth_payload_mock.return_value = auth_msg

        ws_connect_mock.return_value = self.mocking_assistant.create_websocket_mock(
        )
        self.mocking_assistant.add_websocket_json_message(
            ws_connect_mock.return_value,
            message={"result": "ok"}  # authentication
        )
        self.mocking_assistant.add_websocket_aiohttp_message(
            ws_connect_mock.return_value,
            message=json.dumps({"my_msg": "test"})  # first message
        )

        output_queue = asyncio.Queue()
        self.async_task = self.ev_loop.create_task(
            self.data_source.listen_for_user_stream(self.ev_loop,
                                                    output_queue))

        self.mocking_assistant.run_until_all_json_messages_delivered(
            ws_connect_mock.return_value)
        self.mocking_assistant.run_until_all_aiohttp_messages_delivered(
            ws_connect_mock.return_value)

        self.assertFalse(output_queue.empty())

        sent_text_msgs = self.mocking_assistant.text_messages_sent_through_websocket(
            ws_connect_mock.return_value)
        self.assertEqual(auth_msg, json.loads(sent_text_msgs[0]))

        sent_json_msgs = self.mocking_assistant.json_messages_sent_through_websocket(
            ws_connect_mock.return_value)
        for sent_json_msg in sent_json_msgs:
            self.assertEqual("subscribe", sent_json_msg["type"])
            self.assertIn(sent_json_msg["channel"],
                          CONSTANTS.WS_PRIVATE_CHANNELS)
            CONSTANTS.WS_PRIVATE_CHANNELS.remove(sent_json_msg["channel"])

        self.assertEqual(0, len(CONSTANTS.WS_PRIVATE_CHANNELS))
        self.assertNotEqual(0, self.data_source.last_recv_time)

    @patch("aiohttp.client.ClientSession.ws_connect")
    @patch(
        "hummingbot.connector.exchange.probit.probit_api_user_stream_data_source.ProbitAPIUserStreamDataSource._sleep",
        new_callable=AsyncMock,
    )
    def test_listen_for_user_stream_attempts_again_on_exception(
            self, sleep_mock, ws_connect_mock):
        called_event = asyncio.Event()

        async def _sleep(delay):
            called_event.set()
            await asyncio.sleep(delay)

        sleep_mock.side_effect = _sleep

        ws_connect_mock.side_effect = Exception
        self.async_task = self.ev_loop.create_task(
            self.data_source.listen_for_user_stream(self.ev_loop,
                                                    asyncio.Queue()))

        self.async_run_with_timeout(called_event.wait())

        self.check_is_logged(
            log_level="ERROR",
            message=
            "Unexpected error with Probit WebSocket connection. Retrying after 30 seconds...",
        )

    @patch("aiohttp.client.ClientSession.ws_connect")
    def test_listen_for_user_stream_stops_on_asyncio_cancelled_error(
            self, ws_connect_mock):
        ws_connect_mock.side_effect = asyncio.CancelledError

        with self.assertRaises(asyncio.CancelledError):
            self.async_run_with_timeout(
                self.data_source.listen_for_user_stream(
                    self.ev_loop, asyncio.Queue()))

    @patch("aiohttp.client.ClientSession.ws_connect", new_callable=AsyncMock)
    @patch(
        "hummingbot.connector.exchange.probit.probit_auth.ProbitAuth.get_ws_auth_payload",
        new_callable=AsyncMock,
    )
    def test_listen_for_user_stream_registers_ping_msg(
            self, get_ws_auth_payload_mock, ws_connect_mock):
        auth_msg = {"type": "authorization", "token": "someToken"}
        get_ws_auth_payload_mock.return_value = auth_msg

        ws_connect_mock.return_value = self.mocking_assistant.create_websocket_mock(
        )
        self.mocking_assistant.add_websocket_json_message(
            ws_connect_mock.return_value,
            message={"result": "ok"}  # authentication
        )
        self.mocking_assistant.add_websocket_aiohttp_message(
            ws_connect_mock.return_value,
            message="",
            message_type=WSMsgType.PING)
        output_queue = asyncio.Queue()
        self.async_task = self.ev_loop.create_task(
            self.data_source.listen_for_user_stream(self.ev_loop,
                                                    output_queue))

        self.mocking_assistant.run_until_all_aiohttp_messages_delivered(
            ws_connect_mock.return_value)

        self.assertTrue(output_queue.empty())
        ws_connect_mock.return_value.pong.assert_called()