Exemplo n.º 1
0
class BinanceAPIOrderBookDataSourceUnitTests(unittest.TestCase):
    # logging.Level required to receive logs from the data source logger
    level = 0

    @classmethod
    def setUpClass(cls) -> None:
        super().setUpClass()
        cls.ev_loop = asyncio.get_event_loop()
        cls.base_asset = "COINALPHA"
        cls.quote_asset = "HBOT"
        cls.trading_pair = f"{cls.base_asset}-{cls.quote_asset}"
        cls.ex_trading_pair = cls.base_asset + cls.quote_asset
        cls.domain = "com"

    def setUp(self) -> None:
        super().setUp()
        self.log_records = []
        self.listening_task = None
        self.mocking_assistant = NetworkMockingAssistant()
        self.time_synchronizer = TimeSynchronizer()
        self.time_synchronizer.add_time_offset_ms_sample(1000)

        self.throttler = AsyncThrottler(rate_limits=CONSTANTS.RATE_LIMITS)
        self.data_source = BinanceAPIOrderBookDataSource(
            trading_pairs=[self.trading_pair],
            throttler=self.throttler,
            domain=self.domain,
            time_synchronizer=self.time_synchronizer)
        self.data_source.logger().setLevel(1)
        self.data_source.logger().addHandler(self)

        self.resume_test_event = asyncio.Event()

        BinanceAPIOrderBookDataSource._trading_pair_symbol_map = {
            "com":
            bidict({f"{self.base_asset}{self.quote_asset}": self.trading_pair})
        }

    def tearDown(self) -> None:
        self.listening_task and self.listening_task.cancel()
        BinanceAPIOrderBookDataSource._trading_pair_symbol_map = {}
        super().tearDown()

    def handle(self, record):
        self.log_records.append(record)

    def _is_logged(self, log_level: str, message: str) -> bool:
        return any(
            record.levelname == log_level and record.getMessage() == message
            for record in self.log_records)

    def _create_exception_and_unlock_test_with_event(self, exception):
        self.resume_test_event.set()
        raise exception

    def async_run_with_timeout(self, coroutine: Awaitable, timeout: float = 1):
        ret = self.ev_loop.run_until_complete(
            asyncio.wait_for(coroutine, timeout))
        return ret

    def _successfully_subscribed_event(self):
        resp = {"result": None, "id": 1}
        return resp

    def _trade_update_event(self):
        resp = {
            "e": "trade",
            "E": 123456789,
            "s": self.ex_trading_pair,
            "t": 12345,
            "p": "0.001",
            "q": "100",
            "b": 88,
            "a": 50,
            "T": 123456785,
            "m": True,
            "M": True
        }
        return resp

    def _order_diff_event(self):
        resp = {
            "e": "depthUpdate",
            "E": 123456789,
            "s": self.ex_trading_pair,
            "U": 157,
            "u": 160,
            "b": [["0.0024", "10"]],
            "a": [["0.0026", "100"]]
        }
        return resp

    def _snapshot_response(self):
        resp = {
            "lastUpdateId": 1027024,
            "bids": [["4.00000000", "431.00000000"]],
            "asks": [["4.00000200", "12.00000000"]]
        }
        return resp

    @aioresponses()
    def test_get_last_trade_prices(self, mock_api):
        url = web_utils.public_rest_url(
            path_url=CONSTANTS.TICKER_PRICE_CHANGE_PATH_URL,
            domain=self.domain)
        url = f"{url}?symbol={self.base_asset}{self.quote_asset}"
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_response = {
            "symbol": "BNBBTC",
            "priceChange": "-94.99999800",
            "priceChangePercent": "-95.960",
            "weightedAvgPrice": "0.29628482",
            "prevClosePrice": "0.10002000",
            "lastPrice": "100.0",
            "lastQty": "200.00000000",
            "bidPrice": "4.00000000",
            "bidQty": "100.00000000",
            "askPrice": "4.00000200",
            "askQty": "100.00000000",
            "openPrice": "99.00000000",
            "highPrice": "100.00000000",
            "lowPrice": "0.10000000",
            "volume": "8913.30000000",
            "quoteVolume": "15.30000000",
            "openTime": 1499783499040,
            "closeTime": 1499869899040,
            "firstId": 28385,
            "lastId": 28460,
            "count": 76,
        }

        mock_api.get(regex_url, body=json.dumps(mock_response))

        result: Dict[str, float] = self.async_run_with_timeout(
            self.data_source.get_last_traded_prices(
                trading_pairs=[self.trading_pair],
                throttler=self.throttler,
                time_synchronizer=self.time_synchronizer))

        self.assertEqual(1, len(result))
        self.assertEqual(100, result[self.trading_pair])

    @aioresponses()
    def test_get_all_mid_prices(self, mock_api):
        url = web_utils.public_rest_url(CONSTANTS.SERVER_TIME_PATH_URL,
                                        domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        response = {"serverTime": 1640000003000}

        mock_api.get(regex_url, body=json.dumps(response))

        url = web_utils.public_rest_url(
            path_url=CONSTANTS.TICKER_PRICE_CHANGE_PATH_URL,
            domain=self.domain)

        mock_response: List[Dict[str, Any]] = [
            {
                # Truncated Response
                "symbol": self.ex_trading_pair,
                "bidPrice": "99",
                "askPrice": "101",
            },
            {
                # Truncated Response for unrecognized pair
                "symbol": "BCCBTC",
                "bidPrice": "99",
                "askPrice": "101",
            }
        ]

        mock_api.get(url, body=json.dumps(mock_response))

        result: Dict[str, float] = self.async_run_with_timeout(
            self.data_source.get_all_mid_prices())

        self.assertEqual(1, len(result))
        self.assertEqual(100, result[self.trading_pair])

    @aioresponses()
    def test_fetch_trading_pairs(self, mock_api):
        BinanceAPIOrderBookDataSource._trading_pair_symbol_map = {}
        url = web_utils.public_rest_url(
            path_url=CONSTANTS.EXCHANGE_INFO_PATH_URL, domain=self.domain)

        mock_response: Dict[str, Any] = {
            "timezone":
            "UTC",
            "serverTime":
            1639598493658,
            "rateLimits": [],
            "exchangeFilters": [],
            "symbols": [
                {
                    "symbol":
                    "ETHBTC",
                    "status":
                    "TRADING",
                    "baseAsset":
                    "ETH",
                    "baseAssetPrecision":
                    8,
                    "quoteAsset":
                    "BTC",
                    "quotePrecision":
                    8,
                    "quoteAssetPrecision":
                    8,
                    "baseCommissionPrecision":
                    8,
                    "quoteCommissionPrecision":
                    8,
                    "orderTypes": [
                        "LIMIT", "LIMIT_MAKER", "MARKET", "STOP_LOSS_LIMIT",
                        "TAKE_PROFIT_LIMIT"
                    ],
                    "icebergAllowed":
                    True,
                    "ocoAllowed":
                    True,
                    "quoteOrderQtyMarketAllowed":
                    True,
                    "isSpotTradingAllowed":
                    True,
                    "isMarginTradingAllowed":
                    True,
                    "filters": [],
                    "permissions": ["SPOT", "MARGIN"]
                },
                {
                    "symbol":
                    "LTCBTC",
                    "status":
                    "TRADING",
                    "baseAsset":
                    "LTC",
                    "baseAssetPrecision":
                    8,
                    "quoteAsset":
                    "BTC",
                    "quotePrecision":
                    8,
                    "quoteAssetPrecision":
                    8,
                    "baseCommissionPrecision":
                    8,
                    "quoteCommissionPrecision":
                    8,
                    "orderTypes": [
                        "LIMIT", "LIMIT_MAKER", "MARKET", "STOP_LOSS_LIMIT",
                        "TAKE_PROFIT_LIMIT"
                    ],
                    "icebergAllowed":
                    True,
                    "ocoAllowed":
                    True,
                    "quoteOrderQtyMarketAllowed":
                    True,
                    "isSpotTradingAllowed":
                    True,
                    "isMarginTradingAllowed":
                    True,
                    "filters": [],
                    "permissions": ["SPOT", "MARGIN"]
                },
                {
                    "symbol":
                    "BNBBTC",
                    "status":
                    "TRADING",
                    "baseAsset":
                    "BNB",
                    "baseAssetPrecision":
                    8,
                    "quoteAsset":
                    "BTC",
                    "quotePrecision":
                    8,
                    "quoteAssetPrecision":
                    8,
                    "baseCommissionPrecision":
                    8,
                    "quoteCommissionPrecision":
                    8,
                    "orderTypes": [
                        "LIMIT", "LIMIT_MAKER", "MARKET", "STOP_LOSS_LIMIT",
                        "TAKE_PROFIT_LIMIT"
                    ],
                    "icebergAllowed":
                    True,
                    "ocoAllowed":
                    True,
                    "quoteOrderQtyMarketAllowed":
                    True,
                    "isSpotTradingAllowed":
                    True,
                    "isMarginTradingAllowed":
                    True,
                    "filters": [],
                    "permissions": ["MARGIN"]
                },
            ]
        }

        mock_api.get(url, body=json.dumps(mock_response))

        result: Dict[str] = self.async_run_with_timeout(
            self.data_source.fetch_trading_pairs(
                time_synchronizer=self.time_synchronizer))

        self.assertEqual(2, len(result))
        self.assertIn("ETH-BTC", result)
        self.assertIn("LTC-BTC", result)
        self.assertNotIn("BNB-BTC", result)

    @aioresponses()
    def test_fetch_trading_pairs_exception_raised(self, mock_api):
        BinanceAPIOrderBookDataSource._trading_pair_symbol_map = {}

        url = web_utils.public_rest_url(
            path_url=CONSTANTS.EXCHANGE_INFO_PATH_URL, domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, exception=Exception)

        result: Dict[str] = self.async_run_with_timeout(
            self.data_source.fetch_trading_pairs(
                time_synchronizer=self.time_synchronizer))

        self.assertEqual(0, len(result))

    @aioresponses()
    def test_get_snapshot_successful(self, mock_api):
        url = web_utils.public_rest_url(path_url=CONSTANTS.SNAPSHOT_PATH_URL,
                                        domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, body=json.dumps(self._snapshot_response()))

        result: Dict[str, Any] = self.async_run_with_timeout(
            self.data_source.get_snapshot(self.trading_pair))

        self.assertEqual(self._snapshot_response(), result)

    @aioresponses()
    def test_get_snapshot_catch_exception(self, mock_api):
        url = web_utils.public_rest_url(path_url=CONSTANTS.SNAPSHOT_PATH_URL,
                                        domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, status=400)
        with self.assertRaises(IOError):
            self.async_run_with_timeout(
                self.data_source.get_snapshot(self.trading_pair))

    @aioresponses()
    def test_get_new_order_book(self, mock_api):
        url = web_utils.public_rest_url(path_url=CONSTANTS.SNAPSHOT_PATH_URL,
                                        domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_response: Dict[str, Any] = {
            "lastUpdateId": 1,
            "bids": [["4.00000000", "431.00000000"]],
            "asks": [["4.00000200", "12.00000000"]]
        }
        mock_api.get(regex_url, body=json.dumps(mock_response))

        result: OrderBook = self.async_run_with_timeout(
            self.data_source.get_new_order_book(self.trading_pair))

        self.assertEqual(1, result.snapshot_uid)

    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_subscriptions_subscribes_to_trades_and_order_diffs(
            self, ws_connect_mock):
        ws_connect_mock.return_value = self.mocking_assistant.create_websocket_mock(
        )

        result_subscribe_trades = {"result": None, "id": 1}
        result_subscribe_diffs = {"result": None, "id": 2}

        self.mocking_assistant.add_websocket_aiohttp_message(
            websocket_mock=ws_connect_mock.return_value,
            message=json.dumps(result_subscribe_trades))
        self.mocking_assistant.add_websocket_aiohttp_message(
            websocket_mock=ws_connect_mock.return_value,
            message=json.dumps(result_subscribe_diffs))

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_subscriptions())

        self.mocking_assistant.run_until_all_aiohttp_messages_delivered(
            ws_connect_mock.return_value)

        sent_subscription_messages = self.mocking_assistant.json_messages_sent_through_websocket(
            websocket_mock=ws_connect_mock.return_value)

        self.assertEqual(2, len(sent_subscription_messages))
        expected_trade_subscription = {
            "method": "SUBSCRIBE",
            "params": [f"{self.ex_trading_pair.lower()}@trade"],
            "id": 1
        }
        self.assertEqual(expected_trade_subscription,
                         sent_subscription_messages[0])
        expected_diff_subscription = {
            "method": "SUBSCRIBE",
            "params": [f"{self.ex_trading_pair.lower()}@depth@100ms"],
            "id": 2
        }
        self.assertEqual(expected_diff_subscription,
                         sent_subscription_messages[1])

        self.assertTrue(
            self._is_logged(
                "INFO",
                "Subscribed to public order book and trade channels..."))

    @patch(
        "hummingbot.core.data_type.order_book_tracker_data_source.OrderBookTrackerDataSource._sleep"
    )
    @patch("aiohttp.ClientSession.ws_connect")
    def test_listen_for_subscriptions_raises_cancel_exception(
            self, mock_ws, _: AsyncMock):
        mock_ws.side_effect = asyncio.CancelledError

        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.data_source.listen_for_subscriptions())
            self.async_run_with_timeout(self.listening_task)

    @patch(
        "hummingbot.core.data_type.order_book_tracker_data_source.OrderBookTrackerDataSource._sleep"
    )
    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_subscriptions_logs_exception_details(
            self, mock_ws, sleep_mock):
        mock_ws.side_effect = Exception("TEST ERROR.")
        sleep_mock.side_effect = lambda _: self._create_exception_and_unlock_test_with_event(
            asyncio.CancelledError())

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_subscriptions())

        self.async_run_with_timeout(self.resume_test_event.wait())

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error occurred when listening to order book streams. Retrying in 5 seconds..."
            ))

    def test_subscribe_channels_raises_cancel_exception(self):
        mock_ws = MagicMock()
        mock_ws.send.side_effect = asyncio.CancelledError

        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.data_source._subscribe_channels(mock_ws))
            self.async_run_with_timeout(self.listening_task)

    def test_subscribe_channels_raises_exception_and_logs_error(self):
        mock_ws = MagicMock()
        mock_ws.send.side_effect = Exception("Test Error")

        with self.assertRaises(Exception):
            self.listening_task = self.ev_loop.create_task(
                self.data_source._subscribe_channels(mock_ws))
            self.async_run_with_timeout(self.listening_task)

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error occurred subscribing to order book trading and delta streams..."
            ))

    def test_listen_for_trades_cancelled_when_listening(self):
        mock_queue = MagicMock()
        mock_queue.get.side_effect = asyncio.CancelledError()
        self.data_source._message_queue[
            CONSTANTS.TRADE_EVENT_TYPE] = mock_queue

        msg_queue: asyncio.Queue = asyncio.Queue()

        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.data_source.listen_for_trades(self.ev_loop, msg_queue))
            self.async_run_with_timeout(self.listening_task)

    def test_listen_for_trades_logs_exception(self):
        incomplete_resp = {
            "m": 1,
            "i": 2,
        }

        mock_queue = AsyncMock()
        mock_queue.get.side_effect = [
            incomplete_resp, asyncio.CancelledError()
        ]
        self.data_source._message_queue[
            CONSTANTS.TRADE_EVENT_TYPE] = mock_queue

        msg_queue: asyncio.Queue = asyncio.Queue()

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_trades(self.ev_loop, msg_queue))

        try:
            self.async_run_with_timeout(self.listening_task)
        except asyncio.CancelledError:
            pass

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error when processing public trade updates from exchange"
            ))

    def test_listen_for_trades_successful(self):
        mock_queue = AsyncMock()
        mock_queue.get.side_effect = [
            self._trade_update_event(),
            asyncio.CancelledError()
        ]
        self.data_source._message_queue[
            CONSTANTS.TRADE_EVENT_TYPE] = mock_queue

        msg_queue: asyncio.Queue = asyncio.Queue()

        try:
            self.listening_task = self.ev_loop.create_task(
                self.data_source.listen_for_trades(self.ev_loop, msg_queue))
        except asyncio.CancelledError:
            pass

        msg: OrderBookMessage = self.async_run_with_timeout(msg_queue.get())

        self.assertTrue(12345, msg.trade_id)

    def test_listen_for_order_book_diffs_cancelled(self):
        mock_queue = AsyncMock()
        mock_queue.get.side_effect = asyncio.CancelledError()
        self.data_source._message_queue[CONSTANTS.DIFF_EVENT_TYPE] = mock_queue

        msg_queue: asyncio.Queue = asyncio.Queue()

        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.data_source.listen_for_order_book_diffs(
                    self.ev_loop, msg_queue))
            self.async_run_with_timeout(self.listening_task)

    def test_listen_for_order_book_diffs_logs_exception(self):
        incomplete_resp = {
            "m": 1,
            "i": 2,
        }

        mock_queue = AsyncMock()
        mock_queue.get.side_effect = [
            incomplete_resp, asyncio.CancelledError()
        ]
        self.data_source._message_queue[CONSTANTS.DIFF_EVENT_TYPE] = mock_queue

        msg_queue: asyncio.Queue = asyncio.Queue()

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_order_book_diffs(
                self.ev_loop, msg_queue))

        try:
            self.async_run_with_timeout(self.listening_task)
        except asyncio.CancelledError:
            pass

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error when processing public order book updates from exchange"
            ))

    def test_listen_for_order_book_diffs_successful(self):
        mock_queue = AsyncMock()
        mock_queue.get.side_effect = [
            self._order_diff_event(),
            asyncio.CancelledError()
        ]
        self.data_source._message_queue[CONSTANTS.DIFF_EVENT_TYPE] = mock_queue

        msg_queue: asyncio.Queue = asyncio.Queue()

        try:
            self.listening_task = self.ev_loop.create_task(
                self.data_source.listen_for_order_book_diffs(
                    self.ev_loop, msg_queue))
        except asyncio.CancelledError:
            pass

        msg: OrderBookMessage = self.async_run_with_timeout(msg_queue.get())

        self.assertTrue(12345, msg.update_id)

    @aioresponses()
    def test_listen_for_order_book_snapshots_cancelled_when_fetching_snapshot(
            self, mock_api):
        url = web_utils.public_rest_url(path_url=CONSTANTS.SNAPSHOT_PATH_URL,
                                        domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, exception=asyncio.CancelledError)

        with self.assertRaises(asyncio.CancelledError):
            self.async_run_with_timeout(
                self.data_source.listen_for_order_book_snapshots(
                    self.ev_loop, asyncio.Queue()))

    @aioresponses()
    @patch(
        "hummingbot.connector.exchange.binance.binance_api_order_book_data_source"
        ".BinanceAPIOrderBookDataSource._sleep")
    def test_listen_for_order_book_snapshots_log_exception(
            self, mock_api, sleep_mock):
        msg_queue: asyncio.Queue = asyncio.Queue()
        sleep_mock.side_effect = lambda _: self._create_exception_and_unlock_test_with_event(
            asyncio.CancelledError())

        url = web_utils.public_rest_url(path_url=CONSTANTS.SNAPSHOT_PATH_URL,
                                        domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, exception=Exception)

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_order_book_snapshots(
                self.ev_loop, msg_queue))
        self.async_run_with_timeout(self.resume_test_event.wait())

        self.assertTrue(
            self._is_logged(
                "ERROR",
                f"Unexpected error fetching order book snapshot for {self.trading_pair}."
            ))

    @aioresponses()
    def test_listen_for_order_book_snapshots_successful(
        self,
        mock_api,
    ):
        msg_queue: asyncio.Queue = asyncio.Queue()
        url = web_utils.public_rest_url(path_url=CONSTANTS.SNAPSHOT_PATH_URL,
                                        domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, body=json.dumps(self._snapshot_response()))

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_order_book_snapshots(
                self.ev_loop, msg_queue))

        msg: OrderBookMessage = self.async_run_with_timeout(msg_queue.get())

        self.assertEqual(1027024, msg.update_id)
