class TestKucoinAPIOrderBookDataSource(unittest.TestCase):
    # logging.Level required to receive logs from the data source logger
    level = 0

    @classmethod
    def setUpClass(cls) -> None:
        super().setUpClass()
        cls.ev_loop = asyncio.get_event_loop()
        cls.base_asset = "COINALPHA"
        cls.quote_asset = "HBOT"
        cls.trading_pair = f"{cls.base_asset}-{cls.quote_asset}"
        cls.ws_endpoint = "ws://someEndpoint"
        cls.api_key = "someKey"
        cls.api_passphrase = "somePassPhrase"
        cls.api_secret_key = "someSecretKey"

    def setUp(self) -> None:
        super().setUp()
        self.log_records = []
        self.async_task = None

        self.mocking_assistant = NetworkMockingAssistant()
        self.throttler = AsyncThrottler(CONSTANTS.RATE_LIMITS)
        self.time_synchronnizer = TimeSynchronizer()
        self.time_synchronnizer.add_time_offset_ms_sample(1000)
        self.ob_data_source = KucoinAPIOrderBookDataSource(
            trading_pairs=[self.trading_pair],
            throttler=self.throttler,
            time_synchronizer=self.time_synchronnizer)

        self.ob_data_source.logger().setLevel(1)
        self.ob_data_source.logger().addHandler(self)

        KucoinAPIOrderBookDataSource._trading_pair_symbol_map = {
            CONSTANTS.DEFAULT_DOMAIN:
            bidict({self.trading_pair: self.trading_pair})
        }

    def tearDown(self) -> None:
        self.async_task and self.async_task.cancel()
        KucoinAPIOrderBookDataSource._trading_pair_symbol_map = {}
        super().tearDown()

    def handle(self, record):
        self.log_records.append(record)

    def _is_logged(self, log_level: str, message: str) -> bool:
        return any(
            record.levelname == log_level and record.getMessage() == message
            for record in self.log_records)

    def async_run_with_timeout(self, coroutine: Awaitable, timeout: int = 1):
        ret = self.ev_loop.run_until_complete(
            asyncio.wait_for(coroutine, timeout))
        return ret

    @staticmethod
    def get_snapshot_mock() -> Dict:
        snapshot = {
            "code": "200000",
            "data": {
                "time": 1630556205455,
                "sequence": "1630556205456",
                "bids": [["0.3003", "4146.5645"]],
                "asks": [["0.3004", "1553.6412"]]
            }
        }
        return snapshot

    @aioresponses()
    def test_get_last_traded_prices(self, mock_api):
        KucoinAPIOrderBookDataSource._trading_pair_symbol_map[
            CONSTANTS.DEFAULT_DOMAIN]["TKN1-TKN2"] = "TKN1-TKN2"

        url1 = web_utils.rest_url(
            path_url=CONSTANTS.TICKER_PRICE_CHANGE_PATH_URL,
            domain=CONSTANTS.DEFAULT_DOMAIN)
        url1 = f"{url1}?symbol={self.trading_pair}"
        regex_url = re.compile(f"^{url1}".replace(".",
                                                  r"\.").replace("?", r"\?"))
        resp = {
            "code": "200000",
            "data": {
                "sequence": "1550467636704",
                "bestAsk": "0.03715004",
                "size": "0.17",
                "price": "100",
                "bestBidSize": "3.803",
                "bestBid": "0.03710768",
                "bestAskSize": "1.788",
                "time": 1550653727731
            }
        }
        mock_api.get(regex_url, body=json.dumps(resp))

        url2 = web_utils.rest_url(
            path_url=CONSTANTS.TICKER_PRICE_CHANGE_PATH_URL,
            domain=CONSTANTS.DEFAULT_DOMAIN)
        url2 = f"{url2}?symbol=TKN1-TKN2"
        regex_url = re.compile(f"^{url2}".replace(".",
                                                  r"\.").replace("?", r"\?"))
        resp = {
            "code": "200000",
            "data": {
                "sequence": "1550467636704",
                "bestAsk": "0.03715004",
                "size": "0.17",
                "price": "200",
                "bestBidSize": "3.803",
                "bestBid": "0.03710768",
                "bestAskSize": "1.788",
                "time": 1550653727731
            }
        }
        mock_api.get(regex_url, body=json.dumps(resp))

        ret = self.async_run_with_timeout(
            coroutine=KucoinAPIOrderBookDataSource.get_last_traded_prices(
                [self.trading_pair, "TKN1-TKN2"]))

        ticker_requests = [(key, value)
                           for key, value in mock_api.requests.items()
                           if key[1].human_repr().startswith(url1)
                           or key[1].human_repr().startswith(url2)]

        request_params = ticker_requests[0][1][0].kwargs["params"]
        self.assertEqual(f"{self.base_asset}-{self.quote_asset}",
                         request_params["symbol"])
        request_params = ticker_requests[1][1][0].kwargs["params"]
        self.assertEqual("TKN1-TKN2", request_params["symbol"])

        self.assertEqual(ret[self.trading_pair], 100)
        self.assertEqual(ret["TKN1-TKN2"], 200)

    @aioresponses()
    def test_fetch_trading_pairs(self, mock_api):
        KucoinAPIOrderBookDataSource._trading_pair_symbol_map = {}
        url = web_utils.rest_url(path_url=CONSTANTS.EXCHANGE_INFO_PATH_URL)

        resp = {
            "data": [{
                "symbol": self.trading_pair,
                "name": self.trading_pair,
                "baseCurrency": self.base_asset,
                "quoteCurrency": self.quote_asset,
                "enableTrading": True,
            }, {
                "symbol": "SOME-PAIR",
                "name": "SOME-PAIR",
                "baseCurrency": "SOME",
                "quoteCurrency": "PAIR",
                "enableTrading": False,
            }]
        }
        mock_api.get(url, body=json.dumps(resp))

        ret = self.async_run_with_timeout(
            coroutine=KucoinAPIOrderBookDataSource.fetch_trading_pairs(
                throttler=self.throttler,
                time_synchronizer=self.time_synchronnizer,
            ))

        self.assertEqual(1, len(ret))
        self.assertEqual(self.trading_pair, ret[0])

    @aioresponses()
    def test_fetch_trading_pairs_exception_raised(self, mock_api):
        KucoinAPIOrderBookDataSource._trading_pair_symbol_map = {}

        url = web_utils.rest_url(path_url=CONSTANTS.EXCHANGE_INFO_PATH_URL)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, exception=Exception)

        result: Dict[str] = self.async_run_with_timeout(
            self.ob_data_source.fetch_trading_pairs())

        self.assertEqual(0, len(result))

    @aioresponses()
    def test_get_snapshot_raises(self, mock_api):
        url = web_utils.rest_url(path_url=CONSTANTS.SNAPSHOT_NO_AUTH_PATH_URL)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))
        mock_api.get(regex_url, status=500)

        with self.assertRaises(IOError):
            self.async_run_with_timeout(
                coroutine=self.ob_data_source.get_snapshot(self.trading_pair))

    @aioresponses()
    def test_get_snapshot(self, mock_api):
        url = web_utils.rest_url(path_url=CONSTANTS.SNAPSHOT_NO_AUTH_PATH_URL)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))
        resp = self.get_snapshot_mock()
        mock_api.get(regex_url, body=json.dumps(resp))

        ret = self.async_run_with_timeout(
            coroutine=self.ob_data_source.get_snapshot(self.trading_pair))

        self.assertEqual(ret, resp)  # shallow comparison ok

    @aioresponses()
    def test_get_new_order_book(self, mock_api):
        url = web_utils.rest_url(path_url=CONSTANTS.SNAPSHOT_NO_AUTH_PATH_URL)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))
        resp = self.get_snapshot_mock()
        mock_api.get(regex_url, body=json.dumps(resp))

        ret = self.async_run_with_timeout(
            coroutine=self.ob_data_source.get_new_order_book(
                self.trading_pair))
        bid_entries = list(ret.bid_entries())
        ask_entries = list(ret.ask_entries())
        self.assertEqual(1, len(bid_entries))
        self.assertEqual(0.3003, bid_entries[0].price)
        self.assertEqual(4146.5645, bid_entries[0].amount)
        self.assertEqual(int(resp["data"]["sequence"]),
                         bid_entries[0].update_id)
        self.assertEqual(1, len(ask_entries))
        self.assertEqual(0.3004, ask_entries[0].price)
        self.assertEqual(1553.6412, ask_entries[0].amount)
        self.assertEqual(int(resp["data"]["sequence"]),
                         ask_entries[0].update_id)

    @aioresponses()
    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    @patch(
        "hummingbot.connector.exchange.kucoin.kucoin_web_utils.next_message_id"
    )
    def test_listen_for_subscriptions_subscribes_to_trades_and_order_diffs(
            self, mock_api, id_mock, ws_connect_mock):
        id_mock.side_effect = [1, 2]
        url = web_utils.rest_url(path_url=CONSTANTS.PUBLIC_WS_DATA_PATH_URL)

        resp = {
            "code": "200000",
            "data": {
                "instanceServers": [{
                    "endpoint": "wss://test.url/endpoint",
                    "protocol": "websocket",
                    "encrypt": True,
                    "pingInterval": 50000,
                    "pingTimeout": 10000
                }],
                "token":
                "testToken"
            }
        }
        mock_api.post(url, body=json.dumps(resp))

        ws_connect_mock.return_value = self.mocking_assistant.create_websocket_mock(
        )

        result_subscribe_trades = {"type": "ack", "id": 1}
        result_subscribe_diffs = {"type": "ack", "id": 2}

        self.mocking_assistant.add_websocket_aiohttp_message(
            websocket_mock=ws_connect_mock.return_value,
            message=json.dumps(result_subscribe_trades))
        self.mocking_assistant.add_websocket_aiohttp_message(
            websocket_mock=ws_connect_mock.return_value,
            message=json.dumps(result_subscribe_diffs))

        self.listening_task = self.ev_loop.create_task(
            self.ob_data_source.listen_for_subscriptions())

        self.mocking_assistant.run_until_all_aiohttp_messages_delivered(
            ws_connect_mock.return_value)

        sent_subscription_messages = self.mocking_assistant.json_messages_sent_through_websocket(
            websocket_mock=ws_connect_mock.return_value)

        self.assertEqual(2, len(sent_subscription_messages))
        expected_trade_subscription = {
            "id": 1,
            "type": "subscribe",
            "topic": f"/market/match:{self.trading_pair}",
            "privateChannel": False,
            "response": False
        }
        self.assertEqual(expected_trade_subscription,
                         sent_subscription_messages[0])
        expected_diff_subscription = {
            "id": 2,
            "type": "subscribe",
            "topic": f"/market/level2:{self.trading_pair}",
            "privateChannel": False,
            "response": False
        }
        self.assertEqual(expected_diff_subscription,
                         sent_subscription_messages[1])

        self.assertTrue(
            self._is_logged(
                "INFO",
                "Subscribed to public order book and trade channels..."))

    @aioresponses()
    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    @patch(
        "hummingbot.connector.exchange.kucoin.kucoin_web_utils.next_message_id"
    )
    @patch(
        "hummingbot.connector.exchange.kucoin.kucoin_api_order_book_data_source.KucoinAPIOrderBookDataSource._time"
    )
    def test_listen_for_subscriptions_sends_ping_message_before_ping_interval_finishes(
            self, mock_api, time_mock, id_mock, ws_connect_mock):

        id_mock.side_effect = [1, 2, 3, 4]
        time_mock.side_effect = [
            1000, 1100, 1101, 1102
        ]  # Simulate first ping interval is already due
        url = web_utils.rest_url(path_url=CONSTANTS.PUBLIC_WS_DATA_PATH_URL)

        resp = {
            "code": "200000",
            "data": {
                "instanceServers": [{
                    "endpoint": "wss://test.url/endpoint",
                    "protocol": "websocket",
                    "encrypt": True,
                    "pingInterval": 20000,
                    "pingTimeout": 10000
                }],
                "token":
                "testToken"
            }
        }
        mock_api.post(url, body=json.dumps(resp))

        ws_connect_mock.return_value = self.mocking_assistant.create_websocket_mock(
        )

        result_subscribe_trades = {"type": "ack", "id": 1}
        result_subscribe_diffs = {"type": "ack", "id": 2}

        self.mocking_assistant.add_websocket_aiohttp_message(
            websocket_mock=ws_connect_mock.return_value,
            message=json.dumps(result_subscribe_trades))
        self.mocking_assistant.add_websocket_aiohttp_message(
            websocket_mock=ws_connect_mock.return_value,
            message=json.dumps(result_subscribe_diffs))

        self.listening_task = self.ev_loop.create_task(
            self.ob_data_source.listen_for_subscriptions())

        self.mocking_assistant.run_until_all_aiohttp_messages_delivered(
            ws_connect_mock.return_value)

        sent_messages = self.mocking_assistant.json_messages_sent_through_websocket(
            websocket_mock=ws_connect_mock.return_value)

        expected_ping_message = {
            "id": 3,
            "type": "ping",
        }
        self.assertEqual(expected_ping_message, sent_messages[-1])

    @aioresponses()
    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    @patch(
        "hummingbot.core.data_type.order_book_tracker_data_source.OrderBookTrackerDataSource._sleep"
    )
    def test_listen_for_subscriptions_raises_cancel_exception(
            self, mock_api, _, ws_connect_mock):
        url = web_utils.rest_url(path_url=CONSTANTS.PUBLIC_WS_DATA_PATH_URL)

        resp = {
            "code": "200000",
            "data": {
                "instanceServers": [{
                    "endpoint": "wss://test.url/endpoint",
                    "protocol": "websocket",
                    "encrypt": True,
                    "pingInterval": 50000,
                    "pingTimeout": 10000
                }],
                "token":
                "testToken"
            }
        }
        mock_api.post(url, body=json.dumps(resp))

        ws_connect_mock.side_effect = asyncio.CancelledError

        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.ob_data_source.listen_for_subscriptions())
            self.async_run_with_timeout(self.listening_task)

    @aioresponses()
    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    @patch(
        "hummingbot.core.data_type.order_book_tracker_data_source.OrderBookTrackerDataSource._sleep"
    )
    def test_listen_for_subscriptions_logs_exception_details(
            self, mock_api, sleep_mock, ws_connect_mock):
        url = web_utils.rest_url(path_url=CONSTANTS.PUBLIC_WS_DATA_PATH_URL)

        resp = {
            "code": "200000",
            "data": {
                "instanceServers": [{
                    "endpoint": "wss://test.url/endpoint",
                    "protocol": "websocket",
                    "encrypt": True,
                    "pingInterval": 50000,
                    "pingTimeout": 10000
                }],
                "token":
                "testToken"
            }
        }
        mock_api.post(url, body=json.dumps(resp))

        sleep_mock.side_effect = asyncio.CancelledError
        ws_connect_mock.side_effect = Exception("TEST ERROR.")

        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.ob_data_source.listen_for_subscriptions())
            self.async_run_with_timeout(self.listening_task)

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error occurred when listening to order book streams. Retrying in 5 seconds..."
            ))

    def test_listen_for_trades_cancelled_when_listening(self):
        mock_queue = MagicMock()
        mock_queue.get.side_effect = asyncio.CancelledError()
        self.ob_data_source._message_queue[
            CONSTANTS.TRADE_EVENT_TYPE] = mock_queue

        msg_queue: asyncio.Queue = asyncio.Queue()

        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.ob_data_source.listen_for_trades(self.ev_loop, msg_queue))
            self.async_run_with_timeout(self.listening_task)

    def test_listen_for_trades_logs_exception(self):
        incomplete_resp = {
            "type": "message",
            "topic": f"/market/match:{self.trading_pair}",
        }

        mock_queue = AsyncMock()
        mock_queue.get.side_effect = [
            incomplete_resp, asyncio.CancelledError()
        ]
        self.ob_data_source._message_queue[
            CONSTANTS.TRADE_EVENT_TYPE] = mock_queue

        msg_queue: asyncio.Queue = asyncio.Queue()

        self.listening_task = self.ev_loop.create_task(
            self.ob_data_source.listen_for_trades(self.ev_loop, msg_queue))

        try:
            self.async_run_with_timeout(self.listening_task)
        except asyncio.CancelledError:
            pass

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error when processing public trade updates from exchange"
            ))

    def test_listen_for_trades_successful(self):
        mock_queue = AsyncMock()
        trade_event = {
            "type": "message",
            "topic": f"/market/match:{self.trading_pair}",
            "subject": "trade.l3match",
            "data": {
                "sequence": "1545896669145",
                "type": "match",
                "symbol": self.trading_pair,
                "side": "buy",
                "price": "0.08200000000000000000",
                "size": "0.01022222000000000000",
                "tradeId": "5c24c5da03aa673885cd67aa",
                "takerOrderId": "5c24c5d903aa6772d55b371e",
                "makerOrderId": "5c2187d003aa677bd09d5c93",
                "time": "1545913818099033203"
            }
        }
        mock_queue.get.side_effect = [trade_event, asyncio.CancelledError()]
        self.ob_data_source._message_queue[
            CONSTANTS.TRADE_EVENT_TYPE] = mock_queue

        msg_queue: asyncio.Queue = asyncio.Queue()

        try:
            self.listening_task = self.ev_loop.create_task(
                self.ob_data_source.listen_for_trades(self.ev_loop, msg_queue))
        except asyncio.CancelledError:
            pass

        msg: OrderBookMessage = self.async_run_with_timeout(msg_queue.get())

        self.assertTrue(trade_event["data"]["tradeId"], msg.trade_id)

    def test_listen_for_order_book_diffs_cancelled(self):
        mock_queue = AsyncMock()
        mock_queue.get.side_effect = asyncio.CancelledError()
        self.ob_data_source._message_queue[
            CONSTANTS.DIFF_EVENT_TYPE] = mock_queue

        msg_queue: asyncio.Queue = asyncio.Queue()

        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.ob_data_source.listen_for_order_book_diffs(
                    self.ev_loop, msg_queue))
            self.async_run_with_timeout(self.listening_task)

    def test_listen_for_order_book_diffs_logs_exception(self):
        incomplete_resp = {
            "type": "message",
            "topic": f"/market/level2:{self.trading_pair}",
        }

        mock_queue = AsyncMock()
        mock_queue.get.side_effect = [
            incomplete_resp, asyncio.CancelledError()
        ]
        self.ob_data_source._message_queue[
            CONSTANTS.DIFF_EVENT_TYPE] = mock_queue

        msg_queue: asyncio.Queue = asyncio.Queue()

        self.listening_task = self.ev_loop.create_task(
            self.ob_data_source.listen_for_order_book_diffs(
                self.ev_loop, msg_queue))

        try:
            self.async_run_with_timeout(self.listening_task)
        except asyncio.CancelledError:
            pass

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error when processing public order book updates from exchange"
            ))

    def test_listen_for_order_book_diffs_successful(self):
        mock_queue = AsyncMock()
        diff_event = {
            "type": "message",
            "topic": "/market/level2:BTC-USDT",
            "subject": "trade.l2update",
            "data": {
                "sequenceStart": 1545896669105,
                "sequenceEnd": 1545896669106,
                "symbol": f"{self.trading_pair}",
                "changes": {
                    "asks": [["6", "1", "1545896669105"]],
                    "bids": [["4", "1", "1545896669106"]]
                }
            }
        }
        mock_queue.get.side_effect = [diff_event, asyncio.CancelledError()]
        self.ob_data_source._message_queue[
            CONSTANTS.DIFF_EVENT_TYPE] = mock_queue

        msg_queue: asyncio.Queue = asyncio.Queue()

        try:
            self.listening_task = self.ev_loop.create_task(
                self.ob_data_source.listen_for_order_book_diffs(
                    self.ev_loop, msg_queue))
        except asyncio.CancelledError:
            pass

        msg: OrderBookMessage = self.async_run_with_timeout(msg_queue.get())

        self.assertTrue(diff_event["data"]["sequenceEnd"], msg.update_id)

    @aioresponses()
    def test_listen_for_order_book_snapshots_cancelled_when_fetching_snapshot(
            self, mock_api):
        url = web_utils.rest_url(path_url=CONSTANTS.SNAPSHOT_NO_AUTH_PATH_URL)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, exception=asyncio.CancelledError)

        with self.assertRaises(asyncio.CancelledError):
            self.async_run_with_timeout(
                self.ob_data_source.listen_for_order_book_snapshots(
                    self.ev_loop, asyncio.Queue()))

    @aioresponses()
    @patch(
        "hummingbot.connector.exchange.kucoin.kucoin_api_order_book_data_source"
        ".KucoinAPIOrderBookDataSource._sleep")
    def test_listen_for_order_book_snapshots_log_exception(
            self, mock_api, sleep_mock):
        msg_queue: asyncio.Queue = asyncio.Queue()
        sleep_mock.side_effect = asyncio.CancelledError

        url = web_utils.rest_url(path_url=CONSTANTS.SNAPSHOT_NO_AUTH_PATH_URL)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, exception=Exception)

        try:
            self.async_run_with_timeout(
                self.ob_data_source.listen_for_order_book_snapshots(
                    self.ev_loop, msg_queue))
        except asyncio.CancelledError:
            pass

        self.assertTrue(
            self._is_logged(
                "ERROR",
                f"Unexpected error fetching order book snapshot for {self.trading_pair}."
            ))

    @aioresponses()
    def test_listen_for_order_book_snapshots_successful(
        self,
        mock_api,
    ):
        msg_queue: asyncio.Queue = asyncio.Queue()
        url = web_utils.rest_url(path_url=CONSTANTS.SNAPSHOT_NO_AUTH_PATH_URL)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        snapshot_data = {
            "code": "200000",
            "data": {
                "sequence": "3262786978",
                "time": 1550653727731,
                "bids": [["6500.12", "0.45054140"], ["6500.11", "0.45054140"]],
                "asks": [["6500.16", "0.57753524"], ["6500.15", "0.57753524"]]
            }
        }

        mock_api.get(regex_url, body=json.dumps(snapshot_data))

        self.listening_task = self.ev_loop.create_task(
            self.ob_data_source.listen_for_order_book_snapshots(
                self.ev_loop, msg_queue))

        msg: OrderBookMessage = self.async_run_with_timeout(msg_queue.get())

        self.assertEqual(int(snapshot_data["data"]["sequence"]), msg.update_id)
