Esempio n. 1
0
    async def test_always_callback_sensor(self):
        """Test always callback sensor."""
        xknx = XKNX()
        sensor = Sensor(
            xknx,
            "TestSensor",
            group_address_state="1/2/3",
            always_callback=False,
            value_type="volume_liquid_litre",
        )
        after_update_callback = AsyncMock()
        sensor.register_device_updated_cb(after_update_callback)
        payload = DPTArray((0x00, 0x00, 0x01, 0x00))
        #  set initial payload of sensor
        sensor.sensor_value.value = 256
        telegram = Telegram(destination_address=GroupAddress("1/2/3"),
                            payload=GroupValueWrite(payload))
        response_telegram = Telegram(
            destination_address=GroupAddress("1/2/3"),
            payload=GroupValueResponse(payload),
        )
        # verify not called when always_callback is False
        await sensor.process(telegram)
        after_update_callback.assert_not_called()
        after_update_callback.reset_mock()

        sensor.always_callback = True
        # verify called when always_callback is True
        await sensor.process(telegram)
        after_update_callback.assert_called_once()
        after_update_callback.reset_mock()

        # verify not called when processing read responses
        await sensor.process(response_telegram)
        after_update_callback.assert_not_called()
Esempio n. 2
0
    async def test_process_passive(self):
        """Test process / reading telegrams from telegram queue. Test if device was updated."""
        xknx = XKNX()
        callback_mock = AsyncMock()

        switch1 = Switch(
            xknx,
            "TestOutlet",
            group_address=["1/2/3", "4/4/4"],
            group_address_state=["1/2/30", "5/5/5"],
            device_updated_cb=callback_mock,
        )
        assert switch1.state is None
        callback_mock.assert_not_called()

        telegram_on_passive = Telegram(
            destination_address=GroupAddress("4/4/4"),
            payload=GroupValueWrite(DPTBinary(1)),
        )
        telegram_off_passive = Telegram(
            destination_address=GroupAddress("5/5/5"),
            payload=GroupValueWrite(DPTBinary(0)),
        )

        await switch1.process(telegram_on_passive)
        assert switch1.state is True
        callback_mock.assert_called_once()
        callback_mock.reset_mock()
        await switch1.process(telegram_off_passive)
        assert switch1.state is False
        callback_mock.assert_called_once()
        callback_mock.reset_mock()
Esempio n. 3
0
async def test_on_join_with_actions_success(cog: MemberJoins, member: Mock) -> None:
    """On Join event with two actions (channel and DM)"""

    send_dm = AsyncMock()
    send_channel = AsyncMock()
    with patch.multiple(cog, _send_dm=send_dm, _send_channel=send_channel):
        await cog.on_member_join(member)
        send_channel.assert_called_once()
        send_dm.assert_called_once()
    async def test_valid_jwt_without_bearer(self):
        app = MockFastAPI()
        auth_middleware.add_auth_middleware(app, 'development')
        call_next = AsyncMock()
        request = Mock()
        request.headers.get.return_value = 'some_jwt'

        _ = await app.middleware_func(request, call_next)

        call_next.assert_called_once()
Esempio n. 5
0
    def test_run_callable(self):
        rapp = AsyncMock()
        with patch("kafkaesk.app.run_app", rapp), patch("kafkaesk.app.cli_parser") as cli_parser:
            args = Mock()
            args.app = "tests.unit.test_app:app_callable"
            args.kafka_settings = None
            cli_parser.parse_args.return_value = args

            run()

            rapp.assert_called_once()
    async def test_exclude_path(self):
        app = MockFastAPI()
        auth_middleware.add_auth_middleware(app, 'development', excludes=['/some/path'])
        call_next = AsyncMock()
        request = Mock()
        request.headers.get.return_value = 'Bearer some_jwt'
        request.url.path = '/some/path'
        
        response = await app.middleware_func(request, call_next)

        call_next.assert_called_once()