class BinanceAPIOrderBookDataSourceUnitTests(unittest.TestCase):
    # logging.Level required to receive logs from the data source logger
    level = 0

    @classmethod
    def setUpClass(cls) -> None:
        super().setUpClass()
        cls.ev_loop = asyncio.get_event_loop()
        cls.base_asset = "COINALPHA"
        cls.quote_asset = "HBOT"
        cls.trading_pair = f"{cls.base_asset}-{cls.quote_asset}"
        cls.ex_trading_pair = cls.base_asset + cls.quote_asset
        cls.domain = "com"

    def setUp(self) -> None:
        super().setUp()
        self.log_records = []
        self.listening_task = None
        self.mocking_assistant = NetworkMockingAssistant()

        self.throttler = AsyncThrottler(rate_limits=CONSTANTS.RATE_LIMITS)
        self.data_source = BinanceAPIOrderBookDataSource(
            trading_pairs=[self.trading_pair],
            throttler=self.throttler,
            domain=self.domain)
        self.data_source.logger().setLevel(1)
        self.data_source.logger().addHandler(self)

    def tearDown(self) -> None:
        self.listening_task and self.listening_task.cancel()
        super().tearDown()

    def handle(self, record):
        self.log_records.append(record)

    def _is_logged(self, log_level: str, message: str) -> bool:
        return any(
            record.levelname == log_level and record.getMessage() == message
            for record in self.log_records)

    def _raise_exception(self, exception_class):
        raise exception_class

    def _trade_update_event(self):
        resp = {
            "e": "trade",
            "E": 123456789,
            "s": self.ex_trading_pair,
            "t": 12345,
            "p": "0.001",
            "q": "100",
            "b": 88,
            "a": 50,
            "T": 123456785,
            "m": True,
            "M": True
        }
        return ujson.dumps(resp)

    def _order_diff_event(self):
        resp = {
            "e": "depthUpdate",
            "E": 123456789,
            "s": self.ex_trading_pair,
            "U": 157,
            "u": 160,
            "b": [["0.0024", "10"]],
            "a": [["0.0026", "100"]]
        }
        return ujson.dumps(resp)

    def _snapshot_response(self):
        resp = {
            "lastUpdateId": 1027024,
            "bids": [["4.00000000", "431.00000000"]],
            "asks": [["4.00000200", "12.00000000"]]
        }
        return resp

    @patch("aiohttp.ClientSession.get", new_callable=AsyncMock)
    def test_get_last_trade_prices(self, mock_api):
        self.mocking_assistant.configure_http_request_mock(mock_api)

        mock_response: Dict[str, Any] = {
            # Truncated Response
            "lastPrice": "100",
        }

        self.mocking_assistant.add_http_response(mock_api, 200, mock_response)

        result: Dict[str, float] = self.ev_loop.run_until_complete(
            self.data_source.get_last_traded_prices(
                trading_pairs=[self.trading_pair], throttler=self.throttler))

        self.assertEqual(1, len(result))
        self.assertEqual(100, result[self.trading_pair])

    @patch(
        "hummingbot.connector.exchange.binance.binance_utils.convert_from_exchange_trading_pair"
    )
    @patch("aiohttp.ClientSession.get", new_callable=AsyncMock)
    def test_get_all_mid_prices(self, mock_api, mock_utils):
        # Mocks binance_utils for BinanceOrderBook.diff_message_from_exchange()
        mock_utils.return_value = self.trading_pair
        self.mocking_assistant.configure_http_request_mock(mock_api)

        mock_response: List[Dict[str, Any]] = [{
            # Truncated Response
            "symbol": self.ex_trading_pair,
            "bidPrice": "99",
            "askPrice": "101",
        }]

        self.mocking_assistant.add_http_response(mock_api, 200, mock_response)

        result: Dict[str, float] = self.ev_loop.run_until_complete(
            self.data_source.get_all_mid_prices())

        self.assertEqual(1, len(result))
        self.assertEqual(100, result[self.trading_pair])

    @patch(
        "hummingbot.connector.exchange.binance.binance_utils.convert_from_exchange_trading_pair"
    )
    @patch("aiohttp.ClientSession.get")
    def test_fetch_trading_pairs(self, mock_api, mock_utils):
        # Mocks binance_utils for BinanceOrderBook.diff_message_from_exchange()
        mock_utils.return_value = self.trading_pair
        self.mocking_assistant.configure_http_request_mock(mock_api)

        mock_response: Dict[str, Any] = {
            # Truncated Response
            "symbols": [
                {
                    "symbol": self.ex_trading_pair,
                    "status": "TRADING",
                    "baseAsset": self.base_asset,
                    "quoteAsset": self.quote_asset,
                },
            ]
        }

        self.mocking_assistant.add_http_response(mock_api, 200, mock_response)

        result: Dict[str] = self.ev_loop.run_until_complete(
            self.data_source.fetch_trading_pairs())

        self.assertEqual(1, len(result))
        self.assertTrue(self.trading_pair in result)

    def test_get_throttler_instance(self):
        self.assertIsInstance(
            BinanceAPIOrderBookDataSource._get_throttler_instance(),
            AsyncThrottler)

    @patch("aiohttp.ClientSession.get")
    def test_get_snapshot_successful(self, mock_api):
        self.mocking_assistant.configure_http_request_mock(mock_api)

        self.mocking_assistant.add_http_response(mock_api, 200,
                                                 self._snapshot_response())

        result: Dict[str, Any] = self.ev_loop.run_until_complete(
            self.data_source.get_snapshot(self.trading_pair))

        self.assertEqual(self._snapshot_response(), result)

    @patch("aiohttp.ClientSession.get")
    def test_get_snapshot_catch_exception(self, mock_api):
        self.mocking_assistant.configure_http_request_mock(mock_api)

        self.mocking_assistant.add_http_response(mock_api, 400, {})
        with self.assertRaises(IOError):
            self.ev_loop.run_until_complete(
                self.data_source.get_snapshot(self.trading_pair))

    @patch("aiohttp.ClientSession.get")
    def test_get_new_order_book(self, mock_api):
        self.mocking_assistant.configure_http_request_mock(mock_api)

        mock_response: Dict[str, Any] = {
            "lastUpdateId": 1,
            "bids": [["4.00000000", "431.00000000"]],
            "asks": [["4.00000200", "12.00000000"]]
        }
        self.mocking_assistant.add_http_response(mock_api, 200, mock_response)

        result: OrderBook = self.ev_loop.run_until_complete(
            self.data_source.get_new_order_book(self.trading_pair))

        self.assertEqual(1, result.snapshot_uid)

    @patch("websockets.connect")
    def test_listen_for_trades_cancelled_when_connecting(self, mock_ws):
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.side_effect = asyncio.CancelledError

        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.data_source.listen_for_trades(self.ev_loop, msg_queue))
            self.ev_loop.run_until_complete(self.listening_task)

    @patch("websockets.connect", new_callable=AsyncMock)
    def test_listen_for_trades_cancelled_when_listening(self, mock_ws):
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        mock_ws.return_value.recv.side_effect = lambda: (self._raise_exception(
            asyncio.CancelledError))
        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.data_source.listen_for_trades(self.ev_loop, msg_queue))
            self.ev_loop.run_until_complete(self.listening_task)

    @patch("websockets.connect", new_callable=AsyncMock)
    def test_listen_for_trades_logs_exception(self, mock_ws):
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        mock_ws.close.return_value = None

        incomplete_resp = {
            "m": 1,
            "i": 2,
        }
        self.mocking_assistant.add_websocket_text_message(
            mock_ws.return_value, ujson.dumps(incomplete_resp))
        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_trades(self.ev_loop, msg_queue))

        with self.assertRaises(asyncio.TimeoutError):
            self.ev_loop.run_until_complete(
                asyncio.wait_for(self.listening_task, 1))

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error with WebSocket connection. Retrying after 30 seconds..."
            ))

    @patch("websockets.connect", new_callable=AsyncMock)
    def test_listen_for_trades_successful(self, mock_ws):
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        mock_ws.close.return_value = None

        self.mocking_assistant.add_websocket_text_message(
            mock_ws.return_value, self._trade_update_event())
        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_trades(self.ev_loop, msg_queue))

        msg: OrderBookMessage = self.ev_loop.run_until_complete(
            msg_queue.get())

        self.assertTrue(12345, msg.trade_id)

    @patch("websockets.connect")
    def test_listen_for_order_book_diffs_cancelled_when_connecting(
            self, mock_ws):
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.side_effect = asyncio.CancelledError

        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.data_source.listen_for_order_book_diffs(
                    self.ev_loop, msg_queue))
            self.ev_loop.run_until_complete(self.listening_task)

    @patch("websockets.connect", new_callable=AsyncMock)
    def test_listen_for_order_book_diffs_cancelled_when_listening(
            self, mock_ws):
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        mock_ws.return_value.recv.side_effect = lambda: (self._raise_exception(
            asyncio.CancelledError))
        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.data_source.listen_for_order_book_diffs(
                    self.ev_loop, msg_queue))
            self.ev_loop.run_until_complete(self.listening_task)

    @patch("websockets.connect", new_callable=AsyncMock)
    def test_listen_for_order_book_diffs_logs_exception(self, mock_ws):
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        mock_ws.close.return_value = None

        incomplete_resp = {
            "m": 1,
            "i": 2,
        }
        self.mocking_assistant.add_websocket_text_message(
            mock_ws.return_value, ujson.dumps(incomplete_resp))
        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_order_book_diffs(
                self.ev_loop, msg_queue))

        with self.assertRaises(asyncio.TimeoutError):
            self.ev_loop.run_until_complete(
                asyncio.wait_for(self.listening_task, 1))

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error with WebSocket connection. Retrying after 30 seconds..."
            ))

    @patch("websockets.connect", new_callable=AsyncMock)
    def test_listen_for_order_book_diffs_successful(self, mock_ws):
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        mock_ws.close.return_value = None

        self.mocking_assistant.add_websocket_text_message(
            mock_ws.return_value, self._order_diff_event())
        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_order_book_diffs(
                self.ev_loop, msg_queue))

        msg: OrderBookMessage = self.ev_loop.run_until_complete(
            msg_queue.get())

        self.assertTrue(12345, msg.update_id)

    @patch("aiohttp.ClientSession.get")
    def test_listen_for_order_book_snapshots_cancelled_when_fetching_snapshot(
            self, mock_api):
        mock_api.side_effect = asyncio.CancelledError

        with self.assertRaises(asyncio.CancelledError):
            self.ev_loop.run_until_complete(
                self.data_source.listen_for_order_book_snapshots(
                    self.ev_loop, asyncio.Queue()))

    @patch("aiohttp.ClientSession.get")
    def test_listen_for_order_book_snapshots_log_exception(self, mock_api):
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_api.side_effect = lambda: self._raise_exception(Exception)

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_order_book_snapshots(
                self.ev_loop, msg_queue))
        with self.assertRaises(asyncio.TimeoutError):
            self.ev_loop.run_until_complete(
                asyncio.wait_for(self.listening_task, 1))

        self.assertTrue(
            self._is_logged(
                "ERROR",
                f"Unexpected error fetching order book snapshot for {self.trading_pair}."
            ))

    @patch("aiohttp.ClientSession.get")
    def test_listen_for_order_book_snapshots_successful(
        self,
        mock_api,
    ):
        msg_queue: asyncio.Queue = asyncio.Queue()
        self.mocking_assistant.configure_http_request_mock(mock_api)

        self.mocking_assistant.add_http_response(mock_api, 200,
                                                 self._snapshot_response())

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_order_book_snapshots(
                self.ev_loop, msg_queue))

        msg: OrderBookMessage = self.ev_loop.run_until_complete(
            msg_queue.get())

        self.assertTrue(12345, msg.update_id)