Ejemplo n.º 2
0
class LatokenUserStreamDataSourceUnitTests(unittest.TestCase):
    # the level is required to receive logs from the data source logger
    level = 0

    @classmethod
    def setUpClass(cls) -> None:
        super().setUpClass()
        cls.ev_loop = asyncio.get_event_loop()
        cls.base_asset = "d8ae67f2-f954-4014-98c8-64b1ac334c64"
        cls.quote_asset = "0c3a106d-bde3-4c13-a26e-3fd2394529e5"
        cls.trading_pair = "ETH-USDT"
        cls.trading_pairs = [cls.trading_pair]
        cls.ex_trading_pair = cls.base_asset + '/' + cls.quote_asset
        cls.domain = "com"

        cls.listen_key = 'ffffffff-ffff-ffff-ffff-ffffffffff'

    def setUp(self) -> None:
        super().setUp()
        self.log_records = []
        self.listening_task: Optional[asyncio.Task] = None
        self.mocking_assistant = NetworkMockingAssistant()
        self.throttler = AsyncThrottler(rate_limits=CONSTANTS.RATE_LIMITS)
        self.mock_time_provider = MagicMock()
        self.mock_time_provider.time.return_value = 1000
        self.auth = LatokenAuth(api_key="TEST_API_KEY",
                                secret_key="TEST_SECRET",
                                time_provider=self.mock_time_provider)
        self.time_synchronizer = TimeSynchronizer()
        self.time_synchronizer.add_time_offset_ms_sample(0)
        client_config_map = ClientConfigAdapter(ClientConfigMap())
        self.connector = LatokenExchange(client_config_map=client_config_map,
                                         latoken_api_key="",
                                         latoken_api_secret="",
                                         trading_pairs=[],
                                         trading_required=False,
                                         domain=self.domain)
        self.connector._web_assistants_factory._auth = self.auth

        self.data_source = LatokenAPIUserStreamDataSource(
            auth=self.auth,
            trading_pairs=[self.trading_pair],
            connector=self.connector,
            api_factory=self.connector._web_assistants_factory,
            domain=self.domain)

        self.data_source.logger().setLevel(1)
        self.data_source.logger().addHandler(self)

        self.resume_test_event = asyncio.Event()

        self.connector._set_trading_pair_symbol_map(
            bidict({self.ex_trading_pair: self.trading_pair}))

    def tearDown(self) -> None:
        self.listening_task and self.listening_task.cancel()
        super().tearDown()

    def handle(self, record):
        self.log_records.append(record)

    def _is_logged(self, log_level: str, message: str) -> bool:
        return any(
            record.levelname == log_level and record.getMessage() == message
            for record in self.log_records)

    def _raise_exception(self, exception_class):
        raise exception_class

    def _create_exception_and_unlock_test_with_event(self, exception):
        self.resume_test_event.set()
        raise exception

    def _create_return_value_and_unlock_test_with_event(self, value):
        self.resume_test_event.set()
        return value

    def async_run_with_timeout(self, coroutine: Awaitable, timeout: float = 1):
        ret = self.ev_loop.run_until_complete(
            asyncio.wait_for(coroutine, timeout))
        return ret

    def _error_response(self) -> Dict[str, Any]:
        resp = {"code": "ERROR CODE", "msg": "ERROR MESSAGE"}

        return resp

    def _user_update_event(self):
        # Balance Update, so not the initial balance
        return b'MESSAGE\ndestination:/user/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/v1/account\nmessage-id:9e8188c8-682c-41cd-9a14-722bf6dfd99e\ncontent-length:346\nsubscription:2\n\n{"payload":[{"id":"44d36460-46dc-4828-a17c-63b1a047b054","status":"ACCOUNT_STATUS_ACTIVE","type":"ACCOUNT_TYPE_SPOT","timestamp":1650120265819,"currency":"620f2019-33c0-423b-8a9d-cde4d7f8ef7f","available":"34.001000000000000000","blocked":"0.999000000000000000","user":"******"}],"nonce":1,"timestamp":1650120265830}\x00'

    def _successfully_subscribed_event(self):
        return b'CONNECTED\nserver:vertx-stomp/3.9.6\nheart-beat:1000,1000\nsession:37a8e962-7fa7-4eab-b163-146eeafdef63\nversion:1.1\n\n\x00 '

    @aioresponses()
    def test_get_listen_key_log_exception(self, mock_api):
        url = web_utils.private_rest_url(path_url=CONSTANTS.USER_ID_PATH_URL,
                                         domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url,
                     status=400,
                     body=json.dumps(self._error_response()))

        with self.assertRaises(IOError):
            self.async_run_with_timeout(self.data_source._get_listen_key())

    @aioresponses()
    def test_get_listen_key_successful(self, mock_api):
        url = web_utils.private_rest_url(path_url=CONSTANTS.USER_ID_PATH_URL,
                                         domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_response = {
            'id': 'ffffffff-ffff-ffff-ffff-ffffffffff',
            'status': 'ACTIVE',
            'role': 'INVESTOR',
            'email': '*****@*****.**',
            'phone': '',
            'authorities': [],
            'forceChangePassword': None,
            'authType': 'API_KEY',
            'socials': []
        }

        mock_api.get(regex_url, body=json.dumps(mock_response))

        result: str = self.async_run_with_timeout(
            self.data_source._get_listen_key())

        self.assertEqual(self.listen_key, result)

    @aioresponses()
    def test_ping_listen_key_log_exception(self, mock_api):
        url = web_utils.private_rest_url(path_url=CONSTANTS.USER_ID_PATH_URL,
                                         domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url,
                     status=400,
                     body=json.dumps(self._error_response()))

        self.data_source._current_listen_key = self.listen_key
        result: bool = self.async_run_with_timeout(
            self.data_source._ping_listen_key())

        self.assertTrue(
            self._is_logged(
                "WARNING",
                f"Failed to refresh the listen key {self.listen_key}: {self._error_response()}"
            ))
        self.assertFalse(result)

    @aioresponses()
    def test_ping_listen_key_successful(self, mock_api):
        url = web_utils.private_rest_url(path_url=CONSTANTS.USER_ID_PATH_URL,
                                         domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))
        mock_response = {
            'id': 'ffffffff-ffff-ffff-ffff-ffffffffff',
            'status': 'ACTIVE',
            'role': 'INVESTOR',
            'email': '*****@*****.**',
            'phone': '',
            'authorities': [],
            'forceChangePassword': None,
            'authType': 'API_KEY',
            'socials': []
        }

        mock_api.get(regex_url, body=json.dumps(mock_response))

        self.data_source._current_listen_key = self.listen_key
        result: bool = self.async_run_with_timeout(
            self.data_source._ping_listen_key())
        self.assertTrue(result)

    @patch(
        "hummingbot.connector.exchange.latoken.latoken_api_user_stream_data_source.LatokenAPIUserStreamDataSource"
        "._ping_listen_key",
        new_callable=AsyncMock)
    def test_manage_listen_key_task_loop_keep_alive_failed(
            self, mock_ping_listen_key):
        mock_ping_listen_key.side_effect = (
            lambda *args, **kwargs: self.
            _create_return_value_and_unlock_test_with_event(False))

        self.data_source._current_listen_key = self.listen_key

        # Simulate LISTEN_KEY_KEEP_ALIVE_INTERVAL reached
        self.data_source._last_listen_key_ping_ts = 0

        self.listening_task = self.ev_loop.create_task(
            self.data_source._manage_listen_key_task_loop())

        self.async_run_with_timeout(self.resume_test_event.wait())

        self.assertTrue(
            self._is_logged("ERROR", "Error occurred renewing listen key ..."))
        self.assertIsNone(self.data_source._current_listen_key)
        self.assertFalse(
            self.data_source._listen_key_initialized_event.is_set())

    @patch(
        "hummingbot.connector.exchange.latoken.latoken_api_user_stream_data_source.LatokenAPIUserStreamDataSource."
        "_ping_listen_key",
        new_callable=AsyncMock)
    def test_manage_listen_key_task_loop_keep_alive_successful(
            self, mock_ping_listen_key):
        mock_ping_listen_key.side_effect = (
            lambda *args, **kwargs: self.
            _create_return_value_and_unlock_test_with_event(True))

        # Simulate LISTEN_KEY_KEEP_ALIVE_INTERVAL reached
        self.data_source._current_listen_key = self.listen_key
        self.data_source._listen_key_initialized_event.set()
        self.data_source._last_listen_key_ping_ts = 0

        self.listening_task = self.ev_loop.create_task(
            self.data_source._manage_listen_key_task_loop())

        self.async_run_with_timeout(self.resume_test_event.wait())

        self.assertTrue(
            self._is_logged("INFO",
                            f"Refreshed listen key {self.listen_key}."))
        self.assertGreater(self.data_source._last_listen_key_ping_ts, 0)

    @aioresponses()
    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_user_stream_get_listen_key_successful_with_user_update_event(
            self, mock_api, mock_ws):
        url = web_utils.private_rest_url(path_url=CONSTANTS.USER_ID_PATH_URL,
                                         domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_response = {
            'id': 'ffffffff-ffff-ffff-ffff-ffffffffff',
            'status': 'ACTIVE',
            'role': 'INVESTOR',
            'email': '*****@*****.**',
            'phone': '',
            'authorities': [],
            'forceChangePassword': None,
            'authType': 'API_KEY',
            'socials': []
        }

        mock_api.get(regex_url, body=json.dumps(mock_response))

        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        self.mocking_assistant.add_websocket_aiohttp_message(
            mock_ws.return_value,
            self._successfully_subscribed_event(),
            message_type=aiohttp.WSMsgType.BINARY)
        self.mocking_assistant.add_websocket_aiohttp_message(
            mock_ws.return_value,
            self._user_update_event(),
            message_type=aiohttp.WSMsgType.BINARY)

        msg_queue = asyncio.Queue()
        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_user_stream(msg_queue))

        msg = self.async_run_with_timeout(msg_queue.get())
        self.assertTrue(msg, self._user_update_event)

    @aioresponses()
    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_user_stream_does_not_queue_empty_payload(
            self, mock_api, mock_ws):
        url = web_utils.private_rest_url(path_url=CONSTANTS.USER_ID_PATH_URL,
                                         domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_response = {
            'id': 'ffffffff-ffff-ffff-ffff-ffffffffff',
            'status': 'ACTIVE',
            'role': 'INVESTOR',
            'email': '*****@*****.**',
            'phone': '',
            'authorities': [],
            'forceChangePassword': None,
            'authType': 'API_KEY',
            'socials': []
        }
        mock_api.get(regex_url, body=json.dumps(mock_response))

        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        self.mocking_assistant.add_websocket_aiohttp_message(
            mock_ws.return_value, "")

        msg_queue = asyncio.Queue()
        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_user_stream(msg_queue))

        self.mocking_assistant.run_until_all_aiohttp_messages_delivered(
            mock_ws.return_value)

        self.assertEqual(0, msg_queue.qsize())

    @aioresponses()
    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_user_stream_connection_failed(self, mock_api, mock_ws):
        url = web_utils.private_rest_url(path_url=CONSTANTS.USER_ID_PATH_URL,
                                         domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_response = {
            'id': 'ffffffff-ffff-ffff-ffff-ffffffffff',
            'status': 'ACTIVE',
            'role': 'INVESTOR',
            'email': '*****@*****.**',
            'phone': '',
            'authorities': [],
            'forceChangePassword': None,
            'authType': 'API_KEY',
            'socials': []
        }
        mock_api.get(regex_url, body=json.dumps(mock_response))

        mock_ws.side_effect = lambda *arg, **kwars: self._create_exception_and_unlock_test_with_event(
            Exception("TEST ERROR."))

        msg_queue = asyncio.Queue()
        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_user_stream(msg_queue))

        self.async_run_with_timeout(self.resume_test_event.wait())

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error while listening to user stream. Retrying after 5 seconds..."
            ))

    @aioresponses()
    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_user_stream_iter_message_throws_exception(
            self, mock_api, mock_ws):
        url = web_utils.private_rest_url(path_url=CONSTANTS.USER_ID_PATH_URL,
                                         domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_response = {
            'id': 'ffffffff-ffff-ffff-ffff-ffffffffff',
            'status': 'ACTIVE',
            'role': 'INVESTOR',
            'email': '*****@*****.**',
            'phone': '',
            'authorities': [],
            'forceChangePassword': None,
            'authType': 'API_KEY',
            'socials': []
        }
        mock_api.get(regex_url, body=json.dumps(mock_response))

        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        mock_ws.return_value.receive.side_effect = (
            lambda *args, **kwargs: self.
            _create_exception_and_unlock_test_with_event(
                Exception("TEST ERROR")))
        mock_ws.close.return_value = None

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_user_stream(msg_queue))

        self.async_run_with_timeout(self.resume_test_event.wait())

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error while listening to user stream. Retrying after 5 seconds..."
            ))