Esempio n. 7
0
 async def test_retries_on_connection_failure(self):
     sub = SubscriptionConsumer(
         Application(),
         Subscription("foo", lambda record: 1, "group"),
     )
     run_mock = AsyncMock()
     sleep = AsyncMock()
     run_mock.side_effect = [aiokafka.errors.KafkaConnectionError, StopConsumer]
     with patch.object(sub, "initialize", AsyncMock()), patch.object(
         sub, "finalize", AsyncMock()
     ), patch.object(sub, "_run", run_mock), patch("kafkaesk.subscription.asyncio.sleep", sleep):
         await sub()
         sleep.assert_called_once()
         assert len(run_mock.mock_calls) == 2
    async def test_exclude_regex_path(self, path, expected):
        app = MockFastAPI()
        auth_middleware.add_auth_middleware(app, 'development', excludes_regex=['^/some/path/(\w|\d|-)+/something/?$'])
        call_next = AsyncMock()
        request = Mock()
        request.headers.get.return_value = 'Bearer some_jwt'
        request.url.path = path

        response = await app.middleware_func(request, call_next)

        if expected:
            call_next.assert_called_once()
        else:
            assert response.status_code == 401
Esempio n. 9
0
    def test_player_send_json_rpc_request_throws_connection_closed(self):
        """Tests Player.send_json_rpc_request() when websocket send() throws a ConnectionClosed exception
        """
        websocket = MagicMock()
        send = AsyncMock(side_effect=ConnectionClosed(0, 'Connection Closed'))
        websocket.attach_mock(send, 'send')

        player = Player(None, websocket)
        string = "Player is " + str([player])

        params = dict()
        params['a'] = 'x'
        asyncio.run(player.send_json_rpc_request("general_request", params))
        # No exception should propage as expected
        send.assert_called_once()
Esempio n. 10
0
async def test_callbacks():
    """Test detection and handling of both sync and async callbacks."""
    callback = Mock()
    async_callback = AsyncMock()

    with patch("pysqueezebox.discovery._unpack_discovery_response",
               return_value=RESPONSE):
        protocol = pysqueezebox.discovery.ServerDiscoveryProtocol(callback)
        async_protocol = pysqueezebox.discovery.ServerDiscoveryProtocol(
            async_callback)
        protocol.datagram_received(ADDR, DATA)
        async_protocol.datagram_received(ADDR, DATA)

    callback.assert_called_once()
    async_callback.assert_called_once()
Esempio n. 11
0
async def test_failed_task_returns_exceptions(
    dask_client: DaskClient,
    user_id: UserID,
    project_id: ProjectID,
    cluster_id: ClusterID,
    gpu_image: ImageParams,
    mocked_node_ports: None,
    mocked_user_completed_cb: mock.AsyncMock,
):
    # NOTE: this must be inlined so that the test works,
    # the dask-worker must be able to import the function
    def fake_failing_sidecar_fct(
        docker_auth: DockerBasicAuth,
        service_key: str,
        service_version: str,
        input_data: TaskInputData,
        output_data_keys: TaskOutputDataSchema,
        log_file_url: AnyUrl,
        command: List[str],
    ) -> TaskOutputData:

        raise ValueError(
            "sadly we are failing to execute anything cause we are dumb...")

    await dask_client.send_computation_tasks(
        user_id=user_id,
        project_id=project_id,
        cluster_id=cluster_id,
        tasks=gpu_image.fake_task,
        callback=mocked_user_completed_cb,
        remote_fct=fake_failing_sidecar_fct,
    )
    assert (len(dask_client._taskid_to_future_map) == 1
            ), "dask client did not store the future of the task sent"

    job_id, future = list(dask_client._taskid_to_future_map.items())[0]
    # this waits for the computation to run
    with pytest.raises(ValueError):
        task_result = await future.result(
            timeout=_ALLOW_TIME_FOR_GATEWAY_TO_CREATE_WORKERS)
    await _wait_for_call(mocked_user_completed_cb)
    mocked_user_completed_cb.assert_called_once()
    assert mocked_user_completed_cb.call_args[0][0].job_id == job_id
    assert mocked_user_completed_cb.call_args[0][
        0].state == RunningState.FAILED
    mocked_user_completed_cb.call_args[0][0].msg.find("Traceback")
    mocked_user_completed_cb.call_args[0][0].msg.find("raise ValueError")