Exemplo n.º 3
0
class BinanceAPIOrderBookDataSourceUnitTests(unittest.TestCase):
    # logging.Level required to receive logs from the data source logger
    level = 0

    @classmethod
    def setUpClass(cls) -> None:
        super().setUpClass()
        cls.ev_loop = asyncio.get_event_loop()
        cls.base_asset = "COINALPHA"
        cls.quote_asset = "HBOT"
        cls.trading_pair = f"{cls.base_asset}-{cls.quote_asset}"
        cls.ex_trading_pair = cls.base_asset + cls.quote_asset
        cls.domain = "com"

    def setUp(self) -> None:
        super().setUp()
        self.log_records = []
        self.listening_task = None
        self.mocking_assistant = NetworkMockingAssistant()

        self.throttler = AsyncThrottler(rate_limits=CONSTANTS.RATE_LIMITS)
        self.data_source = BinanceAPIOrderBookDataSource(
            trading_pairs=[self.trading_pair],
            throttler=self.throttler,
            domain=self.domain)
        self.data_source.logger().setLevel(1)
        self.data_source.logger().addHandler(self)

        self.resume_test_event = asyncio.Event()

    def tearDown(self) -> None:
        self.listening_task and self.listening_task.cancel()
        super().tearDown()

    def handle(self, record):
        self.log_records.append(record)

    def _is_logged(self, log_level: str, message: str) -> bool:
        return any(
            record.levelname == log_level and record.getMessage() == message
            for record in self.log_records)

    def _raise_exception(self, exception_class):
        raise exception_class

    def _create_exception_and_unlock_test_with_event(self, exception):
        self.resume_test_event.set()
        raise exception

    def _successfully_subscribed_event(self):
        resp = {"result": None, "id": 1}
        return resp

    def _trade_update_event(self):
        resp = {
            "e": "trade",
            "E": 123456789,
            "s": self.ex_trading_pair,
            "t": 12345,
            "p": "0.001",
            "q": "100",
            "b": 88,
            "a": 50,
            "T": 123456785,
            "m": True,
            "M": True
        }
        return resp

    def _order_diff_event(self):
        resp = {
            "e": "depthUpdate",
            "E": 123456789,
            "s": self.ex_trading_pair,
            "U": 157,
            "u": 160,
            "b": [["0.0024", "10"]],
            "a": [["0.0026", "100"]]
        }
        return resp

    def _snapshot_response(self):
        resp = {
            "lastUpdateId": 1027024,
            "bids": [["4.00000000", "431.00000000"]],
            "asks": [["4.00000200", "12.00000000"]]
        }
        return resp

    @aioresponses()
    def test_get_last_trade_prices(self, mock_api):
        url = utils.public_rest_url(
            path_url=CONSTANTS.TICKER_PRICE_CHANGE_PATH_URL,
            domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_response: Dict[str, Any] = {
            # Truncated Response
            "lastPrice": "100",
        }

        mock_api.get(regex_url, body=ujson.dumps(mock_response))

        result: Dict[str, float] = self.ev_loop.run_until_complete(
            self.data_source.get_last_traded_prices(
                trading_pairs=[self.trading_pair], throttler=self.throttler))

        self.assertEqual(1, len(result))
        self.assertEqual(100, result[self.trading_pair])

    @aioresponses()
    @patch(
        "hummingbot.connector.exchange.binance.binance_utils.convert_from_exchange_trading_pair"
    )
    def test_get_all_mid_prices(self, mock_api, mock_utils):
        # Mocks binance_utils for BinanceOrderBook.diff_message_from_exchange()
        mock_utils.return_value = self.trading_pair
        url = utils.public_rest_url(
            path_url=CONSTANTS.TICKER_PRICE_CHANGE_PATH_URL,
            domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_response: List[Dict[str, Any]] = [{
            # Truncated Response
            "symbol": self.ex_trading_pair,
            "bidPrice": "99",
            "askPrice": "101",
        }]

        mock_api.get(regex_url, body=ujson.dumps(mock_response))

        result: Dict[str, float] = self.ev_loop.run_until_complete(
            self.data_source.get_all_mid_prices())

        self.assertEqual(1, len(result))
        self.assertEqual(100, result[self.trading_pair])

    @aioresponses()
    @patch(
        "hummingbot.connector.exchange.binance.binance_utils.convert_from_exchange_trading_pair"
    )
    def test_fetch_trading_pairs(self, mock_api, mock_utils):
        # Mocks binance_utils for BinanceOrderBook.diff_message_from_exchange()
        mock_utils.return_value = self.trading_pair
        url = utils.public_rest_url(path_url=CONSTANTS.EXCHANGE_INFO_PATH_URL,
                                    domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_response: Dict[str, Any] = {
            # Truncated Response
            "symbols": [
                {
                    "symbol": self.ex_trading_pair,
                    "status": "TRADING",
                    "baseAsset": self.base_asset,
                    "quoteAsset": self.quote_asset,
                },
            ]
        }

        mock_api.get(regex_url, body=ujson.dumps(mock_response))

        result: Dict[str] = self.ev_loop.run_until_complete(
            self.data_source.fetch_trading_pairs())

        self.assertEqual(1, len(result))
        self.assertTrue(self.trading_pair in result)

    @aioresponses()
    @patch(
        "hummingbot.connector.exchange.binance.binance_utils.convert_from_exchange_trading_pair"
    )
    def test_fetch_trading_pairs_exception_raised(self, mock_api, mock_utils):
        # Mocks binance_utils for BinanceOrderBook.diff_message_from_exchange()
        mock_utils.return_value = self.trading_pair
        url = utils.public_rest_url(path_url=CONSTANTS.EXCHANGE_INFO_PATH_URL,
                                    domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, exception=Exception)

        result: Dict[str] = self.ev_loop.run_until_complete(
            self.data_source.fetch_trading_pairs())

        self.assertEqual(0, len(result))

    def test_get_throttler_instance(self):
        self.assertIsInstance(
            BinanceAPIOrderBookDataSource._get_throttler_instance(),
            AsyncThrottler)

    @aioresponses()
    def test_get_snapshot_successful(self, mock_api):
        url = utils.public_rest_url(path_url=CONSTANTS.SNAPSHOT_PATH_URL,
                                    domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, body=ujson.dumps(self._snapshot_response()))

        result: Dict[str, Any] = self.ev_loop.run_until_complete(
            self.data_source.get_snapshot(self.trading_pair))

        self.assertEqual(self._snapshot_response(), result)

    @aioresponses()
    def test_get_snapshot_catch_exception(self, mock_api):
        url = utils.public_rest_url(path_url=CONSTANTS.SNAPSHOT_PATH_URL,
                                    domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, status=400)
        with self.assertRaises(IOError):
            self.ev_loop.run_until_complete(
                self.data_source.get_snapshot(self.trading_pair))

    @aioresponses()
    def test_get_new_order_book(self, mock_api):
        url = utils.public_rest_url(path_url=CONSTANTS.SNAPSHOT_PATH_URL,
                                    domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_response: Dict[str, Any] = {
            "lastUpdateId": 1,
            "bids": [["4.00000000", "431.00000000"]],
            "asks": [["4.00000200", "12.00000000"]]
        }
        mock_api.get(regex_url, body=ujson.dumps(mock_response))

        result: OrderBook = self.ev_loop.run_until_complete(
            self.data_source.get_new_order_book(self.trading_pair))

        self.assertEqual(1, result.snapshot_uid)

    @patch("aiohttp.ClientSession.ws_connect")
    def test_create_websocket_connection_cancelled_when_connecting(
            self, mock_ws):
        mock_ws.side_effect = asyncio.CancelledError

        with self.assertRaises(asyncio.CancelledError):
            self.ev_loop.run_until_complete(
                self.data_source._create_websocket_connection())

    @patch("aiohttp.ClientSession.ws_connect")
    def test_create_websocket_connection_exception_raised(self, mock_ws):
        mock_ws.side_effect = Exception("TEST ERROR.")

        with self.assertRaises(Exception):
            self.ev_loop.run_until_complete(
                self.data_source._create_websocket_connection())

        self.assertTrue(
            self._is_logged(
                "NETWORK",
                "Unexpected error occured when connecting to WebSocket server. Error: TEST ERROR."
            ))

    @patch(
        "hummingbot.core.data_type.order_book_tracker_data_source.OrderBookTrackerDataSource._sleep"
    )
    @patch("aiohttp.ClientSession.ws_connect")
    def test_listen_for_trades_cancelled_when_connecting(
            self, mock_ws, _: AsyncMock):
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.side_effect = asyncio.CancelledError

        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.data_source.listen_for_trades(self.ev_loop, msg_queue))
            self.ev_loop.run_until_complete(self.listening_task)

    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_trades_exception_raised_when_connecting(self, mock_ws):
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.side_effect = lambda **_: self._create_exception_and_unlock_test_with_event(
            Exception("TEST ERROR."))

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_trades(self.ev_loop, msg_queue))

        self.ev_loop.run_until_complete(self.resume_test_event.wait())

        self.assertTrue(
            self._is_logged(
                "NETWORK",
                "Unexpected error occured when connecting to WebSocket server. Error: TEST ERROR."
            ))
        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error with WebSocket connection. Retrying after 30 seconds..."
            ))

    @patch(
        "hummingbot.core.data_type.order_book_tracker_data_source.OrderBookTrackerDataSource._sleep"
    )
    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_trades_cancelled_when_listening(
            self, mock_ws, _: AsyncMock):
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        mock_ws.return_value.receive_json.side_effect = lambda: (
            self._raise_exception(asyncio.CancelledError))
        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.data_source.listen_for_trades(self.ev_loop, msg_queue))
            self.ev_loop.run_until_complete(self.listening_task)

    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_trades_logs_exception(self, mock_ws):
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        mock_ws.close.return_value = None

        incomplete_resp = {
            "m": 1,
            "i": 2,
        }
        self.mocking_assistant.add_websocket_json_message(
            mock_ws.return_value, incomplete_resp)
        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_trades(self.ev_loop, msg_queue))

        with self.assertRaises(asyncio.TimeoutError):
            self.ev_loop.run_until_complete(
                asyncio.wait_for(self.listening_task, 1))

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error with WebSocket connection. Retrying after 30 seconds..."
            ))

    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_trades_iter_message_throws_exception(self, mock_ws):
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        mock_ws.return_value.receive_json.side_effect = lambda: self._raise_exception(
            Exception("TEST ERROR"))
        mock_ws.close.return_value = None

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_trades(self.ev_loop, msg_queue))

        with self.assertRaises(asyncio.TimeoutError):
            self.ev_loop.run_until_complete(
                asyncio.wait_for(self.listening_task, 1))
        self.assertTrue(
            self._is_logged(
                "NETWORK",
                "Unexpected error occured when parsing websocket payload. Error: TEST ERROR"
            ))
        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error with WebSocket connection. Retrying after 30 seconds..."
            ))

    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_trades_successful(self, mock_ws):
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        mock_ws.close.return_value = None

        self.mocking_assistant.add_websocket_json_message(
            mock_ws.return_value, self._successfully_subscribed_event())
        self.mocking_assistant.add_websocket_json_message(
            mock_ws.return_value, self._trade_update_event())
        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_trades(self.ev_loop, msg_queue))

        msg: OrderBookMessage = self.ev_loop.run_until_complete(
            msg_queue.get())

        self.assertTrue(12345, msg.trade_id)

    @patch(
        "hummingbot.core.data_type.order_book_tracker_data_source.OrderBookTrackerDataSource._sleep"
    )
    @patch("aiohttp.ClientSession.ws_connect")
    def test_listen_for_order_book_diffs_cancelled_when_connecting(
            self, mock_ws, _: AsyncMock):
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.side_effect = asyncio.CancelledError

        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.data_source.listen_for_order_book_diffs(
                    self.ev_loop, msg_queue))
            self.ev_loop.run_until_complete(self.listening_task)

    @patch(
        "hummingbot.core.data_type.order_book_tracker_data_source.OrderBookTrackerDataSource._sleep"
    )
    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_order_book_diffs_cancelled_when_listening(
            self, mock_ws, _: AsyncMock):
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        mock_ws.return_value.receive_json.side_effect = lambda: (
            self._raise_exception(asyncio.CancelledError))
        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.data_source.listen_for_order_book_diffs(
                    self.ev_loop, msg_queue))
            self.ev_loop.run_until_complete(self.listening_task)

    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_order_book_diffs_logs_exception(self, mock_ws):
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        mock_ws.close.return_value = None

        incomplete_resp = {
            "m": 1,
            "i": 2,
        }
        self.mocking_assistant.add_websocket_json_message(
            mock_ws.return_value, incomplete_resp)
        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_order_book_diffs(
                self.ev_loop, msg_queue))

        with self.assertRaises(asyncio.TimeoutError):
            self.ev_loop.run_until_complete(
                asyncio.wait_for(self.listening_task, 1))

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error with WebSocket connection. Retrying after 30 seconds..."
            ))

    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_order_book_diffs_successful(self, mock_ws):
        msg_queue: asyncio.Queue = asyncio.Queue()
        mock_ws.return_value = self.mocking_assistant.create_websocket_mock()
        mock_ws.close.return_value = None

        self.mocking_assistant.add_websocket_json_message(
            mock_ws.return_value, self._successfully_subscribed_event())
        self.mocking_assistant.add_websocket_json_message(
            mock_ws.return_value, self._order_diff_event())
        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_order_book_diffs(
                self.ev_loop, msg_queue))

        msg: OrderBookMessage = self.ev_loop.run_until_complete(
            msg_queue.get())

        self.assertTrue(12345, msg.update_id)

    @aioresponses()
    def test_listen_for_order_book_snapshots_cancelled_when_fetching_snapshot(
            self, mock_api):
        url = utils.public_rest_url(path_url=CONSTANTS.SNAPSHOT_PATH_URL,
                                    domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, exception=asyncio.CancelledError)

        with self.assertRaises(asyncio.CancelledError):
            self.ev_loop.run_until_complete(
                self.data_source.listen_for_order_book_snapshots(
                    self.ev_loop, asyncio.Queue()))

    @aioresponses()
    def test_listen_for_order_book_snapshots_log_exception(self, mock_api):
        msg_queue: asyncio.Queue = asyncio.Queue()

        url = utils.public_rest_url(path_url=CONSTANTS.SNAPSHOT_PATH_URL,
                                    domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, exception=Exception)

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_order_book_snapshots(
                self.ev_loop, msg_queue))
        with self.assertRaises(asyncio.TimeoutError):
            self.ev_loop.run_until_complete(
                asyncio.wait_for(self.listening_task, 1))

        self.assertTrue(
            self._is_logged(
                "ERROR",
                f"Unexpected error fetching order book snapshot for {self.trading_pair}."
            ))

    @aioresponses()
    def test_listen_for_order_book_snapshots_successful(
        self,
        mock_api,
    ):
        msg_queue: asyncio.Queue = asyncio.Queue()
        url = utils.public_rest_url(path_url=CONSTANTS.SNAPSHOT_PATH_URL,
                                    domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, body=ujson.dumps(self._snapshot_response()))

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_order_book_snapshots(
                self.ev_loop, msg_queue))

        msg: OrderBookMessage = self.ev_loop.run_until_complete(
            msg_queue.get())

        self.assertTrue(12345, msg.update_id)