Ejemplo n.º 3
0
class BinanceAPIOrderBookDataSourceUnitTests(unittest.TestCase):
    # logging.Level required to receive logs from the data source logger
    level = 0

    @classmethod
    def setUpClass(cls) -> None:
        super().setUpClass()
        cls.ev_loop = asyncio.get_event_loop()
        cls.base_asset = "COINALPHA"
        cls.quote_asset = "HBOT"
        cls.trading_pair = f"{cls.base_asset}-{cls.quote_asset}"
        cls.ex_trading_pair = cls.base_asset + cls.quote_asset
        cls.domain = "com"

    def setUp(self) -> None:
        super().setUp()
        self.log_records = []
        self.listening_task = None
        self.mocking_assistant = NetworkMockingAssistant()
        self.time_synchronizer = TimeSynchronizer()
        self.time_synchronizer.add_time_offset_ms_sample(1000)

        self.throttler = AsyncThrottler(rate_limits=CONSTANTS.RATE_LIMITS)
        self.data_source = BinanceAPIOrderBookDataSource(
            trading_pairs=[self.trading_pair],
            throttler=self.throttler,
            domain=self.domain,
            time_synchronizer=self.time_synchronizer)
        self.data_source.logger().setLevel(1)
        self.data_source.logger().addHandler(self)

        self.resume_test_event = asyncio.Event()

        BinanceAPIOrderBookDataSource._trading_pair_symbol_map = {
            "com":
            bidict({f"{self.base_asset}{self.quote_asset}": self.trading_pair})
        }

    def tearDown(self) -> None:
        self.listening_task and self.listening_task.cancel()
        BinanceAPIOrderBookDataSource._trading_pair_symbol_map = {}
        super().tearDown()

    def handle(self, record):
        self.log_records.append(record)

    def _is_logged(self, log_level: str, message: str) -> bool:
        return any(
            record.levelname == log_level and record.getMessage() == message
            for record in self.log_records)

    def _create_exception_and_unlock_test_with_event(self, exception):
        self.resume_test_event.set()
        raise exception

    def async_run_with_timeout(self, coroutine: Awaitable, timeout: float = 1):
        ret = self.ev_loop.run_until_complete(
            asyncio.wait_for(coroutine, timeout))
        return ret

    def _successfully_subscribed_event(self):
        resp = {"result": None, "id": 1}
        return resp

    def _trade_update_event(self):
        resp = {
            "e": "trade",
            "E": 123456789,
            "s": self.ex_trading_pair,
            "t": 12345,
            "p": "0.001",
            "q": "100",
            "b": 88,
            "a": 50,
            "T": 123456785,
            "m": True,
            "M": True
        }
        return resp

    def _order_diff_event(self):
        resp = {
            "e": "depthUpdate",
            "E": 123456789,
            "s": self.ex_trading_pair,
            "U": 157,
            "u": 160,
            "b": [["0.0024", "10"]],
            "a": [["0.0026", "100"]]
        }
        return resp

    def _snapshot_response(self):
        resp = {
            "lastUpdateId": 1027024,
            "bids": [["4.00000000", "431.00000000"]],
            "asks": [["4.00000200", "12.00000000"]]
        }
        return resp

    @aioresponses()
    def test_get_last_trade_prices(self, mock_api):
        url = web_utils.public_rest_url(
            path_url=CONSTANTS.TICKER_PRICE_CHANGE_PATH_URL,
            domain=self.domain)
        url = f"{url}?symbol={self.base_asset}{self.quote_asset}"
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_response = {
            "symbol": "BNBBTC",
            "priceChange": "-94.99999800",
            "priceChangePercent": "-95.960",
            "weightedAvgPrice": "0.29628482",
            "prevClosePrice": "0.10002000",
            "lastPrice": "100.0",
            "lastQty": "200.00000000",
            "bidPrice": "4.00000000",
            "bidQty": "100.00000000",
            "askPrice": "4.00000200",
            "askQty": "100.00000000",
            "openPrice": "99.00000000",
            "highPrice": "100.00000000",
            "lowPrice": "0.10000000",
            "volume": "8913.30000000",
            "quoteVolume": "15.30000000",
            "openTime": 1499783499040,
            "closeTime": 1499869899040,
            "firstId": 28385,
            "lastId": 28460,
            "count": 76,
        }

        mock_api.get(regex_url, body=json.dumps(mock_response))

        result: Dict[str, float] = self.async_run_with_timeout(
            self.data_source.get_last_traded_prices(
                trading_pairs=[self.trading_pair],
                throttler=self.throttler,
                time_synchronizer=self.time_synchronizer))

        self.assertEqual(1, len(result))
        self.assertEqual(100, result[self.trading_pair])

    @aioresponses()
    def test_get_all_mid_prices(self, mock_api):
        url = web_utils.public_rest_url(CONSTANTS.SERVER_TIME_PATH_URL,
                                        domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        response = {"serverTime": 1640000003000}

        mock_api.get(regex_url, body=json.dumps(response))

        url = web_utils.public_rest_url(
            path_url=CONSTANTS.TICKER_PRICE_CHANGE_PATH_URL,
            domain=self.domain)

        mock_response: List[Dict[str, Any]] = [
            {
                # Truncated Response
                "symbol": self.ex_trading_pair,
                "bidPrice": "99",
                "askPrice": "101",
            },
            {
                # Truncated Response for unrecognized pair
                "symbol": "BCCBTC",
                "bidPrice": "99",
                "askPrice": "101",
            }
        ]

        mock_api.get(url, body=json.dumps(mock_response))

        result: Dict[str, float] = self.async_run_with_timeout(
            self.data_source.get_all_mid_prices())

        self.assertEqual(1, len(result))
        self.assertEqual(100, result[self.trading_pair])

    @aioresponses()
    def test_fetch_trading_pairs(self, mock_api):
        BinanceAPIOrderBookDataSource._trading_pair_symbol_map = {}
        url = web_utils.public_rest_url(
            path_url=CONSTANTS.EXCHANGE_INFO_PATH_URL, domain=self.domain)

        mock_response: Dict[str, Any] = {
            "timezone":
            "UTC",
            "serverTime":
            1639598493658,
            "rateLimits": [],
            "exchangeFilters": [],
            "symbols": [
                {
                    "symbol":
                    "ETHBTC",
                    "status":
                    "TRADING",
                    "baseAsset":
                    "ETH",
                    "baseAssetPrecision":
                    8,
                    "quoteAsset":
                    "BTC",
                    "quotePrecision":
                    8,
                    "quoteAssetPrecision":
                    8,
                    "baseCommissionPrecision":
                    8,
                    "quoteCommissionPrecision":
                    8,
                    "orderTypes": [
                        "LIMIT", "LIMIT_MAKER", "MARKET", "STOP_LOSS_LIMIT",
                        "TAKE_PROFIT_LIMIT"
                    ],
                    "icebergAllowed":
                    True,
                    "ocoAllowed":
                    True,
                    "quoteOrderQtyMarketAllowed":
                    True,
                    "isSpotTradingAllowed":
                    True,
                    "isMarginTradingAllowed":
                    True,
                    "filters": [],
                    "permissions": ["SPOT", "MARGIN"]
                },
                {
                    "symbol":
                    "LTCBTC",
                    "status":
                    "TRADING",
                    "baseAsset":
                    "LTC",
                    "baseAssetPrecision":
                    8,
                    "quoteAsset":
                    "BTC",
                    "quotePrecision":
                    8,
                    "quoteAssetPrecision":
                    8,
                    "baseCommissionPrecision":
                    8,
                    "quoteCommissionPrecision":
                    8,
                    "orderTypes": [
                        "LIMIT", "LIMIT_MAKER", "MARKET", "STOP_LOSS_LIMIT",
                        "TAKE_PROFIT_LIMIT"
                    ],
                    "icebergAllowed":
                    True,
                    "ocoAllowed":
                    True,
                    "quoteOrderQtyMarketAllowed":
                    True,
                    "isSpotTradingAllowed":
                    True,
                    "isMarginTradingAllowed":
                    True,
                    "filters": [],
                    "permissions": ["SPOT", "MARGIN"]
                },
                {
                    "symbol":
                    "BNBBTC",
                    "status":
                    "TRADING",
                    "baseAsset":
                    "BNB",
                    "baseAssetPrecision":
                    8,
                    "quoteAsset":
                    "BTC",
                    "quotePrecision":
                    8,
                    "quoteAssetPrecision":
                    8,
                    "baseCommissionPrecision":
                    8,
                    "quoteCommissionPrecision":
                    8,
                    "orderTypes": [
                        "LIMIT", "LIMIT_MAKER", "MARKET", "STOP_LOSS_LIMIT",
                        "TAKE_PROFIT_LIMIT"
                    ],
                    "icebergAllowed":
                    True,
                    "ocoAllowed":
                    True,
                    "quoteOrderQtyMarketAllowed":
                    True,
                    "isSpotTradingAllowed":
                    True,
                    "isMarginTradingAllowed":
                    True,
                    "filters": [],
                    "permissions": ["MARGIN"]
                },
            ]
        }

        mock_api.get(url, body=json.dumps(mock_response))

        result: Dict[str] = self.async_run_with_timeout(
            self.data_source.fetch_trading_pairs(
                time_synchronizer=self.time_synchronizer))

        self.assertEqual(2, len(result))
        self.assertIn("ETH-BTC", result)
        self.assertIn("LTC-BTC", result)
        self.assertNotIn("BNB-BTC", result)

    @aioresponses()
    def test_fetch_trading_pairs_exception_raised(self, mock_api):
        BinanceAPIOrderBookDataSource._trading_pair_symbol_map = {}

        url = web_utils.public_rest_url(
            path_url=CONSTANTS.EXCHANGE_INFO_PATH_URL, domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, exception=Exception)

        result: Dict[str] = self.async_run_with_timeout(
            self.data_source.fetch_trading_pairs(
                time_synchronizer=self.time_synchronizer))

        self.assertEqual(0, len(result))

    @aioresponses()
    def test_get_snapshot_successful(self, mock_api):
        url = web_utils.public_rest_url(path_url=CONSTANTS.SNAPSHOT_PATH_URL,
                                        domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, body=json.dumps(self._snapshot_response()))

        result: Dict[str, Any] = self.async_run_with_timeout(
            self.data_source.get_snapshot(self.trading_pair))

        self.assertEqual(self._snapshot_response(), result)

    @aioresponses()
    def test_get_snapshot_catch_exception(self, mock_api):
        url = web_utils.public_rest_url(path_url=CONSTANTS.SNAPSHOT_PATH_URL,
                                        domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, status=400)
        with self.assertRaises(IOError):
            self.async_run_with_timeout(
                self.data_source.get_snapshot(self.trading_pair))

    @aioresponses()
    def test_get_new_order_book(self, mock_api):
        url = web_utils.public_rest_url(path_url=CONSTANTS.SNAPSHOT_PATH_URL,
                                        domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_response: Dict[str, Any] = {
            "lastUpdateId": 1,
            "bids": [["4.00000000", "431.00000000"]],
            "asks": [["4.00000200", "12.00000000"]]
        }
        mock_api.get(regex_url, body=json.dumps(mock_response))

        result: OrderBook = self.async_run_with_timeout(
            self.data_source.get_new_order_book(self.trading_pair))

        self.assertEqual(1, result.snapshot_uid)

    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_subscriptions_subscribes_to_trades_and_order_diffs(
            self, ws_connect_mock):
        ws_connect_mock.return_value = self.mocking_assistant.create_websocket_mock(
        )

        result_subscribe_trades = {"result": None, "id": 1}
        result_subscribe_diffs = {"result": None, "id": 2}

        self.mocking_assistant.add_websocket_aiohttp_message(
            websocket_mock=ws_connect_mock.return_value,
            message=json.dumps(result_subscribe_trades))
        self.mocking_assistant.add_websocket_aiohttp_message(
            websocket_mock=ws_connect_mock.return_value,
            message=json.dumps(result_subscribe_diffs))

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_subscriptions())

        self.mocking_assistant.run_until_all_aiohttp_messages_delivered(
            ws_connect_mock.return_value)

        sent_subscription_messages = self.mocking_assistant.json_messages_sent_through_websocket(
            websocket_mock=ws_connect_mock.return_value)

        self.assertEqual(2, len(sent_subscription_messages))
        expected_trade_subscription = {
            "method": "SUBSCRIBE",
            "params": [f"{self.ex_trading_pair.lower()}@trade"],
            "id": 1
        }
        self.assertEqual(expected_trade_subscription,
                         sent_subscription_messages[0])
        expected_diff_subscription = {
            "method": "SUBSCRIBE",
            "params": [f"{self.ex_trading_pair.lower()}@depth@100ms"],
            "id": 2
        }
        self.assertEqual(expected_diff_subscription,
                         sent_subscription_messages[1])

        self.assertTrue(
            self._is_logged(
                "INFO",
                "Subscribed to public order book and trade channels..."))

    @patch(
        "hummingbot.core.data_type.order_book_tracker_data_source.OrderBookTrackerDataSource._sleep"
    )
    @patch("aiohttp.ClientSession.ws_connect")
    def test_listen_for_subscriptions_raises_cancel_exception(
            self, mock_ws, _: AsyncMock):
        mock_ws.side_effect = asyncio.CancelledError

        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.data_source.listen_for_subscriptions())
            self.async_run_with_timeout(self.listening_task)

    @patch(
        "hummingbot.core.data_type.order_book_tracker_data_source.OrderBookTrackerDataSource._sleep"
    )
    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_subscriptions_logs_exception_details(
            self, mock_ws, sleep_mock):
        mock_ws.side_effect = Exception("TEST ERROR.")
        sleep_mock.side_effect = lambda _: self._create_exception_and_unlock_test_with_event(
            asyncio.CancelledError())

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_subscriptions())

        self.async_run_with_timeout(self.resume_test_event.wait())

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error occurred when listening to order book streams. Retrying in 5 seconds..."
            ))

    def test_subscribe_channels_raises_cancel_exception(self):
        mock_ws = MagicMock()
        mock_ws.send.side_effect = asyncio.CancelledError

        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.data_source._subscribe_channels(mock_ws))
            self.async_run_with_timeout(self.listening_task)

    def test_subscribe_channels_raises_exception_and_logs_error(self):
        mock_ws = MagicMock()
        mock_ws.send.side_effect = Exception("Test Error")

        with self.assertRaises(Exception):
            self.listening_task = self.ev_loop.create_task(
                self.data_source._subscribe_channels(mock_ws))
            self.async_run_with_timeout(self.listening_task)

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error occurred subscribing to order book trading and delta streams..."
            ))

    def test_listen_for_trades_cancelled_when_listening(self):
        mock_queue = MagicMock()
        mock_queue.get.side_effect = asyncio.CancelledError()
        self.data_source._message_queue[
            CONSTANTS.TRADE_EVENT_TYPE] = mock_queue

        msg_queue: asyncio.Queue = asyncio.Queue()

        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.data_source.listen_for_trades(self.ev_loop, msg_queue))
            self.async_run_with_timeout(self.listening_task)

    def test_listen_for_trades_logs_exception(self):
        incomplete_resp = {
            "m": 1,
            "i": 2,
        }

        mock_queue = AsyncMock()
        mock_queue.get.side_effect = [
            incomplete_resp, asyncio.CancelledError()
        ]
        self.data_source._message_queue[
            CONSTANTS.TRADE_EVENT_TYPE] = mock_queue

        msg_queue: asyncio.Queue = asyncio.Queue()

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_trades(self.ev_loop, msg_queue))

        try:
            self.async_run_with_timeout(self.listening_task)
        except asyncio.CancelledError:
            pass

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error when processing public trade updates from exchange"
            ))

    def test_listen_for_trades_successful(self):
        mock_queue = AsyncMock()
        mock_queue.get.side_effect = [
            self._trade_update_event(),
            asyncio.CancelledError()
        ]
        self.data_source._message_queue[
            CONSTANTS.TRADE_EVENT_TYPE] = mock_queue

        msg_queue: asyncio.Queue = asyncio.Queue()

        try:
            self.listening_task = self.ev_loop.create_task(
                self.data_source.listen_for_trades(self.ev_loop, msg_queue))
        except asyncio.CancelledError:
            pass

        msg: OrderBookMessage = self.async_run_with_timeout(msg_queue.get())

        self.assertTrue(12345, msg.trade_id)

    def test_listen_for_order_book_diffs_cancelled(self):
        mock_queue = AsyncMock()
        mock_queue.get.side_effect = asyncio.CancelledError()
        self.data_source._message_queue[CONSTANTS.DIFF_EVENT_TYPE] = mock_queue

        msg_queue: asyncio.Queue = asyncio.Queue()

        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.data_source.listen_for_order_book_diffs(
                    self.ev_loop, msg_queue))
            self.async_run_with_timeout(self.listening_task)

    def test_listen_for_order_book_diffs_logs_exception(self):
        incomplete_resp = {
            "m": 1,
            "i": 2,
        }

        mock_queue = AsyncMock()
        mock_queue.get.side_effect = [
            incomplete_resp, asyncio.CancelledError()
        ]
        self.data_source._message_queue[CONSTANTS.DIFF_EVENT_TYPE] = mock_queue

        msg_queue: asyncio.Queue = asyncio.Queue()

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_order_book_diffs(
                self.ev_loop, msg_queue))

        try:
            self.async_run_with_timeout(self.listening_task)
        except asyncio.CancelledError:
            pass

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error when processing public order book updates from exchange"
            ))

    def test_listen_for_order_book_diffs_successful(self):
        mock_queue = AsyncMock()
        mock_queue.get.side_effect = [
            self._order_diff_event(),
            asyncio.CancelledError()
        ]
        self.data_source._message_queue[CONSTANTS.DIFF_EVENT_TYPE] = mock_queue

        msg_queue: asyncio.Queue = asyncio.Queue()

        try:
            self.listening_task = self.ev_loop.create_task(
                self.data_source.listen_for_order_book_diffs(
                    self.ev_loop, msg_queue))
        except asyncio.CancelledError:
            pass

        msg: OrderBookMessage = self.async_run_with_timeout(msg_queue.get())

        self.assertTrue(12345, msg.update_id)

    @aioresponses()
    def test_listen_for_order_book_snapshots_cancelled_when_fetching_snapshot(
            self, mock_api):
        url = web_utils.public_rest_url(path_url=CONSTANTS.SNAPSHOT_PATH_URL,
                                        domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, exception=asyncio.CancelledError)

        with self.assertRaises(asyncio.CancelledError):
            self.async_run_with_timeout(
                self.data_source.listen_for_order_book_snapshots(
                    self.ev_loop, asyncio.Queue()))

    @aioresponses()
    @patch(
        "hummingbot.connector.exchange.binance.binance_api_order_book_data_source"
        ".BinanceAPIOrderBookDataSource._sleep")
    def test_listen_for_order_book_snapshots_log_exception(
            self, mock_api, sleep_mock):
        msg_queue: asyncio.Queue = asyncio.Queue()
        sleep_mock.side_effect = lambda _: self._create_exception_and_unlock_test_with_event(
            asyncio.CancelledError())

        url = web_utils.public_rest_url(path_url=CONSTANTS.SNAPSHOT_PATH_URL,
                                        domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, exception=Exception)

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_order_book_snapshots(
                self.ev_loop, msg_queue))
        self.async_run_with_timeout(self.resume_test_event.wait())

        self.assertTrue(
            self._is_logged(
                "ERROR",
                f"Unexpected error fetching order book snapshot for {self.trading_pair}."
            ))

    @aioresponses()
    def test_listen_for_order_book_snapshots_successful(
        self,
        mock_api,
    ):
        msg_queue: asyncio.Queue = asyncio.Queue()
        url = web_utils.public_rest_url(path_url=CONSTANTS.SNAPSHOT_PATH_URL,
                                        domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, body=json.dumps(self._snapshot_response()))

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_order_book_snapshots(
                self.ev_loop, msg_queue))

        msg: OrderBookMessage = self.async_run_with_timeout(msg_queue.get())

        self.assertEqual(1027024, msg.update_id)