Esempio n. 12
0
 async def test_retries_on_connection_failure(self, subscription):
     run_mock = AsyncMock()
     sleep = AsyncMock()
     run_mock.side_effect = [
         aiokafka.errors.KafkaConnectionError, StopConsumer
     ]
     subscription._consumer = MagicMock()
     with patch.object(subscription,
                       "initialize", AsyncMock()), patch.object(
                           subscription, "finalize",
                           AsyncMock()), patch.object(
                               subscription, "_consume", run_mock), patch(
                                   "kafkaesk.consumer.asyncio.sleep",
                                   sleep):
         await subscription()
         sleep.assert_called_once()
         assert len(run_mock.mock_calls) == 2
Esempio n. 13
0
    def test_run(self):
        rapp = AsyncMock()
        with patch("kafkaesk.app.run_app", rapp), patch("kafkaesk.app.cli_parser") as cli_parser:
            args = Mock()
            args.app = "tests.unit.test_app:test_app"
            args.kafka_servers = "foo,bar"
            args.kafka_settings = json.dumps({"foo": "bar"})
            args.topic_prefix = "prefix"
            args.api_version = "api_version"
            cli_parser.parse_args.return_value = args

            run()

            rapp.assert_called_once()
            assert test_app._kafka_servers == ["foo", "bar"]
            assert test_app._kafka_settings == {"foo": "bar"}
            assert test_app._topic_prefix == "prefix"
            assert test_app._kafka_api_version == "api_version"
Esempio n. 14
0
 async def _async_test_effect(name, target=None, called=True):
     async_mocked_start_flow = AsyncMock()
     mocked_bulb.async_start_flow = async_mocked_start_flow
     await hass.services.async_call(
         "light",
         SERVICE_TURN_ON,
         {ATTR_ENTITY_ID: ENTITY_LIGHT, ATTR_EFFECT: name},
         blocking=True,
     )
     if not called:
         return
     async_mocked_start_flow.assert_called_once()
     if target is None:
         return
     args, _ = async_mocked_start_flow.call_args
     flow = args[0]
     assert flow.count == target.count
     assert flow.action == target.action
     assert str(flow.transitions) == str(target.transitions)
Esempio n. 15
0
    async def _async_test_service(
        service,
        data,
        method,
        payload=None,
        domain=DOMAIN,
        failure_side_effect=BulbException,
    ):
        err_count = len(
            [x for x in caplog.records if x.levelno == logging.ERROR])

        # success
        if method.startswith("async_"):
            mocked_method = AsyncMock()
        else:
            mocked_method = MagicMock()
        setattr(mocked_bulb, method, mocked_method)
        await hass.services.async_call(domain, service, data, blocking=True)
        if payload is None:
            mocked_method.assert_called_once()
        elif type(payload) == list:
            mocked_method.assert_called_once_with(*payload)
        else:
            mocked_method.assert_called_once_with(**payload)
        assert (len([x for x in caplog.records
                     if x.levelno == logging.ERROR]) == err_count)

        # failure
        if failure_side_effect:
            if method.startswith("async_"):
                mocked_method = AsyncMock(side_effect=failure_side_effect)
            else:
                mocked_method = MagicMock(side_effect=failure_side_effect)
            setattr(mocked_bulb, method, mocked_method)
            await hass.services.async_call(domain,
                                           service,
                                           data,
                                           blocking=True)
            assert (len([
                x for x in caplog.records if x.levelno == logging.ERROR
            ]) == err_count + 1)