Exemplo n.º 4
0
class BinanceAPIOrderBookDataSourceUnitTests(unittest.TestCase):
    # logging.Level required to receive logs from the data source logger
    level = 0

    @classmethod
    def setUpClass(cls) -> None:
        super().setUpClass()
        cls.ev_loop = asyncio.get_event_loop()
        cls.base_asset = "COINALPHA"
        cls.quote_asset = "HBOT"
        cls.trading_pair = f"{cls.base_asset}-{cls.quote_asset}"
        cls.ex_trading_pair = cls.base_asset + cls.quote_asset
        cls.domain = "com"

    def setUp(self) -> None:
        super().setUp()
        self.log_records = []
        self.listening_task = None
        self.mocking_assistant = NetworkMockingAssistant()

        client_config_map = ClientConfigAdapter(ClientConfigMap())
        self.connector = BinanceExchange(client_config_map=client_config_map,
                                         binance_api_key="",
                                         binance_api_secret="",
                                         trading_pairs=[],
                                         trading_required=False,
                                         domain=self.domain)
        self.data_source = BinanceAPIOrderBookDataSource(
            trading_pairs=[self.trading_pair],
            connector=self.connector,
            api_factory=self.connector._web_assistants_factory,
            domain=self.domain)
        self.data_source.logger().setLevel(1)
        self.data_source.logger().addHandler(self)

        self.resume_test_event = asyncio.Event()

        self.connector._set_trading_pair_symbol_map(
            bidict({self.ex_trading_pair: self.trading_pair}))

    def tearDown(self) -> None:
        self.listening_task and self.listening_task.cancel()
        super().tearDown()

    def handle(self, record):
        self.log_records.append(record)

    def _is_logged(self, log_level: str, message: str) -> bool:
        return any(
            record.levelname == log_level and record.getMessage() == message
            for record in self.log_records)

    def _create_exception_and_unlock_test_with_event(self, exception):
        self.resume_test_event.set()
        raise exception

    def async_run_with_timeout(self, coroutine: Awaitable, timeout: float = 1):
        ret = self.ev_loop.run_until_complete(
            asyncio.wait_for(coroutine, timeout))
        return ret

    def _successfully_subscribed_event(self):
        resp = {"result": None, "id": 1}
        return resp

    def _trade_update_event(self):
        resp = {
            "e": "trade",
            "E": 123456789,
            "s": self.ex_trading_pair,
            "t": 12345,
            "p": "0.001",
            "q": "100",
            "b": 88,
            "a": 50,
            "T": 123456785,
            "m": True,
            "M": True
        }
        return resp

    def _order_diff_event(self):
        resp = {
            "e": "depthUpdate",
            "E": 123456789,
            "s": self.ex_trading_pair,
            "U": 157,
            "u": 160,
            "b": [["0.0024", "10"]],
            "a": [["0.0026", "100"]]
        }
        return resp

    def _snapshot_response(self):
        resp = {
            "lastUpdateId": 1027024,
            "bids": [["4.00000000", "431.00000000"]],
            "asks": [["4.00000200", "12.00000000"]]
        }
        return resp

    @aioresponses()
    def test_get_new_order_book_successful(self, mock_api):
        url = web_utils.public_rest_url(path_url=CONSTANTS.SNAPSHOT_PATH_URL,
                                        domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        resp = self._snapshot_response()

        mock_api.get(regex_url, body=json.dumps(resp))

        order_book: OrderBook = self.async_run_with_timeout(
            self.data_source.get_new_order_book(self.trading_pair))

        expected_update_id = resp["lastUpdateId"]

        self.assertEqual(expected_update_id, order_book.snapshot_uid)
        bids = list(order_book.bid_entries())
        asks = list(order_book.ask_entries())
        self.assertEqual(1, len(bids))
        self.assertEqual(4, bids[0].price)
        self.assertEqual(431, bids[0].amount)
        self.assertEqual(expected_update_id, bids[0].update_id)
        self.assertEqual(1, len(asks))
        self.assertEqual(4.000002, asks[0].price)
        self.assertEqual(12, asks[0].amount)
        self.assertEqual(expected_update_id, asks[0].update_id)

    @aioresponses()
    def test_get_new_order_book_raises_exception(self, mock_api):
        url = web_utils.public_rest_url(path_url=CONSTANTS.SNAPSHOT_PATH_URL,
                                        domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, status=400)
        with self.assertRaises(IOError):
            self.async_run_with_timeout(
                self.data_source.get_new_order_book(self.trading_pair))

    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_subscriptions_subscribes_to_trades_and_order_diffs(
            self, ws_connect_mock):
        ws_connect_mock.return_value = self.mocking_assistant.create_websocket_mock(
        )

        result_subscribe_trades = {"result": None, "id": 1}
        result_subscribe_diffs = {"result": None, "id": 2}

        self.mocking_assistant.add_websocket_aiohttp_message(
            websocket_mock=ws_connect_mock.return_value,
            message=json.dumps(result_subscribe_trades))
        self.mocking_assistant.add_websocket_aiohttp_message(
            websocket_mock=ws_connect_mock.return_value,
            message=json.dumps(result_subscribe_diffs))

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_subscriptions())

        self.mocking_assistant.run_until_all_aiohttp_messages_delivered(
            ws_connect_mock.return_value)

        sent_subscription_messages = self.mocking_assistant.json_messages_sent_through_websocket(
            websocket_mock=ws_connect_mock.return_value)

        self.assertEqual(2, len(sent_subscription_messages))
        expected_trade_subscription = {
            "method": "SUBSCRIBE",
            "params": [f"{self.ex_trading_pair.lower()}@trade"],
            "id": 1
        }
        self.assertEqual(expected_trade_subscription,
                         sent_subscription_messages[0])
        expected_diff_subscription = {
            "method": "SUBSCRIBE",
            "params": [f"{self.ex_trading_pair.lower()}@depth@100ms"],
            "id": 2
        }
        self.assertEqual(expected_diff_subscription,
                         sent_subscription_messages[1])

        self.assertTrue(
            self._is_logged(
                "INFO",
                "Subscribed to public order book and trade channels..."))

    @patch(
        "hummingbot.core.data_type.order_book_tracker_data_source.OrderBookTrackerDataSource._sleep"
    )
    @patch("aiohttp.ClientSession.ws_connect")
    def test_listen_for_subscriptions_raises_cancel_exception(
            self, mock_ws, _: AsyncMock):
        mock_ws.side_effect = asyncio.CancelledError

        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.data_source.listen_for_subscriptions())
            self.async_run_with_timeout(self.listening_task)

    @patch(
        "hummingbot.core.data_type.order_book_tracker_data_source.OrderBookTrackerDataSource._sleep"
    )
    @patch("aiohttp.ClientSession.ws_connect", new_callable=AsyncMock)
    def test_listen_for_subscriptions_logs_exception_details(
            self, mock_ws, sleep_mock):
        mock_ws.side_effect = Exception("TEST ERROR.")
        sleep_mock.side_effect = lambda _: self._create_exception_and_unlock_test_with_event(
            asyncio.CancelledError())

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_subscriptions())

        self.async_run_with_timeout(self.resume_test_event.wait())

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error occurred when listening to order book streams. Retrying in 5 seconds..."
            ))

    def test_subscribe_channels_raises_cancel_exception(self):
        mock_ws = MagicMock()
        mock_ws.send.side_effect = asyncio.CancelledError

        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.data_source._subscribe_channels(mock_ws))
            self.async_run_with_timeout(self.listening_task)

    def test_subscribe_channels_raises_exception_and_logs_error(self):
        mock_ws = MagicMock()
        mock_ws.send.side_effect = Exception("Test Error")

        with self.assertRaises(Exception):
            self.listening_task = self.ev_loop.create_task(
                self.data_source._subscribe_channels(mock_ws))
            self.async_run_with_timeout(self.listening_task)

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error occurred subscribing to order book trading and delta streams..."
            ))

    def test_listen_for_trades_cancelled_when_listening(self):
        mock_queue = MagicMock()
        mock_queue.get.side_effect = asyncio.CancelledError()
        self.data_source._message_queue[
            CONSTANTS.TRADE_EVENT_TYPE] = mock_queue

        msg_queue: asyncio.Queue = asyncio.Queue()

        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.data_source.listen_for_trades(self.ev_loop, msg_queue))
            self.async_run_with_timeout(self.listening_task)

    def test_listen_for_trades_logs_exception(self):
        incomplete_resp = {
            "m": 1,
            "i": 2,
        }

        mock_queue = AsyncMock()
        mock_queue.get.side_effect = [
            incomplete_resp, asyncio.CancelledError()
        ]
        self.data_source._message_queue[
            CONSTANTS.TRADE_EVENT_TYPE] = mock_queue

        msg_queue: asyncio.Queue = asyncio.Queue()

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_trades(self.ev_loop, msg_queue))

        try:
            self.async_run_with_timeout(self.listening_task)
        except asyncio.CancelledError:
            pass

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error when processing public trade updates from exchange"
            ))

    def test_listen_for_trades_successful(self):
        mock_queue = AsyncMock()
        mock_queue.get.side_effect = [
            self._trade_update_event(),
            asyncio.CancelledError()
        ]
        self.data_source._message_queue[
            CONSTANTS.TRADE_EVENT_TYPE] = mock_queue

        msg_queue: asyncio.Queue = asyncio.Queue()

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_trades(self.ev_loop, msg_queue))

        msg: OrderBookMessage = self.async_run_with_timeout(msg_queue.get())

        self.assertEqual(12345, msg.trade_id)

    def test_listen_for_order_book_diffs_cancelled(self):
        mock_queue = AsyncMock()
        mock_queue.get.side_effect = asyncio.CancelledError()
        self.data_source._message_queue[CONSTANTS.DIFF_EVENT_TYPE] = mock_queue

        msg_queue: asyncio.Queue = asyncio.Queue()

        with self.assertRaises(asyncio.CancelledError):
            self.listening_task = self.ev_loop.create_task(
                self.data_source.listen_for_order_book_diffs(
                    self.ev_loop, msg_queue))
            self.async_run_with_timeout(self.listening_task)

    def test_listen_for_order_book_diffs_logs_exception(self):
        incomplete_resp = {
            "m": 1,
            "i": 2,
        }

        mock_queue = AsyncMock()
        mock_queue.get.side_effect = [
            incomplete_resp, asyncio.CancelledError()
        ]
        self.data_source._message_queue[CONSTANTS.DIFF_EVENT_TYPE] = mock_queue

        msg_queue: asyncio.Queue = asyncio.Queue()

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_order_book_diffs(
                self.ev_loop, msg_queue))

        try:
            self.async_run_with_timeout(self.listening_task)
        except asyncio.CancelledError:
            pass

        self.assertTrue(
            self._is_logged(
                "ERROR",
                "Unexpected error when processing public order book updates from exchange"
            ))

    def test_listen_for_order_book_diffs_successful(self):
        mock_queue = AsyncMock()
        diff_event = self._order_diff_event()
        mock_queue.get.side_effect = [diff_event, asyncio.CancelledError()]
        self.data_source._message_queue[CONSTANTS.DIFF_EVENT_TYPE] = mock_queue

        msg_queue: asyncio.Queue = asyncio.Queue()

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_order_book_diffs(
                self.ev_loop, msg_queue))

        msg: OrderBookMessage = self.async_run_with_timeout(msg_queue.get())

        self.assertEqual(diff_event["u"], msg.update_id)

    @aioresponses()
    def test_listen_for_order_book_snapshots_cancelled_when_fetching_snapshot(
            self, mock_api):
        url = web_utils.public_rest_url(path_url=CONSTANTS.SNAPSHOT_PATH_URL,
                                        domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, exception=asyncio.CancelledError)

        with self.assertRaises(asyncio.CancelledError):
            self.async_run_with_timeout(
                self.data_source.listen_for_order_book_snapshots(
                    self.ev_loop, asyncio.Queue()))

    @aioresponses()
    @patch(
        "hummingbot.connector.exchange.binance.binance_api_order_book_data_source"
        ".BinanceAPIOrderBookDataSource._sleep")
    def test_listen_for_order_book_snapshots_log_exception(
            self, mock_api, sleep_mock):
        msg_queue: asyncio.Queue = asyncio.Queue()
        sleep_mock.side_effect = lambda _: self._create_exception_and_unlock_test_with_event(
            asyncio.CancelledError())

        url = web_utils.public_rest_url(path_url=CONSTANTS.SNAPSHOT_PATH_URL,
                                        domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, exception=Exception)

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_order_book_snapshots(
                self.ev_loop, msg_queue))
        self.async_run_with_timeout(self.resume_test_event.wait())

        self.assertTrue(
            self._is_logged(
                "ERROR",
                f"Unexpected error fetching order book snapshot for {self.trading_pair}."
            ))

    @aioresponses()
    def test_listen_for_order_book_snapshots_successful(
        self,
        mock_api,
    ):
        msg_queue: asyncio.Queue = asyncio.Queue()
        url = web_utils.public_rest_url(path_url=CONSTANTS.SNAPSHOT_PATH_URL,
                                        domain=self.domain)
        regex_url = re.compile(f"^{url}".replace(".",
                                                 r"\.").replace("?", r"\?"))

        mock_api.get(regex_url, body=json.dumps(self._snapshot_response()))

        self.listening_task = self.ev_loop.create_task(
            self.data_source.listen_for_order_book_snapshots(
                self.ev_loop, msg_queue))

        msg: OrderBookMessage = self.async_run_with_timeout(msg_queue.get())

        self.assertEqual(1027024, msg.update_id)