Ejemplo n.º 4
0
class BinancePerpetualUserStreamDataSourceUnitTests(unittest.TestCase):
    # the level is required to receive logs from the data source logger
    level = 0

    @classmethod
    def setUpClass(cls) -> None:
        super().setUpClass()
        cls.ev_loop = asyncio.get_event_loop()
        cls.base_asset = "COINALPHA"
        cls.quote_asset = "HBOT"
        cls.trading_pair = f"{cls.base_asset}-{cls.quote_asset}"
        cls.ex_trading_pair = cls.base_asset + cls.quote_asset
        cls.domain = CONSTANTS.TESTNET_DOMAIN

        cls.api_key = "TEST_API_KEY"
        cls.secret_key = "TEST_SECRET_KEY"
        cls.listen_key = "TEST_LISTEN_KEY"

    def setUp(self) -> None:
        super().setUp()
        self.log_records = []
        self.listening_task: Optional[asyncio.Task] = None
        self.mocking_assistant = NetworkMockingAssistant()

        self.emulated_time = 1640001112.223
        self.auth = BinancePerpetualAuth(api_key=self.api_key,
                                         api_secret=self.secret_key,
                                         time_provider=self)
        self.throttler = AsyncThrottler(rate_limits=CONSTANTS.RATE_LIMITS)
        self.time_synchronizer = TimeSynchronizer()
        self.time_synchronizer.add_time_offset_ms_sample(0)
        self.data_source = BinancePerpetualUserStreamDataSource(
            auth=self.auth,
            domain=self.domain,
            throttler=self.throttler,
            time_synchronizer=self.time_synchronizer)

        self.data_source.logger().setLevel(1)
        self.data_source.logger().addHandler(self)

        self.mock_done_event = asyncio.Event()
        self.resume_test_event = asyncio.Event()

    def tearDown(self) -> None:
        self.listening_task and self.listening_task.cancel()
        super().tearDown()

    def handle(self, record):
        self.log_records.append(record)

    def async_run_with_timeout(self, coroutine: Awaitable, timeout: float = 1):
        ret = self.ev_loop.run_until_complete(
            asyncio.wait_for(coroutine, timeout))
        return ret

    def _is_logged(self, log_level: str, message: str) -> bool:
        return any(
            record.levelname == log_level and record.getMessage() == message
            for record in self.log_records)

    def _raise_exception(self, exception_class):
        raise exception_class

    def _mock_responses_done_callback(self, *_, **__):
        self.mock_done_event.set()

    def _create_exception_and_unlock_test_with_event(self, exception):
        self.resume_test_event.set()
        raise exception

    def _successful_get_listen_key_response(self) -> str:
        resp = {"listenKey": self.listen_key}
        return ujson.dumps(resp)

    def _error_response(self) -> Dict[str, Any]:
        resp = {"code": "ERROR CODE", "msg": "ERROR MESSAGE"}

        return resp

    def _simulate_user_update_event(self):
        # Order Trade Update
        resp = {
            "e": "ORDER_TRADE_UPDATE",
            "E": 1591274595442,
            "T": 1591274595453,
            "i": "SfsR",
            "o": {
                "s": "BTCUSD_200925",
                "c": "TEST",
                "S": "SELL",
                "o": "TRAILING_STOP_MARKET",
                "f": "GTC",
                "q": "2",
                "p": "0",
                "ap": "0",
                "sp": "9103.1",
                "x": "NEW",
                "X": "NEW",
                "i": 8888888,
                "l": "0",
                "z": "0",
                "L": "0",
                "ma": "BTC",
                "N": "BTC",
                "n": "0",
                "T": 1591274595442,
                "t": 0,
                "rp": "0",
                "b": "0",
                "a": "0",
                "m": False,
                "R": False,
                "wt": "CONTRACT_PRICE",
                "ot": "TRAILING_STOP_MARKET",
                "ps": "LONG",
                "cp": False,
                "AP": "9476.8",
                "cr": "5.0",
                "pP": False,
            },
        }
        return ujson.dumps(resp)

    def time(self):
        # Implemented to emulate a TimeSynchronizer
        return self.emulated_time

    def test_last_recv_time(self):
        # Initial last_recv_time
        self.assertEqual(0, self.data_source.last_recv_time)

    @aioresponses()
    def test_get_listen_key_exception_raised(self, mock_api):
        url = web_utils.rest_url(
            path_url=CONSTANTS.BINANCE_USER_STREAM_ENDPOINT,
            domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.post(regex_url,
                      status=400,
                      body=ujson.dumps(self._error_response))

        with self.assertRaises(IOError):
            self.async_run_with_timeout(self.data_source.get_listen_key())

    @aioresponses()
    def test_get_listen_key_successful(self, mock_api):
        url = web_utils.rest_url(
            path_url=CONSTANTS.BINANCE_USER_STREAM_ENDPOINT,
            domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.post(regex_url,
                      body=self._successful_get_listen_key_response())

        result: str = self.async_run_with_timeout(
            self.data_source.get_listen_key())

        self.assertEqual(self.listen_key, result)

    @aioresponses()
    def test_ping_listen_key_failed_log_warning(self, mock_api):
        url = web_utils.rest_url(
            path_url=CONSTANTS.BINANCE_USER_STREAM_ENDPOINT,
            domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.put(regex_url,
                     status=400,
                     body=ujson.dumps(self._error_response()))

        self.data_source._current_listen_key = self.listen_key
        result: bool = self.async_run_with_timeout(
            self.data_source.ping_listen_key())

        self.assertTrue(
            self._is_logged(
                "WARNING",
                f"Failed to refresh the listen key {self.listen_key}: {self._error_response()}"
            ))
        self.assertFalse(result)

    @aioresponses()
    def test_ping_listen_key_successful(self, mock_api):
        url = web_utils.rest_url(
            path_url=CONSTANTS.BINANCE_USER_STREAM_ENDPOINT,
            domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.put(regex_url, body=ujson.dumps({}))

        self.data_source._current_listen_key = self.listen_key
        result: bool = self.async_run_with_timeout(
            self.data_source.ping_listen_key())
        self.assertTrue(result)

    @aioresponses()
    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_create_websocket_connection_log_exception(self, mock_api,
                                                       mock_ws):
        url = web_utils.rest_url(
            path_url=CONSTANTS.BINANCE_USER_STREAM_ENDPOINT,
            domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.post(regex_url,
                      body=self._successful_get_listen_key_response())

        mock_ws.side_effect = Exception("TEST ERROR.")

        msg_queue = asyncio.Queue()
        try:
            self.async_run_with_timeout(
                self.data_source.listen_for_user_stream(msg_queue))
        except asyncio.exceptions.TimeoutError:
            pass

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error while listening to user stream. Retrying after 5 seconds... Error: TEST ERROR.",
            ))

    @aioresponses()
    def test_manage_listen_key_task_loop_keep_alive_failed(self, mock_api):
        url = web_utils.rest_url(
            path_url=CONSTANTS.BINANCE_USER_STREAM_ENDPOINT,
            domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.put(regex_url,
                     status=400,
                     body=ujson.dumps(self._error_response()),
                     callback=self._mock_responses_done_callback)

        self.data_source._current_listen_key = self.listen_key

        # Simulate LISTEN_KEY_KEEP_ALIVE_INTERVAL reached
        self.data_source._last_listen_key_ping_ts = 0

        self.listening_task = self.ev_loop.create_task(
            self.data_source._manage_listen_key_task_loop())

        self.async_run_with_timeout(self.mock_done_event.wait())

        self.assertTrue(
            self._is_logged("ERROR", "Error occurred renewing listen key... "))
        self.assertIsNone(self.data_source._current_listen_key)
        self.assertFalse(
            self.data_source._listen_key_initialized_event.is_set())

    @aioresponses()
    def test_manage_listen_key_task_loop_keep_alive_successful(self, mock_api):
        url = web_utils.rest_url(
            path_url=CONSTANTS.BINANCE_USER_STREAM_ENDPOINT,
            domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.put(regex_url,
                     body=ujson.dumps({}),
                     callback=self._mock_responses_done_callback)

        self.data_source._current_listen_key = self.listen_key

        # Simulate LISTEN_KEY_KEEP_ALIVE_INTERVAL reached
        self.data_source._last_listen_key_ping_ts = 0

        self.listening_task = self.ev_loop.create_task(
            self.data_source._manage_listen_key_task_loop())

        self.async_run_with_timeout(self.mock_done_event.wait())

        self.assertTrue(
            self._is_logged("INFO",
                            f"Refreshed listen key {self.listen_key}."))
        self.assertGreater(self.data_source._last_listen_key_ping_ts, 0)

    @aioresponses()
    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_user_stream_create_websocket_connection_failed(
            self, mock_api, mock_ws):
        url = web_utils.rest_url(
            path_url=CONSTANTS.BINANCE_USER_STREAM_ENDPOINT,
            domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.post(regex_url,
                      body=self._successful_get_listen_key_response())

        mock_ws.side_effect = Exception("TEST ERROR.")

        msg_queue = asyncio.Queue()

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_user_stream(msg_queue))

        try:
            self.async_run_with_timeout(msg_queue.get())
        except asyncio.exceptions.TimeoutError:
            pass

        self.assertTrue(
            self._is_logged(
                "INFO", f"Successfully obtained listen key {self.listen_key}"))
        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error while listening to user stream. Retrying after 5 seconds... Error: TEST ERROR.",
            ))

    @aioresponses()
    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    @patch(
        "hummingbot.core.data_type.user_stream_tracker_data_source.UserStreamTrackerDataSource._sleep"
    )
    def test_listen_for_user_stream_iter_message_throws_exception(
            self, mock_api, _, mock_ws):
        url = web_utils.rest_url(
            path_url=CONSTANTS.BINANCE_USER_STREAM_ENDPOINT,
            domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_response = {"listenKey": self.listen_key}
        mock_api.post(regex_url, body=ujson.dumps(mock_response))

        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        mock_ws.return_value.receive.side_effect = Exception("TEST ERROR")
        mock_ws.return_value.closed = False
        mock_ws.return_value.close.side_effect = Exception

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_user_stream(msg_queue))

        try:
            self.async_run_with_timeout(msg_queue.get())
        except Exception:
            pass

        self.assertTrue(
            self._is_logged(
                "INFO", f"Successfully obtained listen key {self.listen_key}"))
        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error while listening to user stream. Retrying after 5 seconds... Error: TEST ERROR",
            ))

    @aioresponses()
    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_user_stream_successful(self, mock_api, mock_ws):
        url = web_utils.rest_url(
            path_url=CONSTANTS.BINANCE_USER_STREAM_ENDPOINT,
            domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.post(regex_url,
                      body=self._successful_get_listen_key_response())

        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()

        self.mocking_assistant.add_websocket_aiohttp_message(
            mock_ws.return_value, self._simulate_user_update_event())

        msg_queue = asyncio.Queue()
        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_user_stream(msg_queue))

        msg = self.async_run_with_timeout(msg_queue.get())
        self.assertTrue(msg, self._simulate_user_update_event)
        mock_ws.return_value.ping.assert_called()

    @aioresponses()
    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_user_stream_does_not_queue_empty_payload(
            self, mock_api, mock_ws):
        url = web_utils.rest_url(
            path_url=CONSTANTS.BINANCE_USER_STREAM_ENDPOINT,
            domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.post(regex_url,
                      body=self._successful_get_listen_key_response())

        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()

        self.mocking_assistant.add_websocket_aiohttp_message(
            mock_ws.return_value, "")

        msg_queue = asyncio.Queue()
        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_user_stream(msg_queue))

        self.mocking_assistant.run_until_all_aiohttp_messages_delivered(
            mock_ws.return_value)

        self.assertEqual(0, msg_queue.qsize())
class BinanceUserStreamDataSourceUnitTests(unittest.TestCase):
    # the level is required to receive logs from the data source logger
    level = 0

    @classmethod
    def setUpClass(cls) -> None:
        super().setUpClass()
        cls.ev_loop = asyncio.get_event_loop()
        cls.base_asset = "COINALPHA"
        cls.quote_asset = "HBOT"
        cls.trading_pair = f"{cls.base_asset}-{cls.quote_asset}"
        cls.ex_trading_pair = cls.base_asset + cls.quote_asset
        cls.domain = "com"

        cls.listen_key = "TEST_LISTEN_KEY"

    def setUp(self) -> None:
        super().setUp()
        self.log_records = []
        self.listening_task: Optional[asyncio.Task] = None
        self.mocking_assistant = NetworkMockingAssistant()

        self.throttler = AsyncThrottler(rate_limits=CONSTANTS.RATE_LIMITS)
        self.mock_time_provider = MagicMock()
        self.mock_time_provider.time.return_value = 1000

        self.time_synchronizer = TimeSynchronizer()
        self.time_synchronizer.add_time_offset_ms_sample(0)

        self.data_source = BinanceAPIUserStreamDataSource(
            auth=BinanceAuth(api_key="TEST_API_KEY", secret_key="TEST_SECRET", time_provider=self.mock_time_provider),
            domain=self.domain,
            throttler=self.throttler,
            time_synchronizer=self.time_synchronizer,
        )

        self.data_source.logger().setLevel(1)
        self.data_source.logger().addHandler(self)

        self.resume_test_event = asyncio.Event()

    def tearDown(self) -> None:
        self.listening_task and self.listening_task.cancel()
        super().tearDown()

    def handle(self, record):
        self.log_records.append(record)

    def _is_logged(self, log_level: str, message: str) -> bool:
        return any(record.levelname == log_level and record.getMessage() == message
                   for record in self.log_records)

    def _raise_exception(self, exception_class):
        raise exception_class

    def _create_exception_and_unlock_test_with_event(self, exception):
        self.resume_test_event.set()
        raise exception

    def _create_return_value_and_unlock_test_with_event(self, value):
        self.resume_test_event.set()
        return value

    def async_run_with_timeout(self, coroutine: Awaitable, timeout: float = 1):
        ret = self.ev_loop.run_until_complete(asyncio.wait_for(coroutine, timeout))
        return ret

    def _error_response(self) -> Dict[str, Any]:
        resp = {
            "code": "ERROR CODE",
            "msg": "ERROR MESSAGE"
        }

        return resp

    def _user_update_event(self):
        # Balance Update
        resp = {
            "e": "balanceUpdate",
            "E": 1573200697110,
            "a": "BTC",
            "d": "100.00000000",
            "T": 1573200697068
        }
        return json.dumps(resp)

    def _successfully_subscribed_event(self):
        resp = {
            "result": None,
            "id": 1
        }
        return resp

    @aioresponses()
    def test_get_listen_key_log_exception(self, mock_api):
        url = web_utils.private_rest_url(path_url=CONSTANTS.BINANCE_USER_STREAM_PATH_URL, domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".", r"\.").replace("?", r"\?"))

        mock_api.post(regex_url, status=400, body=json.dumps(self._error_response()))

        with self.assertRaises(IOError):
            self.async_run_with_timeout(self.data_source._get_listen_key())

    @aioresponses()
    def test_get_listen_key_successful(self, mock_api):
        url = web_utils.private_rest_url(path_url=CONSTANTS.BINANCE_USER_STREAM_PATH_URL, domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".", r"\.").replace("?", r"\?"))

        mock_response = {
            "listenKey": self.listen_key
        }
        mock_api.post(regex_url, body=json.dumps(mock_response))

        result: str = self.async_run_with_timeout(self.data_source._get_listen_key())

        self.assertEqual(self.listen_key, result)

    @aioresponses()
    def test_ping_listen_key_log_exception(self, mock_api):
        url = web_utils.private_rest_url(path_url=CONSTANTS.BINANCE_USER_STREAM_PATH_URL, domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".", r"\.").replace("?", r"\?"))

        mock_api.put(regex_url, status=400, body=json.dumps(self._error_response()))

        self.data_source._current_listen_key = self.listen_key
        result: bool = self.async_run_with_timeout(self.data_source._ping_listen_key())

        self.assertTrue(self._is_logged("WARNING", f"Failed to refresh the listen key {self.listen_key}: {self._error_response()}"))
        self.assertFalse(result)

    @aioresponses()
    def test_ping_listen_key_successful(self, mock_api):
        url = web_utils.private_rest_url(path_url=CONSTANTS.BINANCE_USER_STREAM_PATH_URL, domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".", r"\.").replace("?", r"\?"))
        mock_api.put(regex_url, body=json.dumps({}))

        self.data_source._current_listen_key = self.listen_key
        result: bool = self.async_run_with_timeout(self.data_source._ping_listen_key())
        self.assertTrue(result)

    @patch("hummingbot.connector.exchange.binance.binance_api_user_stream_data_source.BinanceAPIUserStreamDataSource"
           "._ping_listen_key",
           new_callable=AsyncMock)
    def test_manage_listen_key_task_loop_keep_alive_failed(self, mock_ping_listen_key):
        mock_ping_listen_key.side_effect = (lambda *args, **kwargs:
                                            self._create_return_value_and_unlock_test_with_event(False))

        self.data_source._current_listen_key = self.listen_key

        # Simulate LISTEN_KEY_KEEP_ALIVE_INTERVAL reached
        self.data_source._last_listen_key_ping_ts = 0

        self.listening_task = self.ev_loop.create_task(self.data_source._manage_listen_key_task_loop())

        self.async_run_with_timeout(self.resume_test_event.wait())

        self.assertTrue(self._is_logged("ERROR", "Error occurred renewing listen key ..."))
        self.assertIsNone(self.data_source._current_listen_key)
        self.assertFalse(self.data_source._listen_key_initialized_event.is_set())

    @patch("hummingbot.connector.exchange.binance.binance_api_user_stream_data_source.BinanceAPIUserStreamDataSource."
           "_ping_listen_key",
           new_callable=AsyncMock)
    def test_manage_listen_key_task_loop_keep_alive_successful(self, mock_ping_listen_key):
        mock_ping_listen_key.side_effect = (lambda *args, **kwargs:
                                            self._create_return_value_and_unlock_test_with_event(True))

        # Simulate LISTEN_KEY_KEEP_ALIVE_INTERVAL reached
        self.data_source._current_listen_key = self.listen_key
        self.data_source._listen_key_initialized_event.set()
        self.data_source._last_listen_key_ping_ts = 0

        self.listening_task = self.ev_loop.create_task(self.data_source._manage_listen_key_task_loop())

        self.async_run_with_timeout(self.resume_test_event.wait())

        self.assertTrue(self._is_logged("INFO", f"Refreshed listen key {self.listen_key}."))
        self.assertGreater(self.data_source._last_listen_key_ping_ts, 0)

    @aioresponses()
    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_user_stream_get_listen_key_successful_with_user_update_event(self, mock_api, mock_ws):
        url = web_utils.private_rest_url(path_url=CONSTANTS.BINANCE_USER_STREAM_PATH_URL, domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".", r"\.").replace("?", r"\?"))

        mock_response = {
            "listenKey": self.listen_key
        }
        mock_api.post(regex_url, body=json.dumps(mock_response))

        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        self.mocking_assistant.add_websocket_aiohttp_message(mock_ws.return_value, self._user_update_event())

        msg_queue = asyncio.Queue()
        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_user_stream(msg_queue)
        )

        msg = self.async_run_with_timeout(msg_queue.get())
        self.assertEqual(json.loads(self._user_update_event()), msg)
        mock_ws.return_value.ping.assert_called()

    @aioresponses()
    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_user_stream_does_not_queue_empty_payload(self, mock_api, mock_ws):
        url = web_utils.private_rest_url(path_url=CONSTANTS.BINANCE_USER_STREAM_PATH_URL, domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".", r"\.").replace("?", r"\?"))

        mock_response = {
            "listenKey": self.listen_key
        }
        mock_api.post(regex_url, body=json.dumps(mock_response))

        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        self.mocking_assistant.add_websocket_aiohttp_message(mock_ws.return_value, "")

        msg_queue = asyncio.Queue()
        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_user_stream(msg_queue)
        )

        self.mocking_assistant.run_until_all_aiohttp_messages_delivered(mock_ws.return_value)

        self.assertEqual(0, msg_queue.qsize())

    @aioresponses()
    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_user_stream_connection_failed(self, mock_api, mock_ws):
        url = web_utils.private_rest_url(path_url=CONSTANTS.BINANCE_USER_STREAM_PATH_URL, domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".", r"\.").replace("?", r"\?"))

        mock_response = {
            "listenKey": self.listen_key
        }
        mock_api.post(regex_url, body=json.dumps(mock_response))

        mock_ws.side_effect = lambda *arg, **kwars: self._create_exception_and_unlock_test_with_event(
            Exception("TEST ERROR."))

        msg_queue = asyncio.Queue()
        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_user_stream(msg_queue)
        )

        self.async_run_with_timeout(self.resume_test_event.wait())

        self.assertTrue(
            self._is_logged("ERROR",
                            "Unexpected error while listening to user stream. Retrying after 5 seconds..."))

    @aioresponses()
    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_user_stream_iter_message_throws_exception(self, mock_api, mock_ws):
        url = web_utils.private_rest_url(path_url=CONSTANTS.BINANCE_USER_STREAM_PATH_URL, domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".", r"\.").replace("?", r"\?"))

        mock_response = {
            "listenKey": self.listen_key
        }
        mock_api.post(regex_url, body=json.dumps(mock_response))

        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        mock_ws.return_value.receive.side_effect = (lambda *args, **kwargs:
                                                    self._create_exception_and_unlock_test_with_event(
                                                        Exception("TEST ERROR")))
        mock_ws.close.return_value = None

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_user_stream(msg_queue)
        )

        self.async_run_with_timeout(self.resume_test_event.wait())

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error while listening to user stream. Retrying after 5 seconds..."))
class TestKucoinAPIUserStreamDataSource(unittest.TestCase):
    # the level is required to receive logs from the data source logger
    level = 0

    @classmethod
    def setUpClass(cls) -> None:
        super().setUpClass()
        cls.ev_loop = asyncio.get_event_loop()
        cls.base_asset = "COINALPHA"
        cls.quote_asset = "HBOT"
        cls.trading_pair = f"{cls.base_asset}-{cls.quote_asset}"
        cls.api_key = "someKey"
        cls.api_passphrase = "somePassPhrase"
        cls.api_secret_key = "someSecretKey"

    def setUp(self) -> None:
        super().setUp()
        self.log_records = []
        self.listening_task: Optional[asyncio.Task] = None
        self.mocking_assistant = NetworkMockingAssistant()

        self.throttler = AsyncThrottler(CONSTANTS.RATE_LIMITS)
        self.mock_time_provider = MagicMock()
        self.mock_time_provider.time.return_value = 1000
        self.auth = KucoinAuth(self.api_key,
                               self.api_passphrase,
                               self.api_secret_key,
                               time_provider=self.mock_time_provider)
        self.time_synchronizer = TimeSynchronizer()
        self.time_synchronizer.add_time_offset_ms_sample(0)

        self.api_factory = web_utils.build_api_factory(
            throttler=self.throttler,
            time_synchronizer=self.time_synchronizer,
            auth=self.auth)

        self.data_source = KucoinAPIUserStreamDataSource(
            throttler=self.throttler, api_factory=self.api_factory)

        self.data_source.logger().setLevel(1)
        self.data_source.logger().addHandler(self)

    def tearDown(self) -> None:
        self.listening_task and self.listening_task.cancel()
        super().tearDown()

    def handle(self, record):
        self.log_records.append(record)

    def _is_logged(self, log_level: str, message: str) -> bool:
        return any(
            record.levelname == log_level and record.getMessage() == message
            for record in self.log_records)

    def async_run_with_timeout(self, coroutine: Awaitable, timeout: int = 1):
        ret = self.ev_loop.run_until_complete(
            asyncio.wait_for(coroutine, timeout))
        return ret

    @staticmethod
    def get_listen_key_mock():
        listen_key = {
            "code": "200000",
            "data": {
                "token":
                "someToken",
                "instanceServers": [{
                    "endpoint": "wss://someEndpoint",
                    "encrypt": True,
                    "protocol": "websocket",
                    "pingInterval": 18000,
                    "pingTimeout": 10000,
                }]
            }
        }
        return listen_key

    def test_last_recv_time(self):
        # Initial last_recv_time
        self.assertEqual(0, self.data_source.last_recv_time)

        ws_assistant = self.async_run_with_timeout(
            self.data_source._get_ws_assistant())
        ws_assistant._connection._last_recv_time = 1000
        self.assertEqual(1000, self.data_source.last_recv_time)

    @aioresponses()
    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    @patch(
        "hummingbot.connector.exchange.kucoin.kucoin_web_utils.next_message_id"
    )
    def test_listen_for_user_stream_subscribes_to_orders_and_balances_events(
            self, mock_api, id_mock, ws_connect_mock):
        id_mock.side_effect = [1, 2]
        url = web_utils.rest_url(path_url=CONSTANTS.PRIVATE_WS_DATA_PATH_URL)

        resp = {
            "code": "200000",
            "data": {
                "instanceServers": [{
                    "endpoint": "wss://test.url/endpoint",
                    "protocol": "websocket",
                    "encrypt": True,
                    "pingInterval": 50000,
                    "pingTimeout": 10000
                }],
                "token":
                "testToken"
            }
        }
        mock_api.post(url, body=json.dumps(resp))

        ws_connect_mock.return_value = self.mocking_assistant.create_websocket_mock(
        )

        result_subscribe_trades = {"type": "ack", "id": 1}
        result_subscribe_diffs = {"type": "ack", "id": 2}

        self.mocking_assistant.add_websocket_aiohttp_message(
            websocket_mock=ws_connect_mock.return_value,
            message=json.dumps(result_subscribe_trades))
        self.mocking_assistant.add_websocket_aiohttp_message(
            websocket_mock=ws_connect_mock.return_value,
            message=json.dumps(result_subscribe_diffs))

        output_queue = asyncio.Queue()

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_user_stream(output=output_queue))

        self.mocking_assistant.run_until_all_aiohttp_messages_delivered(
            ws_connect_mock.return_value)

        sent_subscription_messages = self.mocking_assistant.json_messages_sent_through_websocket(
            websocket_mock=ws_connect_mock.return_value)

        self.assertEqual(2, len(sent_subscription_messages))
        expected_orders_subscription = {
            "id": 1,
            "type": "subscribe",
            "topic": "/spotMarket/tradeOrders",
            "privateChannel": True,
            "response": False
        }
        self.assertEqual(expected_orders_subscription,
                         sent_subscription_messages[0])
        expected_balances_subscription = {
            "id": 2,
            "type": "subscribe",
            "topic": "/account/balance",
            "privateChannel": True,
            "response": False
        }
        self.assertEqual(expected_balances_subscription,
                         sent_subscription_messages[1])

        self.assertTrue(
            self._is_logged(
                "INFO",
                "Subscribed to private order changes and balance updates channels..."
            ))

    @aioresponses()
    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_user_stream_get_listen_key_successful_with_user_update_event(
            self, mock_api, mock_ws):
        url = web_utils.rest_url(path_url=CONSTANTS.PRIVATE_WS_DATA_PATH_URL)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_response = self.get_listen_key_mock()
        mock_api.post(regex_url, body=json.dumps(mock_response))

        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        order_event = {
            "type": "message",
            "topic": "/spotMarket/tradeOrders",
            "subject": "orderChange",
            "channelType": "private",
            "data": {
                "symbol": "KCS-USDT",
                "orderType": "limit",
                "side": "buy",
                "orderId": "5efab07953bdea00089965d2",
                "type": "open",
                "orderTime": 1593487481683297666,
                "size": "0.1",
                "filledSize": "0",
                "price": "0.937",
                "clientOid": "1593487481000906",
                "remainSize": "0.1",
                "status": "open",
                "ts": 1593487481683297666
            }
        }
        self.mocking_assistant.add_websocket_aiohttp_message(
            mock_ws.return_value, json.dumps(order_event))

        msg_queue = asyncio.Queue()
        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_user_stream(msg_queue))

        msg = self.async_run_with_timeout(msg_queue.get())
        self.assertEqual(order_event, msg)
        mock_ws.return_value.ping.assert_called()

    @aioresponses()
    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_user_stream_does_not_queue_pong_payload(
            self, mock_api, mock_ws):
        url = web_utils.rest_url(path_url=CONSTANTS.PRIVATE_WS_DATA_PATH_URL)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_response = self.get_listen_key_mock()

        mock_pong = {"id": "1545910590801", "type": "pong"}
        mock_api.post(regex_url, body=json.dumps(mock_response))

        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        self.mocking_assistant.add_websocket_aiohttp_message(
            mock_ws.return_value, json.dumps(mock_pong))

        msg_queue = asyncio.Queue()
        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_user_stream(msg_queue))

        self.mocking_assistant.run_until_all_aiohttp_messages_delivered(
            mock_ws.return_value)

        self.assertEqual(0, msg_queue.qsize())

    @aioresponses()
    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    @patch(
        "hummingbot.core.data_type.user_stream_tracker_data_source.UserStreamTrackerDataSource._sleep"
    )
    def test_listen_for_user_stream_connection_failed(self, mock_api,
                                                      sleep_mock, mock_ws):
        url = web_utils.rest_url(path_url=CONSTANTS.PRIVATE_WS_DATA_PATH_URL)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_response = self.get_listen_key_mock()
        mock_api.post(regex_url, body=json.dumps(mock_response))

        mock_ws.side_effect = Exception("TEST ERROR.")
        sleep_mock.side_effect = asyncio.CancelledError  # to finish the task execution

        msg_queue = asyncio.Queue()
        try:
            self.async_run_with_timeout(
                self.data_source.listen_for_user_stream(msg_queue))
        except asyncio.CancelledError:
            pass

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error occurred when listening to user streams. Retrying in 5 seconds..."
            ))

    @aioresponses()
    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    @patch(
        "hummingbot.core.data_type.user_stream_tracker_data_source.UserStreamTrackerDataSource._sleep"
    )
    def test_listen_for_user_stream_iter_message_throws_exception(
            self, mock_api, sleep_mock, mock_ws):
        url = web_utils.rest_url(path_url=CONSTANTS.PRIVATE_WS_DATA_PATH_URL)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_response = self.get_listen_key_mock()
        mock_api.post(regex_url, body=json.dumps(mock_response))

        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        mock_ws.return_value.receive.side_effect = Exception("TEST ERROR")
        sleep_mock.side_effect = asyncio.CancelledError  # to finish the task execution

        try:
            self.async_run_with_timeout(
                self.data_source.listen_for_user_stream(msg_queue))
        except asyncio.CancelledError:
            pass

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error occurred when listening to user streams. Retrying in 5 seconds..."
            ))

    @aioresponses()
    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    @patch(
        "hummingbot.connector.exchange.kucoin.kucoin_web_utils.next_message_id"
    )
    @patch(
        "hummingbot.connector.exchange.kucoin.kucoin_api_user_stream_data_source.KucoinAPIUserStreamDataSource"
        "._time")
    def test_listen_for_user_stream_sends_ping_message_before_ping_interval_finishes(
            self, mock_api, time_mock, id_mock, ws_connect_mock):

        id_mock.side_effect = [1, 2, 3, 4]
        time_mock.side_effect = [
            1000, 1100, 1101, 1102
        ]  # Simulate first ping interval is already due
        url = web_utils.rest_url(path_url=CONSTANTS.PRIVATE_WS_DATA_PATH_URL)

        resp = {
            "code": "200000",
            "data": {
                "instanceServers": [{
                    "endpoint": "wss://test.url/endpoint",
                    "protocol": "websocket",
                    "encrypt": True,
                    "pingInterval": 20000,
                    "pingTimeout": 10000
                }],
                "token":
                "testToken"
            }
        }
        mock_api.post(url, body=json.dumps(resp))

        ws_connect_mock.return_value = self.mocking_assistant.create_websocket_mock(
        )

        result_subscribe_trades = {"type": "ack", "id": 1}
        result_subscribe_diffs = {"type": "ack", "id": 2}

        self.mocking_assistant.add_websocket_aiohttp_message(
            websocket_mock=ws_connect_mock.return_value,
            message=json.dumps(result_subscribe_trades))
        self.mocking_assistant.add_websocket_aiohttp_message(
            websocket_mock=ws_connect_mock.return_value,
            message=json.dumps(result_subscribe_diffs))

        output_queue = asyncio.Queue()

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_user_stream(output=output_queue))

        self.mocking_assistant.run_until_all_aiohttp_messages_delivered(
            ws_connect_mock.return_value)

        sent_messages = self.mocking_assistant.json_messages_sent_through_websocket(
            websocket_mock=ws_connect_mock.return_value)

        expected_ping_message = {
            "id": 3,
            "type": "ping",
        }
        self.assertEqual(expected_ping_message, sent_messages[-1])