class TestReadDatafeedStrategy:
    """Testing refresh_session_if_unauthorized strategy"""

    @retry(retry=strategy.read_datafeed_retry)
    async def _retryable_coroutine(self, thing):
        await asyncio.sleep(0.00001)
        return thing.go()

    @pytest.mark.asyncio
    async def test_client_error_recreates_datafeed_and_and_tries_again(self):
        self._retry_config = minimal_retry_config_with_attempts(2)
        self._auth_session = Mock()
        self.recreate_datafeed = AsyncMock()
        thing = NoApiExceptionAfterCount(1, status=400)

        value = await self._retryable_coroutine(thing)

        self.recreate_datafeed.assert_called_once()
        assert value is True

    @pytest.mark.asyncio
    async def test_unauthorized_error_refreshes_session_and_and_tries_again(self):
        self._retry_config = minimal_retry_config_with_attempts(2)
        self._auth_session = Mock()
        self._auth_session.refresh = AsyncMock()
        thing = NoApiExceptionAfterCount(1, status=401)

        value = await self._retryable_coroutine(thing)

        self._auth_session.refresh.assert_called_once()
        assert value is True

    @pytest.mark.asyncio
    async def test_unexpected_api_exception_is_raised(self):
        self._retry_config = minimal_retry_config_with_attempts(1)
        thing = NoApiExceptionAfterCount(2, status=404)
        with pytest.raises(ApiException):
            await self._retryable_coroutine(thing)

        assert thing.call_count == 1
Esempio n. 17
0
    async def test_process_reset_after(self, time_travel):
        """Test process / reading telegrams from telegram queue."""
        xknx = XKNX()
        reset_after_sec = 1
        async_after_update_callback = AsyncMock()
        binaryinput = BinarySensor(
            xknx,
            "TestInput",
            "1/2/3",
            reset_after=reset_after_sec,
            device_updated_cb=async_after_update_callback,
        )
        telegram_on = Telegram(
            destination_address=GroupAddress("1/2/3"),
            payload=GroupValueWrite(DPTBinary(1)),
        )

        await binaryinput.process(telegram_on)
        assert binaryinput.state

        await time_travel(reset_after_sec)
        assert not binaryinput.state
        # once for 'on' and once for 'off'
        assert async_after_update_callback.call_count == 2

        async_after_update_callback.reset_mock()
        # multiple telegrams during reset_after time period shall reset timer
        await binaryinput.process(telegram_on)
        async_after_update_callback.assert_called_once()
        await binaryinput.process(telegram_on)
        await binaryinput.process(telegram_on)
        # second and third telegram resets timer but doesn't run callback
        async_after_update_callback.assert_called_once()
        assert binaryinput.state

        await time_travel(reset_after_sec)
        assert not binaryinput.state
        # once for 'on' and once for 'off'
        assert async_after_update_callback.call_count == 2
Esempio n. 18
0
    async def test_process_state(self):
        """Test process / reading telegrams from telegram queue. Test if device was updated."""
        xknx = XKNX()
        callback_mock = AsyncMock()

        switch1 = Switch(
            xknx,
            "TestOutlet",
            group_address="1/2/3",
            group_address_state="1/2/4",
            device_updated_cb=callback_mock,
        )
        switch2 = Switch(
            xknx,
            "TestOutlet",
            group_address="1/2/3",
            group_address_state="1/2/4",
            device_updated_cb=callback_mock,
        )
        assert switch1.state is None
        assert switch2.state is None
        callback_mock.assert_not_called()

        telegram_on = Telegram(
            destination_address=GroupAddress("1/2/4"),
            payload=GroupValueResponse(DPTBinary(1)),
        )
        telegram_off = Telegram(
            destination_address=GroupAddress("1/2/4"),
            payload=GroupValueResponse(DPTBinary(0)),
        )

        await switch1.process(telegram_on)
        assert switch1.state is True
        callback_mock.assert_called_once()
        callback_mock.reset_mock()
        await switch1.process(telegram_off)
        assert switch1.state is False
        callback_mock.assert_called_once()
        callback_mock.reset_mock()
        # test setting switch2 to False with first telegram
        await switch2.process(telegram_off)
        assert switch2.state is False
        callback_mock.assert_called_once()
        callback_mock.reset_mock()
        await switch2.process(telegram_on)
        assert switch2.state is True
        callback_mock.assert_called_once()
        callback_mock.reset_mock()