class BinancePerpetualAPIOrderBookDataSourceUnitTests(unittest.TestCase):
    # logging.Level required to receive logs from the data source logger
    level = 0

    @classmethod
    def setUpClass(cls) -> None:
        super().setUpClass()
        cls.ev_loop = asyncio.get_event_loop()
        cls.base_asset = "COINALPHA"
        cls.quote_asset = "HBOT"
        cls.trading_pair = f"{cls.base_asset}-{cls.quote_asset}"
        cls.ex_trading_pair = f"{cls.base_asset}{cls.quote_asset}"
        cls.domain = "binance_perpetual_testnet"

    def setUp(self) -> None:
        super().setUp()
        self.log_records = []
        self.listening_task = None
        self.async_tasks: List[asyncio.Task] = []

        self.time_synchronizer = TimeSynchronizer()
        self.time_synchronizer.add_time_offset_ms_sample(0)
        self.data_source = BinancePerpetualAPIOrderBookDataSource(
            time_synchronizer=self.time_synchronizer,
            trading_pairs=[self.trading_pair],
            domain=self.domain,
        )

        self.data_source.logger().setLevel(1)
        self.data_source.logger().addHandler(self)

        self.mocking_assistant = NetworkMockingAssistant()
        self.resume_test_event = asyncio.Event()
        BinancePerpetualAPIOrderBookDataSource._trading_pair_symbol_map = {
            self.domain: bidict({self.ex_trading_pair: self.trading_pair})
        }

    def tearDown(self) -> None:
        self.listening_task and self.listening_task.cancel()
        for task in self.async_tasks:
            task.cancel()
        BinancePerpetualAPIOrderBookDataSource._trading_pair_symbol_map = {}
        super().tearDown()

    def handle(self, record):
        self.log_records.append(record)

    def async_run_with_timeout(self, coroutine: Awaitable, timeout: float = 1):
        ret = self.ev_loop.run_until_complete(
            asyncio.wait_for(coroutine, timeout))
        return ret

    def resume_test_callback(self, *_, **__):
        self.resume_test_event.set()
        return None

    def _is_logged(self, log_level: str, message: str) -> bool:
        return any(
            record.levelname == log_level and record.getMessage() == message
            for record in self.log_records)

    def _raise_exception(self, exception_class):
        raise exception_class

    def _raise_exception_and_unlock_test_with_event(self, exception):
        self.resume_test_event.set()
        raise exception

    def _orderbook_update_event(self):
        resp = {
            "stream": f"{self.ex_trading_pair.lower()}@depth",
            "data": {
                "e": "depthUpdate",
                "E": 1631591424198,
                "T": 1631591424189,
                "s": self.ex_trading_pair,
                "U": 752409354963,
                "u": 752409360466,
                "pu": 752409354901,
                "b": [
                    ["43614.31", "0.000"],
                ],
                "a": [
                    ["45277.14", "0.257"],
                ],
            },
        }
        return resp

    def _orderbook_trade_event(self):
        resp = {
            "stream": f"{self.ex_trading_pair.lower()}@aggTrade",
            "data": {
                "e": "aggTrade",
                "E": 1631594403486,
                "a": 817295132,
                "s": self.ex_trading_pair,
                "p": "45266.16",
                "q": "2.206",
                "f": 1437689393,
                "l": 1437689407,
                "T": 1631594403330,
                "m": False,
            },
        }
        return resp

    def _funding_info_event(self):
        resp = {
            "stream": f"{self.ex_trading_pair.lower()}@markPrice",
            "data": {
                "e": "markPriceUpdate",
                "E": 1641288864000,
                "s": self.ex_trading_pair,
                "p": "46353.99600757",
                "P": "46507.47845460",
                "i": "46358.63622407",
                "r": "0.00010000",
                "T": 1641312000000,
            },
        }
        return resp

    @aioresponses()
    def test_get_last_traded_prices(self, mock_api):
        url = web_utils.rest_url(CONSTANTS.SERVER_TIME_PATH_URL,
                                 domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        response = {"serverTime": 1640000003000}

        mock_api.get(regex_url, body=json.dumps(response))

        url = web_utils.rest_url(path_url=CONSTANTS.TICKER_PRICE_CHANGE_URL,
                                 domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))
        mock_response: Dict[str, Any] = {
            # Truncated responses
            "lastPrice": "10.0",
        }
        mock_api.get(regex_url, body=json.dumps(mock_response))

        result: Dict[str, Any] = self.async_run_with_timeout(
            self.data_source.get_last_traded_prices(
                trading_pairs=[self.trading_pair], domain=self.domain))
        self.assertTrue(self.trading_pair in result)
        self.assertEqual(10.0, result[self.trading_pair])

    @aioresponses()
    def test_init_trading_pair_symbols_failure(self, mock_api):
        BinancePerpetualAPIOrderBookDataSource._trading_pair_symbol_map = {}
        url = web_utils.rest_url(path_url=CONSTANTS.EXCHANGE_INFO_URL,
                                 domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, status=400, body=json.dumps(["ERROR"]))

        map = self.async_run_with_timeout(
            self.data_source.trading_pair_symbol_map(
                domain=self.domain,
                time_synchronizer=self.data_source._time_synchronizer))
        self.assertEqual(0, len(map))

    @aioresponses()
    def test_init_trading_pair_symbols_successful(self, mock_api):
        url = web_utils.rest_url(path_url=CONSTANTS.EXCHANGE_INFO_URL,
                                 domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))
        mock_response: Dict[str, Any] = {
            # Truncated Responses
            "symbols": [
                {
                    "symbol": self.ex_trading_pair,
                    "pair": self.ex_trading_pair,
                    "baseAsset": self.base_asset,
                    "quoteAsset": self.quote_asset,
                    "status": "TRADING",
                },
                {
                    "symbol": "INACTIVEMARKET",
                    "status": "INACTIVE"
                },
            ],
        }
        mock_api.get(regex_url, status=200, body=json.dumps(mock_response))
        self.async_run_with_timeout(
            self.data_source.init_trading_pair_symbols(
                domain=self.domain,
                time_synchronizer=self.data_source._time_synchronizer))
        self.assertEqual(1, len(self.data_source._trading_pair_symbol_map))

    @aioresponses()
    def test_trading_pair_symbol_map_dictionary_not_initialized(
            self, mock_api):
        BinancePerpetualAPIOrderBookDataSource._trading_pair_symbol_map = {}
        url = web_utils.rest_url(path_url=CONSTANTS.EXCHANGE_INFO_URL,
                                 domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))
        mock_response: Dict[str, Any] = {
            # Truncated Responses
            "symbols": [
                {
                    "symbol": self.ex_trading_pair,
                    "pair": self.ex_trading_pair,
                    "baseAsset": self.base_asset,
                    "quoteAsset": self.quote_asset,
                    "status": "TRADING",
                },
            ]
        }
        mock_api.get(regex_url, status=200, body=json.dumps(mock_response))
        self.async_run_with_timeout(
            self.data_source.trading_pair_symbol_map(
                domain=self.domain,
                time_synchronizer=self.data_source._time_synchronizer))
        self.assertEqual(1, len(self.data_source._trading_pair_symbol_map))

    def test_trading_pair_symbol_map_dictionary_initialized(self):
        result = self.async_run_with_timeout(
            self.data_source.trading_pair_symbol_map(
                domain=self.domain,
                time_synchronizer=self.data_source._time_synchronizer))
        self.assertEqual(1, len(result))

    def test_convert_from_exchange_trading_pair_not_found(self):
        unknown_pair = "UNKNOWN-PAIR"
        with self.assertRaisesRegex(
                ValueError,
                f"There is no symbol mapping for exchange trading pair {unknown_pair}"
        ):
            self.async_run_with_timeout(
                self.data_source.convert_from_exchange_trading_pair(
                    unknown_pair, domain=self.domain))

    def test_convert_from_exchange_trading_pair_successful(self):
        result = self.async_run_with_timeout(
            self.data_source.convert_from_exchange_trading_pair(
                self.ex_trading_pair, domain=self.domain))
        self.assertEqual(result, self.trading_pair)

    def test_convert_to_exchange_trading_pair_not_found(self):
        unknown_pair = "UNKNOWN-PAIR"
        with self.assertRaisesRegex(
                ValueError,
                f"There is no symbol mapping for trading pair {unknown_pair}"):
            self.async_run_with_timeout(
                self.data_source.convert_to_exchange_trading_pair(
                    unknown_pair, domain=self.domain))

    def test_convert_to_exchange_trading_pair_successful(self):
        result = self.async_run_with_timeout(
            self.data_source.convert_to_exchange_trading_pair(
                self.trading_pair, domain=self.domain))
        self.assertEqual(result, self.ex_trading_pair)

    @aioresponses()
    def test_get_snapshot_exception_raised(self, mock_api):
        url = web_utils.rest_url(CONSTANTS.SNAPSHOT_REST_URL,
                                 domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))
        mock_api.get(regex_url, status=400, body=json.dumps(["ERROR"]))

        with self.assertRaises(IOError) as context:
            self.async_run_with_timeout(
                self.data_source.get_snapshot(
                    trading_pair=self.trading_pair,
                    domain=self.domain,
                    time_synchronizer=self.data_source._time_synchronizer))

        self.assertEqual(
            "Error executing request GET /depth. HTTP status is 400. Error: [\"ERROR\"]",
            str(context.exception))

    @aioresponses()
    def test_get_snapshot_successful(self, mock_api):
        url = web_utils.rest_url(CONSTANTS.SNAPSHOT_REST_URL,
                                 domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))
        mock_response = {
            "lastUpdateId": 1027024,
            "E": 1589436922972,
            "T": 1589436922959,
            "bids": [["10", "1"]],
            "asks": [["11", "1"]],
        }
        mock_api.get(regex_url, status=200, body=json.dumps(mock_response))

        result: Dict[str, Any] = self.async_run_with_timeout(
            self.data_source.get_snapshot(
                trading_pair=self.trading_pair,
                domain=self.domain,
                time_synchronizer=self.data_source._time_synchronizer))
        self.assertEqual(mock_response, result)

    @aioresponses()
    def test_get_new_order_book(self, mock_api):
        url = web_utils.rest_url(CONSTANTS.SNAPSHOT_REST_URL,
                                 domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))
        mock_response = {
            "lastUpdateId": 1027024,
            "E": 1589436922972,
            "T": 1589436922959,
            "bids": [["10", "1"]],
            "asks": [["11", "1"]],
        }
        mock_api.get(regex_url, status=200, body=json.dumps(mock_response))
        result = self.async_run_with_timeout(
            self.data_source.get_new_order_book(
                trading_pair=self.trading_pair))
        self.assertIsInstance(result, OrderBook)
        self.assertEqual(1027024, result.snapshot_uid)

    @aioresponses()
    def test_get_funding_info_from_exchange_error_response(self, mock_api):
        url = web_utils.rest_url(CONSTANTS.MARK_PRICE_URL, domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, status=400)

        result = self.async_run_with_timeout(
            self.data_source._get_funding_info_from_exchange(
                self.trading_pair))
        self.assertIsNone(result)
        self._is_logged(
            "ERROR",
            f"Unable to fetch FundingInfo for {self.trading_pair}. Error: None"
        )

    @aioresponses()
    def test_get_funding_info_from_exchange_successful(self, mock_api):
        url = web_utils.rest_url(CONSTANTS.MARK_PRICE_URL, domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_response = {
            "symbol": self.ex_trading_pair,
            "markPrice": "46382.32704603",
            "indexPrice": "46385.80064948",
            "estimatedSettlePrice": "46510.13598963",
            "lastFundingRate": "0.00010000",
            "interestRate": "0.00010000",
            "nextFundingTime": 1641312000000,
            "time": 1641288825000,
        }
        mock_api.get(regex_url, body=json.dumps(mock_response))

        result = self.async_run_with_timeout(
            self.data_source._get_funding_info_from_exchange(
                self.trading_pair))

        self.assertIsInstance(result, FundingInfo)
        self.assertEqual(result.trading_pair, self.trading_pair)
        self.assertEqual(result.index_price,
                         Decimal(mock_response["indexPrice"]))
        self.assertEqual(result.mark_price,
                         Decimal(mock_response["markPrice"]))
        self.assertEqual(result.next_funding_utc_timestamp,
                         mock_response["nextFundingTime"])
        self.assertEqual(result.rate,
                         Decimal(mock_response["lastFundingRate"]))

    @aioresponses()
    def test_get_funding_info(self, mock_api):
        self.assertNotIn(self.trading_pair, self.data_source._funding_info)

        url = web_utils.rest_url(CONSTANTS.MARK_PRICE_URL, domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_response = {
            "symbol": self.ex_trading_pair,
            "markPrice": "46382.32704603",
            "indexPrice": "46385.80064948",
            "estimatedSettlePrice": "46510.13598963",
            "lastFundingRate": "0.00010000",
            "interestRate": "0.00010000",
            "nextFundingTime": 1641312000000,
            "time": 1641288825000,
        }
        mock_api.get(regex_url, body=json.dumps(mock_response))

        result = self.async_run_with_timeout(
            self.data_source.get_funding_info(trading_pair=self.trading_pair))

        self.assertIsInstance(result, FundingInfo)
        self.assertEqual(result.trading_pair, self.trading_pair)
        self.assertEqual(result.index_price,
                         Decimal(mock_response["indexPrice"]))
        self.assertEqual(result.mark_price,
                         Decimal(mock_response["markPrice"]))
        self.assertEqual(result.next_funding_utc_timestamp,
                         mock_response["nextFundingTime"])
        self.assertEqual(result.rate,
                         Decimal(mock_response["lastFundingRate"]))

    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    @patch(
        "hummingbot.core.data_type.order_book_tracker_data_source.OrderBookTrackerDataSource._sleep"
    )
    def test_listen_for_subscriptions_cancelled_when_connecting(
            self, _, mock_ws):
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.side_effect = asyncio.CancelledError

        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.data_source.listen_for_subscriptions())
            self.async_run_with_timeout(self.listening_task)
        self.assertEqual(msg_queue.qsize(), 0)

    @patch(
        "hummingbot.core.data_type.order_book_tracker_data_source.OrderBookTrackerDataSource._sleep"
    )
    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_subscriptions_logs_exception(self, mock_ws, *_):
        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        mock_ws.close.return_value = None
        incomplete_resp = {
            "m": 1,
            "i": 2,
        }
        self.mocking_assistant.add_websocket_aiohttp_message(
            mock_ws.return_value, json.dumps(incomplete_resp))
        self.mocking_assistant.add_websocket_aiohttp_message(
            mock_ws.return_value, json.dumps(self._orderbook_update_event()))

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_subscriptions())

        try:
            self.async_run_with_timeout(self.listening_task)
        except asyncio.exceptions.TimeoutError:
            pass

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error with Websocket connection. Retrying after 30 seconds..."
            ))

    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_subscriptions_successful(self, mock_ws):
        msg_queue_diffs: asyncio.Queue = asyncio.Queue()
        msg_queue_trades: asyncio.Queue = asyncio.Queue()
        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        mock_ws.close.return_value = None

        self.mocking_assistant.add_websocket_aiohttp_message(
            mock_ws.return_value, json.dumps(self._orderbook_update_event()))
        self.mocking_assistant.add_websocket_aiohttp_message(
            mock_ws.return_value, json.dumps(self._orderbook_trade_event()))
        self.mocking_assistant.add_websocket_aiohttp_message(
            mock_ws.return_value, json.dumps(self._funding_info_event()))

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_subscriptions())
        self.listening_task_diffs = self.ev_loop.create_task(
            self.data_source.listen_for_order_book_diffs(
                self.ev_loop, msg_queue_diffs))
        self.listening_task_trades = self.ev_loop.create_task(
            self.data_source.listen_for_trades(self.ev_loop, msg_queue_trades))
        self.listening_task_funding_info = self.ev_loop.create_task(
            self.data_source.listen_for_funding_info())

        result: OrderBookMessage = self.async_run_with_timeout(
            msg_queue_diffs.get())
        self.assertIsInstance(result, OrderBookMessage)
        self.assertEqual(OrderBookMessageType.DIFF, result.type)
        self.assertTrue(result.has_update_id)
        self.assertEqual(result.update_id, 752409360466)
        self.assertEqual(self.trading_pair, result.content["trading_pair"])
        self.assertEqual(1, len(result.content["bids"]))
        self.assertEqual(1, len(result.content["asks"]))

        result: OrderBookMessage = self.async_run_with_timeout(
            msg_queue_trades.get())
        self.assertIsInstance(result, OrderBookMessage)
        self.assertEqual(OrderBookMessageType.TRADE, result.type)
        self.assertTrue(result.has_trade_id)
        self.assertEqual(result.trade_id, 817295132)
        self.assertEqual(self.trading_pair, result.content["trading_pair"])

        self.mocking_assistant.run_until_all_aiohttp_messages_delivered(
            mock_ws.return_value)

        self.assertIn(self.trading_pair, self.data_source.funding_info)

        funding_info: FundingInfo = self.data_source.funding_info[
            self.trading_pair]
        self.assertTrue(self.data_source.is_funding_info_initialized)
        self.assertEqual(funding_info.trading_pair, self.trading_pair)
        self.assertEqual(funding_info.index_price,
                         Decimal(self._funding_info_event()["data"]["i"]))
        self.assertEqual(funding_info.mark_price,
                         Decimal(self._funding_info_event()["data"]["p"]))
        self.assertEqual(funding_info.next_funding_utc_timestamp,
                         int(self._funding_info_event()["data"]["T"]))
        self.assertEqual(funding_info.rate,
                         Decimal(self._funding_info_event()["data"]["r"]))

    @aioresponses()
    def test_listen_for_order_book_snapshots_cancelled_error_raised(
            self, mock_api):
        url = web_utils.rest_url(CONSTANTS.SNAPSHOT_REST_URL,
                                 domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, exception=asyncio.CancelledError)

        msg_queue: asyncio.Queue = asyncio.Queue()

        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.data_source.listen_for_order_book_snapshots(
                    self.ev_loop, msg_queue))
            self.async_run_with_timeout(self.listening_task)

        self.assertEqual(0, msg_queue.qsize())

    @aioresponses()
    def test_listen_for_order_book_snapshots_logs_exception_error_with_response(
            self, mock_api):
        url = web_utils.rest_url(CONSTANTS.SNAPSHOT_REST_URL,
                                 domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_response = {
            "m": 1,
            "i": 2,
        }
        mock_api.get(regex_url,
                     body=json.dumps(mock_response),
                     callback=self.resume_test_callback)

        msg_queue: asyncio.Queue = asyncio.Queue()

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_order_book_snapshots(
                self.ev_loop, msg_queue))

        self.async_run_with_timeout(self.resume_test_event.wait())

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error occurred fetching orderbook snapshots. Retrying in 5 seconds..."
            ))

    @aioresponses()
    def test_listen_for_order_book_snapshots_successful(self, mock_api):
        url = web_utils.rest_url(CONSTANTS.SNAPSHOT_REST_URL,
                                 domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_response = {
            "lastUpdateId": 1027024,
            "E": 1589436922972,
            "T": 1589436922959,
            "bids": [["10", "1"]],
            "asks": [["11", "1"]],
        }
        mock_api.get(regex_url, body=json.dumps(mock_response))

        msg_queue: asyncio.Queue = asyncio.Queue()
        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_order_book_snapshots(
                self.ev_loop, msg_queue))

        result = self.async_run_with_timeout(msg_queue.get())

        self.assertIsInstance(result, OrderBookMessage)
        self.assertEqual(OrderBookMessageType.SNAPSHOT, result.type)
        self.assertTrue(result.has_update_id)
        self.assertEqual(result.update_id, 1027024)
        self.assertEqual(self.trading_pair, result.content["trading_pair"])

    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_funding_info_invalid_trading_pair(self, mock_ws):

        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        mock_ws.close.return_value = None

        mock_response = {
            "stream": "unknown_pair@markPrice",
            "data": {
                "e": "markPriceUpdate",
                "E": 1641288864000,
                "s": "unknown_pair",
                "p": "46353.99600757",
                "P": "46507.47845460",
                "i": "46358.63622407",
                "r": "0.00010000",
                "T": 1641312000000,
            },
        }

        self.mocking_assistant.add_websocket_aiohttp_message(
            mock_ws.return_value, json.dumps(mock_response))

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_subscriptions())

        self.listening_task_funding_info = self.ev_loop.create_task(
            self.data_source.listen_for_funding_info())

        self.mocking_assistant.run_until_all_aiohttp_messages_delivered(
            mock_ws.return_value)

        self.assertNotIn(self.trading_pair, self.data_source.funding_info)

    def test_listen_for_funding_info_cancelled_error_raised(self):

        mock_queue = AsyncMock()
        mock_queue.get.side_effect = asyncio.CancelledError
        self.data_source._message_queue[
            CONSTANTS.FUNDING_INFO_STREAM_ID] = mock_queue

        with self.assertRaises(asyncio.CancelledError):
            self.async_run_with_timeout(
                self.data_source.listen_for_funding_info())

    @patch(
        "hummingbot.core.data_type.order_book_tracker_data_source.OrderBookTrackerDataSource._sleep"
    )
    def test_listen_for_funding_info_logs_exception(self, mock_sleep):

        mock_sleep.side_effect = lambda _: (self.ev_loop.run_until_complete(
            asyncio.sleep(0.5)))

        mock_queue = AsyncMock()
        mock_queue.get.side_effect = lambda: (
            self._raise_exception_and_unlock_test_with_event(
                Exception("TEST ERROR")))
        self.data_source._message_queue[
            CONSTANTS.FUNDING_INFO_STREAM_ID] = mock_queue

        self.listening_task_funding_info = self.ev_loop.create_task(
            self.data_source.listen_for_funding_info())
        self.async_run_with_timeout(self.resume_test_event.wait())

        self._is_logged(
            "ERROR",
            "Unexpected error occured updating funding information. Retrying in 5 seconds... Error: TEST ERROR"
        )