Esempio n. 19
0
async def test_abort_send_computation_task(
    dask_client: DaskClient,
    user_id: UserID,
    project_id: ProjectID,
    cluster_id: ClusterID,
    image_params: ImageParams,
    mocked_node_ports: None,
    mocked_user_completed_cb: mock.AsyncMock,
):
    # NOTE: this must be inlined so that the test works,
    # the dask-worker must be able to import the function
    def fake_sidecar_fct(
        docker_auth: DockerBasicAuth,
        service_key: str,
        service_version: str,
        input_data: TaskInputData,
        output_data_keys: TaskOutputDataSchema,
        log_file_url: AnyUrl,
        command: List[str],
        expected_annotations: Dict[str, Any],
    ) -> TaskOutputData:
        sub = Sub(TaskCancelEvent.topic_name())
        # get the task data
        worker = get_worker()
        task = worker.tasks.get(worker.get_current_task())
        assert task is not None
        print(f"--> task {task=} started")
        assert task.annotations == expected_annotations
        # sleep a bit in case someone is aborting us
        print("--> waiting for task to be aborted...")
        for msg in sub:
            assert msg
            print(f"--> received cancellation msg: {msg=}")
            cancel_event = TaskCancelEvent.parse_raw(msg)  # type: ignore
            assert cancel_event
            if cancel_event.job_id == task.key:
                print("--> raising cancellation error now")
                raise asyncio.CancelledError("task cancelled")

        return TaskOutputData.parse_obj({"some_output_key": 123})

    await dask_client.send_computation_tasks(
        user_id=user_id,
        project_id=project_id,
        cluster_id=cluster_id,
        tasks=image_params.fake_task,
        callback=mocked_user_completed_cb,
        remote_fct=functools.partial(
            fake_sidecar_fct,
            expected_annotations=image_params.expected_annotations),
    )
    assert (len(dask_client._taskid_to_future_map) == 1
            ), "dask client did not store the future of the task sent"
    # let the task start
    await asyncio.sleep(2)

    # now let's abort the computation
    job_id, future = list(dask_client._taskid_to_future_map.items())[0]
    assert future.key == job_id
    await dask_client.abort_computation_tasks([job_id])
    assert future.cancelled() == True
    await _wait_for_call(mocked_user_completed_cb)
    mocked_user_completed_cb.assert_called_once()
    mocked_user_completed_cb.assert_called_with(
        TaskStateEvent(
            job_id=job_id,
            msg=None,
            state=RunningState.ABORTED,
        ))
    assert (len(dask_client._taskid_to_future_map) == 0
            ), "the list of futures was not cleaned correctly"
Esempio n. 20
0
class AsyncMockAssert(unittest.TestCase):
    def setUp(self):
        self.mock = AsyncMock()

    async def _runnable_test(self, *args, **kwargs):
        await self.mock(*args, **kwargs)

    async def _await_coroutine(self, coroutine):
        return await coroutine

    def test_assert_called_but_not_awaited(self):
        mock = AsyncMock(AsyncClass)
        with self.assertWarns(RuntimeWarning):
            # Will raise a warning because never awaited
            mock.async_method()
        self.assertTrue(asyncio.iscoroutinefunction(mock.async_method))
        mock.async_method.assert_called()
        mock.async_method.assert_called_once()
        mock.async_method.assert_called_once_with()
        with self.assertRaises(AssertionError):
            mock.assert_awaited()
        with self.assertRaises(AssertionError):
            mock.async_method.assert_awaited()

    def test_assert_called_then_awaited(self):
        mock = AsyncMock(AsyncClass)
        mock_coroutine = mock.async_method()
        mock.async_method.assert_called()
        mock.async_method.assert_called_once()
        mock.async_method.assert_called_once_with()
        with self.assertRaises(AssertionError):
            mock.async_method.assert_awaited()

        asyncio.run(self._await_coroutine(mock_coroutine))
        # Assert we haven't re-called the function
        mock.async_method.assert_called_once()
        mock.async_method.assert_awaited()
        mock.async_method.assert_awaited_once()
        mock.async_method.assert_awaited_once_with()

    def test_assert_called_and_awaited_at_same_time(self):
        with self.assertRaises(AssertionError):
            self.mock.assert_awaited()

        with self.assertRaises(AssertionError):
            self.mock.assert_called()

        asyncio.run(self._runnable_test())
        self.mock.assert_called_once()
        self.mock.assert_awaited_once()

    def test_assert_called_twice_and_awaited_once(self):
        mock = AsyncMock(AsyncClass)
        coroutine = mock.async_method()
        with self.assertWarns(RuntimeWarning):
            # The first call will be awaited so no warning there
            # But this call will never get awaited, so it will warn here
            mock.async_method()
        with self.assertRaises(AssertionError):
            mock.async_method.assert_awaited()
        mock.async_method.assert_called()
        asyncio.run(self._await_coroutine(coroutine))
        mock.async_method.assert_awaited()
        mock.async_method.assert_awaited_once()

    def test_assert_called_once_and_awaited_twice(self):
        mock = AsyncMock(AsyncClass)
        coroutine = mock.async_method()
        mock.async_method.assert_called_once()
        asyncio.run(self._await_coroutine(coroutine))
        with self.assertRaises(RuntimeError):
            # Cannot reuse already awaited coroutine
            asyncio.run(self._await_coroutine(coroutine))
        mock.async_method.assert_awaited()

    def test_assert_awaited_but_not_called(self):
        with self.assertRaises(AssertionError):
            self.mock.assert_awaited()
        with self.assertRaises(AssertionError):
            self.mock.assert_called()
        with self.assertRaises(TypeError):
            # You cannot await an AsyncMock, it must be a coroutine
            asyncio.run(self._await_coroutine(self.mock))

        with self.assertRaises(AssertionError):
            self.mock.assert_awaited()
        with self.assertRaises(AssertionError):
            self.mock.assert_called()

    def test_assert_has_calls_not_awaits(self):
        kalls = [call('foo')]
        with self.assertWarns(RuntimeWarning):
            # Will raise a warning because never awaited
            self.mock('foo')
        self.mock.assert_has_calls(kalls)
        with self.assertRaises(AssertionError):
            self.mock.assert_has_awaits(kalls)

    def test_assert_has_mock_calls_on_async_mock_no_spec(self):
        with self.assertWarns(RuntimeWarning):
            # Will raise a warning because never awaited
            self.mock()
        kalls_empty = [('', (), {})]
        self.assertEqual(self.mock.mock_calls, kalls_empty)

        with self.assertWarns(RuntimeWarning):
            # Will raise a warning because never awaited
            self.mock('foo')
            self.mock('baz')
        mock_kalls = ([call(), call('foo'), call('baz')])
        self.assertEqual(self.mock.mock_calls, mock_kalls)

    def test_assert_has_mock_calls_on_async_mock_with_spec(self):
        a_class_mock = AsyncMock(AsyncClass)
        with self.assertWarns(RuntimeWarning):
            # Will raise a warning because never awaited
            a_class_mock.async_method()
        kalls_empty = [('', (), {})]
        self.assertEqual(a_class_mock.async_method.mock_calls, kalls_empty)
        self.assertEqual(a_class_mock.mock_calls, [call.async_method()])

        with self.assertWarns(RuntimeWarning):
            # Will raise a warning because never awaited
            a_class_mock.async_method(1, 2, 3, a=4, b=5)
        method_kalls = [call(), call(1, 2, 3, a=4, b=5)]
        mock_kalls = [
            call.async_method(),
            call.async_method(1, 2, 3, a=4, b=5)
        ]
        self.assertEqual(a_class_mock.async_method.mock_calls, method_kalls)
        self.assertEqual(a_class_mock.mock_calls, mock_kalls)

    def test_async_method_calls_recorded(self):
        with self.assertWarns(RuntimeWarning):
            # Will raise warnings because never awaited
            self.mock.something(3, fish=None)
            self.mock.something_else.something(6, cake=sentinel.Cake)

        self.assertEqual(self.mock.method_calls, [("something", (3, ), {
            'fish': None
        }), ("something_else.something", (6, ), {
            'cake': sentinel.Cake
        })], "method calls not recorded correctly")
        self.assertEqual(self.mock.something_else.method_calls,
                         [("something", (6, ), {
                             'cake': sentinel.Cake
                         })], "method calls not recorded correctly")

    def test_async_arg_lists(self):
        def assert_attrs(mock):
            names = ('call_args_list', 'method_calls', 'mock_calls')
            for name in names:
                attr = getattr(mock, name)
                self.assertIsInstance(attr, _CallList)
                self.assertIsInstance(attr, list)
                self.assertEqual(attr, [])

        assert_attrs(self.mock)
        with self.assertWarns(RuntimeWarning):
            # Will raise warnings because never awaited
            self.mock()
            self.mock(1, 2)
            self.mock(a=3)

        self.mock.reset_mock()
        assert_attrs(self.mock)

        a_mock = AsyncMock(AsyncClass)
        with self.assertWarns(RuntimeWarning):
            # Will raise warnings because never awaited
            a_mock.async_method()
            a_mock.async_method(1, a=3)

        a_mock.reset_mock()
        assert_attrs(a_mock)

    def test_assert_awaited(self):
        with self.assertRaises(AssertionError):
            self.mock.assert_awaited()

        asyncio.run(self._runnable_test())
        self.mock.assert_awaited()

    def test_assert_awaited_once(self):
        with self.assertRaises(AssertionError):
            self.mock.assert_awaited_once()

        asyncio.run(self._runnable_test())
        self.mock.assert_awaited_once()

        asyncio.run(self._runnable_test())
        with self.assertRaises(AssertionError):
            self.mock.assert_awaited_once()

    def test_assert_awaited_with(self):
        msg = 'Not awaited'
        with self.assertRaisesRegex(AssertionError, msg):
            self.mock.assert_awaited_with('foo')

        asyncio.run(self._runnable_test())
        msg = 'expected await not found'
        with self.assertRaisesRegex(AssertionError, msg):
            self.mock.assert_awaited_with('foo')

        asyncio.run(self._runnable_test('foo'))
        self.mock.assert_awaited_with('foo')

        asyncio.run(self._runnable_test('SomethingElse'))
        with self.assertRaises(AssertionError):
            self.mock.assert_awaited_with('foo')

    def test_assert_awaited_once_with(self):
        with self.assertRaises(AssertionError):
            self.mock.assert_awaited_once_with('foo')

        asyncio.run(self._runnable_test('foo'))
        self.mock.assert_awaited_once_with('foo')

        asyncio.run(self._runnable_test('foo'))
        with self.assertRaises(AssertionError):
            self.mock.assert_awaited_once_with('foo')

    def test_assert_any_wait(self):
        with self.assertRaises(AssertionError):
            self.mock.assert_any_await('foo')

        asyncio.run(self._runnable_test('baz'))
        with self.assertRaises(AssertionError):
            self.mock.assert_any_await('foo')

        asyncio.run(self._runnable_test('foo'))
        self.mock.assert_any_await('foo')

        asyncio.run(self._runnable_test('SomethingElse'))
        self.mock.assert_any_await('foo')

    def test_assert_has_awaits_no_order(self):
        calls = [call('foo'), call('baz')]

        with self.assertRaises(AssertionError) as cm:
            self.mock.assert_has_awaits(calls)
        self.assertEqual(len(cm.exception.args), 1)

        asyncio.run(self._runnable_test('foo'))
        with self.assertRaises(AssertionError):
            self.mock.assert_has_awaits(calls)

        asyncio.run(self._runnable_test('foo'))
        with self.assertRaises(AssertionError):
            self.mock.assert_has_awaits(calls)

        asyncio.run(self._runnable_test('baz'))
        self.mock.assert_has_awaits(calls)

        asyncio.run(self._runnable_test('SomethingElse'))
        self.mock.assert_has_awaits(calls)

    def test_assert_has_awaits_ordered(self):
        calls = [call('foo'), call('baz')]
        with self.assertRaises(AssertionError):
            self.mock.assert_has_awaits(calls, any_order=True)

        asyncio.run(self._runnable_test('baz'))
        with self.assertRaises(AssertionError):
            self.mock.assert_has_awaits(calls, any_order=True)

        asyncio.run(self._runnable_test('bamf'))
        with self.assertRaises(AssertionError):
            self.mock.assert_has_awaits(calls, any_order=True)

        asyncio.run(self._runnable_test('foo'))
        self.mock.assert_has_awaits(calls, any_order=True)

        asyncio.run(self._runnable_test('qux'))
        self.mock.assert_has_awaits(calls, any_order=True)

    def test_assert_not_awaited(self):
        self.mock.assert_not_awaited()

        asyncio.run(self._runnable_test())
        with self.assertRaises(AssertionError):
            self.mock.assert_not_awaited()

    def test_assert_has_awaits_not_matching_spec_error(self):
        async def f(x=None):
            pass

        self.mock = AsyncMock(spec=f)
        asyncio.run(self._runnable_test(1))

        with self.assertRaisesRegex(
                AssertionError, '^{}$'.format(
                    re.escape('Awaits not found.\n'
                              'Expected: [call()]\n'
                              'Actual: [call(1)]'))) as cm:
            self.mock.assert_has_awaits([call()])
        self.assertIsNone(cm.exception.__cause__)

        with self.assertRaisesRegex(
                AssertionError, '^{}$'.format(
                    re.escape('Error processing expected awaits.\n'
                              "Errors: [None, TypeError('too many positional "
                              "arguments')]\n"
                              'Expected: [call(), call(1, 2)]\n'
                              'Actual: [call(1)]'))) as cm:
            self.mock.assert_has_awaits([call(), call(1, 2)])
        self.assertIsInstance(cm.exception.__cause__, TypeError)