Ejemplo n.º 8
0
class TestBybitAPIOrderBookDataSource(unittest.TestCase):
    # logging.Level required to receive logs from the data source logger
    level = 0

    @classmethod
    def setUpClass(cls) -> None:
        super().setUpClass()
        cls.ev_loop = asyncio.get_event_loop()
        cls.base_asset = "COINALPHA"
        cls.quote_asset = "HBOT"
        cls.trading_pair = f"{cls.base_asset}-{cls.quote_asset}"
        cls.ex_trading_pair = cls.base_asset + cls.quote_asset
        cls.domain = CONSTANTS.DEFAULT_DOMAIN

    def setUp(self) -> None:
        super().setUp()
        self.log_records = []
        self.async_task = None

        self.mocking_assistant = NetworkMockingAssistant()
        self.throttler = AsyncThrottler(CONSTANTS.RATE_LIMITS)
        self.time_synchronnizer = TimeSynchronizer()
        self.time_synchronnizer.add_time_offset_ms_sample(1000)
        self.ob_data_source = BybitAPIOrderBookDataSource(
            trading_pairs=[self.trading_pair],
            throttler=self.throttler,
            time_synchronizer=self.time_synchronnizer)

        self.ob_data_source.logger().setLevel(1)
        self.ob_data_source.logger().addHandler(self)

        BybitAPIOrderBookDataSource._trading_pair_symbol_map = {
            self.domain: bidict({self.ex_trading_pair: self.trading_pair})
        }

    def tearDown(self) -> None:
        self.async_task and self.async_task.cancel()
        BybitAPIOrderBookDataSource._trading_pair_symbol_map = {}
        super().tearDown()

    def handle(self, record):
        self.log_records.append(record)

    def _is_logged(self, log_level: str, message: str) -> bool:
        return any(
            record.levelname == log_level and record.getMessage() == message
            for record in self.log_records)

    def async_run_with_timeout(self, coroutine: Awaitable, timeout: int = 1):
        ret = self.ev_loop.run_until_complete(
            asyncio.wait_for(coroutine, timeout))
        return ret

    # TRADING PAIRS
    @aioresponses()
    def test_fetch_trading_pairs(self, mock_api):
        BybitAPIOrderBookDataSource._trading_pair_symbol_map = {}
        url = web_utils.rest_url(path_url=CONSTANTS.EXCHANGE_INFO_PATH_URL)

        resp = {
            "ret_code":
            0,
            "ret_msg":
            "",
            "ext_code":
            None,
            "ext_info":
            None,
            "result": [{
                "name": "BTCUSDT",
                "alias": "BTCUSDT",
                "baseCurrency": "BTC",
                "quoteCurrency": "USDT",
                "basePrecision": "0.000001",
                "quotePrecision": "0.01",
                "minTradeQuantity": "0.0001",
                "minTradeAmount": "10",
                "minPricePrecision": "0.01",
                "maxTradeQuantity": "2",
                "maxTradeAmount": "200",
                "category": 1
            }, {
                "name": "ETHUSDT",
                "alias": "ETHUSDT",
                "baseCurrency": "ETH",
                "quoteCurrency": "USDT",
                "basePrecision": "0.0001",
                "quotePrecision": "0.01",
                "minTradeQuantity": "0.0001",
                "minTradeAmount": "10",
                "minPricePrecision": "0.01",
                "maxTradeQuantity": "2",
                "maxTradeAmount": "200",
                "category": 1
            }]
        }
        mock_api.get(url, body=json.dumps(resp))

        ret = self.async_run_with_timeout(
            coroutine=BybitAPIOrderBookDataSource.fetch_trading_pairs(
                domain=self.domain,
                throttler=self.throttler,
                time_synchronizer=self.time_synchronnizer,
            ))
        self.assertEqual(2, len(ret))
        self.assertEqual("BTC-USDT", ret[0])
        self.assertEqual("ETH-USDT", ret[1])

    @aioresponses()
    def test_fetch_trading_pairs_exception_raised(self, mock_api):
        BybitAPIOrderBookDataSource._trading_pair_symbol_map = {}

        url = web_utils.rest_url(path_url=CONSTANTS.EXCHANGE_INFO_PATH_URL)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, exception=Exception)

        result: Dict[str] = self.async_run_with_timeout(
            self.ob_data_source.fetch_trading_pairs())

        self.assertEqual(0, len(result))

    # LAST TRADED PRICES
    @aioresponses()
    def test_get_last_traded_prices(self, mock_api):
        BybitAPIOrderBookDataSource._trading_pair_symbol_map[
            CONSTANTS.DEFAULT_DOMAIN]["TKN1TKN2"] = "TKN1-TKN2"

        url1 = web_utils.rest_url(path_url=CONSTANTS.LAST_TRADED_PRICE_PATH,
                                  domain=CONSTANTS.DEFAULT_DOMAIN)
        url1 = f"{url1}?symbol={self.ex_trading_pair}"
        regex_url = re.compile(f"^{url1}".replace(".",
                                                  r"\.").replace("?", r"\?"))
        resp = {
            "ret_code": 0,
            "ret_msg": None,
            "result": {
                "symbol": self.ex_trading_pair,
                "price": "50008"
            },
            "ext_code": None,
            "ext_info": None
        }
        mock_api.get(regex_url, body=json.dumps(resp))

        url2 = web_utils.rest_url(path_url=CONSTANTS.LAST_TRADED_PRICE_PATH,
                                  domain=CONSTANTS.DEFAULT_DOMAIN)
        url2 = f"{url2}?symbol=TKN1TKN2"
        regex_url = re.compile(f"^{url2}".replace(".",
                                                  r"\.").replace("?", r"\?"))
        resp = {
            "ret_code": 0,
            "ret_msg": None,
            "result": {
                "symbol": "TKN1TKN2",
                "price": "2050"
            },
            "ext_code": None,
            "ext_info": None
        }
        mock_api.get(regex_url, body=json.dumps(resp))

        ret = self.async_run_with_timeout(
            coroutine=BybitAPIOrderBookDataSource.get_last_traded_prices(
                [self.trading_pair, "TKN1-TKN2"]))

        ticker_requests = [(key, value)
                           for key, value in mock_api.requests.items()
                           if key[1].human_repr().startswith(url1)
                           or key[1].human_repr().startswith(url2)]
        request_params = ticker_requests[0][1][0].kwargs["params"]
        self.assertEqual(self.ex_trading_pair, request_params["symbol"])
        request_params = ticker_requests[1][1][0].kwargs["params"]
        self.assertEqual("TKN1TKN2", request_params["symbol"])

        self.assertEqual(ret[self.trading_pair], 50008)
        self.assertEqual(ret["TKN1-TKN2"], 2050)

    # ORDER BOOK SNAPSHOT
    @staticmethod
    def _snapshot_response() -> Dict:
        snapshot = {
            "ret_code": 0,
            "ret_msg": None,
            "result": {
                "time": 1620886105740,
                "bids": [["50005.12", "403.0416"]],
                "asks": [["50006.34", "0.2297"]]
            },
            "ext_code": None,
            "ext_info": None
        }
        return snapshot

    @staticmethod
    def _snapshot_response_processed() -> Dict:
        snapshot_processed = {
            "time": 1620886105740,
            "bids": [["50005.12", "403.0416"]],
            "asks": [["50006.34", "0.2297"]]
        }
        return snapshot_processed

    @aioresponses()
    def test_get_snapshot(self, mock_api):
        url = web_utils.rest_url(path_url=CONSTANTS.SNAPSHOT_PATH_URL)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))
        snapshot_data = self._snapshot_response()
        mock_api.get(regex_url, body=json.dumps(snapshot_data))

        ret = self.async_run_with_timeout(
            coroutine=self.ob_data_source.get_snapshot(self.trading_pair))

        self.assertEqual(
            ret, self._snapshot_response_processed())  # shallow comparison ok

    @aioresponses()
    def test_get_snapshot_raises(self, mock_api):
        url = web_utils.rest_url(path_url=CONSTANTS.SNAPSHOT_PATH_URL)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))
        mock_api.get(regex_url, status=500)

        with self.assertRaises(IOError):
            self.async_run_with_timeout(
                coroutine=self.ob_data_source.get_snapshot(self.trading_pair))

    @aioresponses()
    def test_get_new_order_book(self, mock_api):
        url = web_utils.rest_url(path_url=CONSTANTS.SNAPSHOT_PATH_URL)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))
        resp = self._snapshot_response()
        mock_api.get(regex_url, body=json.dumps(resp))

        ret = self.async_run_with_timeout(
            coroutine=self.ob_data_source.get_new_order_book(
                self.trading_pair))
        bid_entries = list(ret.bid_entries())
        ask_entries = list(ret.ask_entries())
        self.assertEqual(1, len(bid_entries))
        self.assertEqual(50005.12, bid_entries[0].price)
        self.assertEqual(403.0416, bid_entries[0].amount)
        self.assertEqual(int(resp["result"]["time"]), bid_entries[0].update_id)
        self.assertEqual(1, len(ask_entries))
        self.assertEqual(50006.34, ask_entries[0].price)
        self.assertEqual(0.2297, ask_entries[0].amount)
        self.assertEqual(int(resp["result"]["time"]), ask_entries[0].update_id)

    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_subscriptions_subscribes_to_trades_and_depth(
            self, ws_connect_mock):
        ws_connect_mock.return_value = self.mocking_assistant.create_websocket_mock(
        )

        result_subscribe_trades = {
            'topic': 'trade',
            'event': 'sub',
            'symbol': self.ex_trading_pair,
            'params': {
                'binary': 'false',
                'symbolName': self.ex_trading_pair
            },
            'code': '0',
            'msg': 'Success'
        }

        result_subscribe_depth = {
            'topic': 'depth',
            'event': 'sub',
            'symbol': self.ex_trading_pair,
            'params': {
                'binary': 'false',
                'symbolName': self.ex_trading_pair
            },
            'code': '0',
            'msg': 'Success'
        }

        self.mocking_assistant.add_websocket_aiohttp_message(
            websocket_mock=ws_connect_mock.return_value,
            message=json.dumps(result_subscribe_trades))
        self.mocking_assistant.add_websocket_aiohttp_message(
            websocket_mock=ws_connect_mock.return_value,
            message=json.dumps(result_subscribe_depth))

        self.listening_task = self.ev_loop.create_task(
            self.ob_data_source.listen_for_subscriptions())

        self.mocking_assistant.run_until_all_aiohttp_messages_delivered(
            ws_connect_mock.return_value)

        sent_subscription_messages = self.mocking_assistant.json_messages_sent_through_websocket(
            websocket_mock=ws_connect_mock.return_value)

        self.assertEqual(2, len(sent_subscription_messages))
        expected_trade_subscription = {
            "topic": "trade",
            "event": "sub",
            "symbol": self.ex_trading_pair,
            "params": {
                "binary": False
            }
        }
        self.assertEqual(expected_trade_subscription,
                         sent_subscription_messages[0])
        expected_diff_subscription = {
            "topic": "diffDepth",
            "event": "sub",
            "symbol": self.ex_trading_pair,
            "params": {
                "binary": False
            }
        }
        self.assertEqual(expected_diff_subscription,
                         sent_subscription_messages[1])

        self.assertTrue(
            self._is_logged(
                "INFO",
                f"Subscribed to public order book and trade channels of {self.trading_pair}..."
            ))

    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    @patch(
        "hummingbot.connector.exchange.bybit.bybit_api_order_book_data_source.BybitAPIOrderBookDataSource._time"
    )
    def test_listen_for_subscriptions_sends_ping_message_before_ping_interval_finishes(
            self, time_mock, ws_connect_mock):

        time_mock.side_effect = [
            1000, 1100, 1101, 1102
        ]  # Simulate first ping interval is already due

        ws_connect_mock.return_value = self.mocking_assistant.create_websocket_mock(
        )

        result_subscribe_trades = {
            'topic': 'trade',
            'event': 'sub',
            'symbol': self.ex_trading_pair,
            'params': {
                'binary': 'false',
                'symbolName': self.ex_trading_pair
            },
            'code': '0',
            'msg': 'Success'
        }

        result_subscribe_depth = {
            'topic': 'depth',
            'event': 'sub',
            'symbol': self.ex_trading_pair,
            'params': {
                'binary': 'false',
                'symbolName': self.ex_trading_pair
            },
            'code': '0',
            'msg': 'Success'
        }

        self.mocking_assistant.add_websocket_aiohttp_message(
            websocket_mock=ws_connect_mock.return_value,
            message=json.dumps(result_subscribe_trades))
        self.mocking_assistant.add_websocket_aiohttp_message(
            websocket_mock=ws_connect_mock.return_value,
            message=json.dumps(result_subscribe_depth))

        self.listening_task = self.ev_loop.create_task(
            self.ob_data_source.listen_for_subscriptions())

        self.mocking_assistant.run_until_all_aiohttp_messages_delivered(
            ws_connect_mock.return_value)
        sent_messages = self.mocking_assistant.json_messages_sent_through_websocket(
            websocket_mock=ws_connect_mock.return_value)

        expected_ping_message = {"ping": int(1101 * 1e3)}
        self.assertEqual(expected_ping_message, sent_messages[-1])

    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    @patch(
        "hummingbot.core.data_type.order_book_tracker_data_source.OrderBookTrackerDataSource._sleep"
    )
    def test_listen_for_subscriptions_raises_cancel_exception(
            self, _, ws_connect_mock):
        ws_connect_mock.side_effect = asyncio.CancelledError
        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.ob_data_source.listen_for_subscriptions())
            self.async_run_with_timeout(self.listening_task)

    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    @patch(
        "hummingbot.core.data_type.order_book_tracker_data_source.OrderBookTrackerDataSource._sleep"
    )
    def test_listen_for_subscriptions_logs_exception_details(
            self, sleep_mock, ws_connect_mock):
        sleep_mock.side_effect = asyncio.CancelledError
        ws_connect_mock.side_effect = Exception("TEST ERROR.")

        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.ob_data_source.listen_for_subscriptions())
            self.async_run_with_timeout(self.listening_task)

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error occurred when listening to order book streams. Retrying in 5 seconds..."
            ))

    def test_listen_for_trades_cancelled_when_listening(self):
        mock_queue = MagicMock()
        mock_queue.get.side_effect = asyncio.CancelledError()
        self.ob_data_source._message_queue[
            CONSTANTS.TRADE_EVENT_TYPE] = mock_queue

        msg_queue: asyncio.Queue = asyncio.Queue()

        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.ob_data_source.listen_for_trades(self.ev_loop, msg_queue))
            self.async_run_with_timeout(self.listening_task)

    def test_listen_for_trades_logs_exception(self):
        incomplete_resp = {
            "topic": "trade",
            "params": {
                "symbol": self.ex_trading_pair,
                "binary": "false",
                "symbolName": self.ex_trading_pair
            },
            "data": {
                "v": "564265886622695424",
                # "t": 1582001735462,
                "p": "9787.5",
                "q": "0.195009",
                "m": True
            }
        }

        mock_queue = AsyncMock()
        mock_queue.get.side_effect = [
            incomplete_resp, asyncio.CancelledError()
        ]
        self.ob_data_source._message_queue[
            CONSTANTS.TRADE_EVENT_TYPE] = mock_queue

        msg_queue: asyncio.Queue = asyncio.Queue()

        self.listening_task = self.ev_loop.create_task(
            self.ob_data_source.listen_for_trades(self.ev_loop, msg_queue))

        try:
            self.async_run_with_timeout(self.listening_task)
        except asyncio.CancelledError:
            pass

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error when processing public trade updates from exchange"
            ))

    def test_listen_for_trades_successful(self):
        mock_queue = AsyncMock()
        trade_event = {
            "symbol":
            self.ex_trading_pair,
            "symbolName":
            self.ex_trading_pair,
            "topic":
            "trade",
            "params": {
                "realtimeInterval": "24h",
                "binary": "false"
            },
            "data": [{
                "v": "929681067596857345",
                "t": 1625562619577,
                "p": "34924.15",
                "q": "0.00027",
                "m": True
            }],
            "f":
            True,
            "sendTime":
            1626249138535,
            "shared":
            False
        }
        mock_queue.get.side_effect = [trade_event, asyncio.CancelledError()]
        self.ob_data_source._message_queue[
            CONSTANTS.TRADE_EVENT_TYPE] = mock_queue

        msg_queue: asyncio.Queue = asyncio.Queue()

        try:
            self.listening_task = self.ev_loop.create_task(
                self.ob_data_source.listen_for_trades(self.ev_loop, msg_queue))
        except asyncio.CancelledError:
            pass

        msg: OrderBookMessage = self.async_run_with_timeout(msg_queue.get())

        self.assertTrue(trade_event["data"][0]["t"], msg.trade_id)

    def test_listen_for_order_book_diffs_cancelled(self):
        mock_queue = AsyncMock()
        mock_queue.get.side_effect = asyncio.CancelledError()
        self.ob_data_source._message_queue[
            CONSTANTS.DIFF_EVENT_TYPE] = mock_queue

        msg_queue: asyncio.Queue = asyncio.Queue()

        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.ob_data_source.listen_for_order_book_diffs(
                    self.ev_loop, msg_queue))
            self.async_run_with_timeout(self.listening_task)

    def test_listen_for_order_book_diffs_logs_exception(self):
        incomplete_resp = {
            # "symbol": self.ex_trading_pair,
            "symbolName":
            self.ex_trading_pair,
            "topic":
            "diffDepth",
            "params": {
                "realtimeInterval": "24h",
                "binary": "false"
            },
            "data": [{
                "e":
                301,
                "s":
                self.ex_trading_pair,
                "t":
                1565600357643,
                "v":
                "112801745_18",
                "b": [["11371.49", "0.0014"], ["11371.12", "0.2"],
                      ["11369.97", "0.3523"], ["11369.96", "0.5"],
                      ["11369.95", "0.0934"], ["11369.94", "1.6809"],
                      ["11369.6", "0.0047"], ["11369.17", "0.3"],
                      ["11369.16", "0.2"], ["11369.04", "1.3203"]],
                "a": [["11375.41", "0.0053"], ["11375.42", "0.0043"],
                      ["11375.48", "0.0052"], ["11375.58", "0.0541"],
                      ["11375.7", "0.0386"], ["11375.71", "2"],
                      ["11377", "2.0691"], ["11377.01", "0.0167"],
                      ["11377.12", "1.5"], ["11377.61", "0.3"]],
                "o":
                0
            }],
            "f":
            False,
            "sendTime":
            1626253839401,
            "shared":
            False
        }

        mock_queue = AsyncMock()
        mock_queue.get.side_effect = [
            incomplete_resp, asyncio.CancelledError()
        ]
        self.ob_data_source._message_queue[
            CONSTANTS.DIFF_EVENT_TYPE] = mock_queue

        msg_queue: asyncio.Queue = asyncio.Queue()

        self.listening_task = self.ev_loop.create_task(
            self.ob_data_source.listen_for_order_book_diffs(
                self.ev_loop, msg_queue))

        try:
            self.async_run_with_timeout(self.listening_task)
        except asyncio.CancelledError:
            pass

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error when processing public order book updates from exchange"
            ))

    def test_listen_for_order_book_diffs_successful(self):
        mock_queue = AsyncMock()
        diff_event = {
            "symbol":
            self.ex_trading_pair,
            "symbolName":
            self.ex_trading_pair,
            "topic":
            "diffDepth",
            "params": {
                "realtimeInterval": "24h",
                "binary": "false"
            },
            "data": [{
                "e":
                301,
                "s":
                self.ex_trading_pair,
                "t":
                1565600357643,
                "v":
                "112801745_18",
                "b": [["11371.49", "0.0014"], ["11371.12", "0.2"],
                      ["11369.97", "0.3523"], ["11369.96", "0.5"],
                      ["11369.95", "0.0934"], ["11369.94", "1.6809"],
                      ["11369.6", "0.0047"], ["11369.17", "0.3"],
                      ["11369.16", "0.2"], ["11369.04", "1.3203"]],
                "a": [["11375.41", "0.0053"], ["11375.42", "0.0043"],
                      ["11375.48", "0.0052"], ["11375.58", "0.0541"],
                      ["11375.7", "0.0386"], ["11375.71", "2"],
                      ["11377", "2.0691"], ["11377.01", "0.0167"],
                      ["11377.12", "1.5"], ["11377.61", "0.3"]],
                "o":
                0
            }],
            "f":
            False,
            "sendTime":
            1626253839401,
            "shared":
            False
        }
        mock_queue.get.side_effect = [diff_event, asyncio.CancelledError()]
        self.ob_data_source._message_queue[
            CONSTANTS.DIFF_EVENT_TYPE] = mock_queue

        msg_queue: asyncio.Queue = asyncio.Queue()

        try:
            self.listening_task = self.ev_loop.create_task(
                self.ob_data_source.listen_for_order_book_diffs(
                    self.ev_loop, msg_queue))
        except asyncio.CancelledError:
            pass

        msg: OrderBookMessage = self.async_run_with_timeout(msg_queue.get())

        self.assertTrue(diff_event["data"][0]["t"], msg.update_id)

    def test_listen_for_order_book_snapshots_cancelled_when_fetching_snapshot(
            self):
        mock_queue = AsyncMock()
        mock_queue.get.side_effect = asyncio.CancelledError()
        self.ob_data_source._message_queue[
            CONSTANTS.SNAPSHOT_EVENT_TYPE] = mock_queue

        msg_queue: asyncio.Queue = asyncio.Queue()

        with self.assertRaises(asyncio.CancelledError):
            self.async_run_with_timeout(
                self.ob_data_source.listen_for_order_book_snapshots(
                    self.ev_loop, msg_queue))

    @aioresponses()
    @patch(
        "hummingbot.core.data_type.order_book_tracker_data_source.OrderBookTrackerDataSource._sleep"
    )
    def test_listen_for_order_book_snapshots_log_exception(
            self, mock_api, sleep_mock):
        mock_queue = AsyncMock()
        mock_queue.get.side_effect = ['ERROR', asyncio.CancelledError]
        self.ob_data_source._message_queue[
            CONSTANTS.SNAPSHOT_EVENT_TYPE] = mock_queue

        msg_queue: asyncio.Queue = asyncio.Queue()
        sleep_mock.side_effect = [asyncio.CancelledError]
        url = web_utils.rest_url(path_url=CONSTANTS.SNAPSHOT_PATH_URL)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))
        mock_api.get(regex_url, exception=Exception)

        try:
            self.async_run_with_timeout(
                self.ob_data_source.listen_for_order_book_snapshots(
                    self.ev_loop, msg_queue))
        except asyncio.CancelledError:
            pass

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error when processing public order book updates from exchange"
            ))

    @aioresponses()
    @patch(
        "hummingbot.core.data_type.order_book_tracker_data_source.OrderBookTrackerDataSource._sleep"
    )
    def test_listen_for_order_book_snapshots_successful_rest(
            self, mock_api, _):
        mock_queue = AsyncMock()
        mock_queue.get.side_effect = asyncio.TimeoutError
        self.ob_data_source._message_queue[
            CONSTANTS.SNAPSHOT_EVENT_TYPE] = mock_queue

        msg_queue: asyncio.Queue = asyncio.Queue()
        url = web_utils.rest_url(path_url=CONSTANTS.SNAPSHOT_PATH_URL)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))
        snapshot_data = self._snapshot_response()
        mock_api.get(regex_url, body=json.dumps(snapshot_data))

        self.listening_task = self.ev_loop.create_task(
            self.ob_data_source.listen_for_order_book_snapshots(
                self.ev_loop, msg_queue))

        msg: OrderBookMessage = self.async_run_with_timeout(msg_queue.get())

        self.assertEqual(int(snapshot_data["result"]["time"]), msg.update_id)

    def test_listen_for_order_book_snapshots_successful_ws(self):
        mock_queue = AsyncMock()
        snapshot_event = {
            "symbol":
            self.ex_trading_pair,
            "symbolName":
            self.ex_trading_pair,
            "topic":
            "diffDepth",
            "params": {
                "realtimeInterval": "24h",
                "binary": "false"
            },
            "data": [{
                "e":
                301,
                "s":
                self.ex_trading_pair,
                "t":
                1565600357643,
                "v":
                "112801745_18",
                "b": [["11371.49", "0.0014"], ["11371.12", "0.2"],
                      ["11369.97", "0.3523"], ["11369.96", "0.5"],
                      ["11369.95", "0.0934"], ["11369.94", "1.6809"],
                      ["11369.6", "0.0047"], ["11369.17", "0.3"],
                      ["11369.16", "0.2"], ["11369.04", "1.3203"]],
                "a": [["11375.41", "0.0053"], ["11375.42", "0.0043"],
                      ["11375.48", "0.0052"], ["11375.58", "0.0541"],
                      ["11375.7", "0.0386"], ["11375.71", "2"],
                      ["11377", "2.0691"], ["11377.01", "0.0167"],
                      ["11377.12", "1.5"], ["11377.61", "0.3"]],
                "o":
                0
            }],
            "f":
            True,
            "sendTime":
            1626253839401,
            "shared":
            False
        }
        mock_queue.get.side_effect = [snapshot_event, asyncio.CancelledError()]
        self.ob_data_source._message_queue[
            CONSTANTS.DIFF_EVENT_TYPE] = mock_queue

        msg_queue: asyncio.Queue = asyncio.Queue()

        try:
            self.listening_task = self.ev_loop.create_task(
                self.ob_data_source.listen_for_order_book_diffs(
                    self.ev_loop, msg_queue))
        except asyncio.CancelledError:
            pass

        msg: OrderBookMessage = self.async_run_with_timeout(msg_queue.get(),
                                                            timeout=6)

        self.assertTrue(snapshot_event["data"][0]["t"], msg.update_id)