Esempio n. 21
0
async def test_send_computation_task(
    dask_client: DaskClient,
    user_id: UserID,
    project_id: ProjectID,
    cluster_id: ClusterID,
    image_params: ImageParams,
    mocked_node_ports: None,
    mocked_user_completed_cb: mock.AsyncMock,
):
    # NOTE: this must be inlined so that the test works,
    # the dask-worker must be able to import the function
    def fake_sidecar_fct(
        docker_auth: DockerBasicAuth,
        service_key: str,
        service_version: str,
        input_data: TaskInputData,
        output_data_keys: TaskOutputDataSchema,
        log_file_url: AnyUrl,
        command: List[str],
        expected_annotations: Dict[str, Any],
    ) -> TaskOutputData:
        # sleep a bit in case someone is aborting us
        time.sleep(1)
        # get the task data
        worker = get_worker()
        task = worker.tasks.get(worker.get_current_task())
        assert task is not None
        assert task.annotations == expected_annotations
        return TaskOutputData.parse_obj({"some_output_key": 123})

    # NOTE: We pass another fct so it can run in our localy created dask cluster
    await dask_client.send_computation_tasks(
        user_id=user_id,
        project_id=project_id,
        cluster_id=cluster_id,
        tasks=image_params.fake_task,
        callback=mocked_user_completed_cb,
        remote_fct=functools.partial(
            fake_sidecar_fct,
            expected_annotations=image_params.expected_annotations),
    )
    assert (len(dask_client._taskid_to_future_map) == 1
            ), "dask client did not store the future of the task sent"

    job_id, future = list(dask_client._taskid_to_future_map.items())[0]
    # this waits for the computation to run
    task_result = await future.result(
        timeout=_ALLOW_TIME_FOR_GATEWAY_TO_CREATE_WORKERS)
    assert isinstance(task_result, TaskOutputData)
    assert task_result["some_output_key"] == 123
    assert future.key == job_id
    await _wait_for_call(mocked_user_completed_cb)
    mocked_user_completed_cb.assert_called_once()
    mocked_user_completed_cb.assert_called_with(
        TaskStateEvent(
            job_id=job_id,
            msg=json.dumps({"some_output_key": 123}),
            state=RunningState.SUCCESS,
        ))
    assert (len(dask_client._taskid_to_future_map) == 0
            ), "the list of futures was not cleaned correctly"