class TestGateIoAPIUserStreamDataSource(unittest.TestCase):
    # the level is required to receive logs from the data source logger
    level = 0

    @classmethod
    def setUpClass(cls) -> None:
        super().setUpClass()
        cls.ev_loop = asyncio.get_event_loop()
        cls.base_asset = "COINALPHA"
        cls.quote_asset = "HBOT"
        cls.trading_pair = f"{cls.base_asset}-{cls.quote_asset}"
        cls.ex_trading_pair = f"{cls.base_asset}_{cls.quote_asset}"
        cls.api_key = "someKey"
        cls.api_secret_key = "someSecretKey"

    def setUp(self) -> None:
        super().setUp()
        self.log_records = []
        self.listening_task: Optional[asyncio.Task] = None
        self.mocking_assistant = NetworkMockingAssistant()

        self.throttler = AsyncThrottler(CONSTANTS.RATE_LIMITS)
        self.mock_time_provider = MagicMock()
        self.mock_time_provider.time.return_value = 1000
        self.auth = GateIoAuth(api_key=self.api_key,
                               secret_key=self.api_secret_key,
                               time_provider=self.mock_time_provider)
        self.time_synchronizer = TimeSynchronizer()
        self.time_synchronizer.add_time_offset_ms_sample(0)

        client_config_map = ClientConfigAdapter(ClientConfigMap())
        self.connector = GateIoExchange(client_config_map=client_config_map,
                                        gate_io_api_key="",
                                        gate_io_secret_key="",
                                        trading_pairs=[],
                                        trading_required=False)
        self.connector._web_assistants_factory._auth = self.auth

        self.data_source = GateIoAPIUserStreamDataSource(
            self.auth,
            trading_pairs=[self.trading_pair],
            connector=self.connector,
            api_factory=self.connector._web_assistants_factory)

        self.data_source.logger().setLevel(1)
        self.data_source.logger().addHandler(self)

        self.connector._set_trading_pair_symbol_map(
            bidict({self.ex_trading_pair: self.trading_pair}))

    def tearDown(self) -> None:
        self.listening_task and self.listening_task.cancel()
        super().tearDown()

    def handle(self, record):
        self.log_records.append(record)

    def _is_logged(self, log_level: str, message: str) -> bool:
        return any(
            record.levelname == log_level and record.getMessage() == message
            for record in self.log_records)

    def async_run_with_timeout(self, coroutine: Awaitable, timeout: int = 1):
        ret = self.ev_loop.run_until_complete(
            asyncio.wait_for(coroutine, timeout))
        return ret

    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    @patch(
        "hummingbot.connector.exchange.gate_io.gate_io_api_user_stream_data_source.GateIoAPIUserStreamDataSource"
        "._time")
    def test_listen_for_user_stream_subscribes_to_orders_and_balances_events(
            self, time_mock, ws_connect_mock):
        time_mock.return_value = 1000
        ws_connect_mock.return_value = self.mocking_assistant.create_websocket_mock(
        )

        result_subscribe_orders = {
            "time": 1611541000,
            "channel": CONSTANTS.USER_ORDERS_ENDPOINT_NAME,
            "event": "subscribe",
            "error": None,
            "result": {
                "status": "success"
            }
        }
        result_subscribe_trades = {
            "time": 1611541000,
            "channel": CONSTANTS.USER_TRADES_ENDPOINT_NAME,
            "event": "subscribe",
            "error": None,
            "result": {
                "status": "success"
            }
        }
        result_subscribe_balance = {
            "time": 1611541000,
            "channel": CONSTANTS.USER_BALANCE_ENDPOINT_NAME,
            "event": "subscribe",
            "error": None,
            "result": {
                "status": "success"
            }
        }

        self.mocking_assistant.add_websocket_aiohttp_message(
            websocket_mock=ws_connect_mock.return_value,
            message=json.dumps(result_subscribe_orders))
        self.mocking_assistant.add_websocket_aiohttp_message(
            websocket_mock=ws_connect_mock.return_value,
            message=json.dumps(result_subscribe_trades))
        self.mocking_assistant.add_websocket_aiohttp_message(
            websocket_mock=ws_connect_mock.return_value,
            message=json.dumps(result_subscribe_balance))

        output_queue = asyncio.Queue()

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_user_stream(output=output_queue))

        self.mocking_assistant.run_until_all_aiohttp_messages_delivered(
            ws_connect_mock.return_value)

        sent_subscription_messages = self.mocking_assistant.json_messages_sent_through_websocket(
            websocket_mock=ws_connect_mock.return_value)

        self.assertEqual(3, len(sent_subscription_messages))
        expected_orders_subscription = {
            "time": int(self.mock_time_provider.time()),
            "channel": CONSTANTS.USER_ORDERS_ENDPOINT_NAME,
            "event": "subscribe",
            "payload": [self.ex_trading_pair],
            "auth": {
                "KEY":
                self.api_key,
                "SIGN":
                '005d2e6996fa7783459453d36ff871d8d5cfe225a098f37ac234543811c79e3c'  # noqa: mock
                'db8f41684f3ad9491f65c15ed880ce7baee81f402eb1df56b1bba188c0e7838c',  # noqa: mock
                "method":
                "api_key"
            },
        }
        self.assertEqual(expected_orders_subscription,
                         sent_subscription_messages[0])
        expected_trades_subscription = {
            "time": int(self.mock_time_provider.time()),
            "channel": CONSTANTS.USER_TRADES_ENDPOINT_NAME,
            "event": "subscribe",
            "payload": [self.ex_trading_pair],
            "auth": {
                "KEY":
                self.api_key,
                "SIGN":
                '0f34bf79558905d2b5bc7790febf1099d38ff1aa39525a077db32bcbf9135268'  # noqa: mock
                'caf23cdf2d62315841500962f788f7c5f4c3f4b8a057b2184366687b1f74af69',  # noqa: mock
                "method":
                "api_key"
            }
        }
        self.assertEqual(expected_trades_subscription,
                         sent_subscription_messages[1])
        expected_balances_subscription = {
            "time": int(self.mock_time_provider.time()),
            "channel": CONSTANTS.USER_BALANCE_ENDPOINT_NAME,
            "event": "subscribe",
            "auth": {
                "KEY":
                self.api_key,
                "SIGN":
                '90f5e732fc586d09c4a1b7de13f65b668c7ce90678b30da87aa137364bac0b97'  # noqa: mock
                '16b34219b689fb754e821872933a0e12b1d415867b9fbb8ec441bc86e77fb79c',  # noqa: mock
                "method":
                "api_key"
            }
        }
        self.assertEqual(expected_balances_subscription,
                         sent_subscription_messages[2])

        self.assertTrue(
            self._is_logged(
                "INFO",
                "Subscribed to private order changes and balance updates channels..."
            ))

    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    @patch(
        "hummingbot.connector.exchange.gate_io.gate_io_api_user_stream_data_source.GateIoAPIUserStreamDataSource"
        "._time")
    def test_listen_for_user_stream_skips_subscribe_unsubscribe_messages(
            self, time_mock, ws_connect_mock):
        time_mock.return_value = 1000
        ws_connect_mock.return_value = self.mocking_assistant.create_websocket_mock(
        )

        result_subscribe_orders = {
            "time": 1611541000,
            "channel": CONSTANTS.USER_ORDERS_ENDPOINT_NAME,
            "event": "subscribe",
            "error": None,
            "result": {
                "status": "success"
            }
        }
        result_subscribe_trades = {
            "time": 1611541000,
            "channel": CONSTANTS.USER_TRADES_ENDPOINT_NAME,
            "event": "subscribe",
            "error": None,
            "result": {
                "status": "success"
            }
        }
        result_subscribe_balance = {
            "time": 1611541000,
            "channel": CONSTANTS.USER_BALANCE_ENDPOINT_NAME,
            "event": "subscribe",
            "error": None,
            "result": {
                "status": "success"
            }
        }

        self.mocking_assistant.add_websocket_aiohttp_message(
            websocket_mock=ws_connect_mock.return_value,
            message=json.dumps(result_subscribe_orders))
        self.mocking_assistant.add_websocket_aiohttp_message(
            websocket_mock=ws_connect_mock.return_value,
            message=json.dumps(result_subscribe_trades))
        self.mocking_assistant.add_websocket_aiohttp_message(
            websocket_mock=ws_connect_mock.return_value,
            message=json.dumps(result_subscribe_balance))

        output_queue = asyncio.Queue()

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_user_stream(output=output_queue))

        self.mocking_assistant.run_until_all_aiohttp_messages_delivered(
            ws_connect_mock.return_value)

        self.assertTrue(output_queue.empty())

    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_user_stream_does_not_queue_pong_payload(self, mock_ws):
        mock_pong = {
            "time": 1545404023,
            "channel": CONSTANTS.PONG_CHANNEL_NAME,
            "event": "",
            "error": None,
            "result": None
        }

        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        self.mocking_assistant.add_websocket_aiohttp_message(
            mock_ws.return_value, json.dumps(mock_pong))

        msg_queue = asyncio.Queue()
        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_user_stream(msg_queue))

        self.mocking_assistant.run_until_all_aiohttp_messages_delivered(
            mock_ws.return_value)

        self.assertEqual(0, msg_queue.qsize())

    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    @patch(
        "hummingbot.core.data_type.user_stream_tracker_data_source.UserStreamTrackerDataSource._sleep"
    )
    def test_listen_for_user_stream_connection_failed(self, sleep_mock,
                                                      mock_ws):
        mock_ws.side_effect = Exception("TEST ERROR.")
        sleep_mock.side_effect = asyncio.CancelledError  # to finish the task execution

        msg_queue = asyncio.Queue()
        try:
            self.async_run_with_timeout(
                self.data_source.listen_for_user_stream(msg_queue))
        except asyncio.CancelledError:
            pass

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error while listening to user stream. Retrying after 5 seconds..."
            ))

    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    @patch(
        "hummingbot.core.data_type.user_stream_tracker_data_source.UserStreamTrackerDataSource._sleep"
    )
    def test_listen_for_user_stream_iter_message_throws_exception(
            self, sleep_mock, mock_ws):
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        mock_ws.return_value.receive.side_effect = Exception("TEST ERROR")
        sleep_mock.side_effect = asyncio.CancelledError  # to finish the task execution

        try:
            self.async_run_with_timeout(
                self.data_source.listen_for_user_stream(msg_queue))
        except asyncio.CancelledError:
            pass

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error while listening to user stream. Retrying after 5 seconds..."
            ))