async def test_always_callback_sensor(self): """Test always callback sensor.""" xknx = XKNX() sensor = Sensor( xknx, "TestSensor", group_address_state="1/2/3", always_callback=False, value_type="volume_liquid_litre", ) after_update_callback = AsyncMock() sensor.register_device_updated_cb(after_update_callback) payload = DPTArray((0x00, 0x00, 0x01, 0x00)) # set initial payload of sensor sensor.sensor_value.value = 256 telegram = Telegram(destination_address=GroupAddress("1/2/3"), payload=GroupValueWrite(payload)) response_telegram = Telegram( destination_address=GroupAddress("1/2/3"), payload=GroupValueResponse(payload), ) # verify not called when always_callback is False await sensor.process(telegram) after_update_callback.assert_not_called() after_update_callback.reset_mock() sensor.always_callback = True # verify called when always_callback is True await sensor.process(telegram) after_update_callback.assert_called_once() after_update_callback.reset_mock() # verify not called when processing read responses await sensor.process(response_telegram) after_update_callback.assert_not_called()
async def test_process_passive(self): """Test process / reading telegrams from telegram queue. Test if device was updated.""" xknx = XKNX() callback_mock = AsyncMock() switch1 = Switch( xknx, "TestOutlet", group_address=["1/2/3", "4/4/4"], group_address_state=["1/2/30", "5/5/5"], device_updated_cb=callback_mock, ) assert switch1.state is None callback_mock.assert_not_called() telegram_on_passive = Telegram( destination_address=GroupAddress("4/4/4"), payload=GroupValueWrite(DPTBinary(1)), ) telegram_off_passive = Telegram( destination_address=GroupAddress("5/5/5"), payload=GroupValueWrite(DPTBinary(0)), ) await switch1.process(telegram_on_passive) assert switch1.state is True callback_mock.assert_called_once() callback_mock.reset_mock() await switch1.process(telegram_off_passive) assert switch1.state is False callback_mock.assert_called_once() callback_mock.reset_mock()
async def test_on_join_with_actions_success(cog: MemberJoins, member: Mock) -> None: """On Join event with two actions (channel and DM)""" send_dm = AsyncMock() send_channel = AsyncMock() with patch.multiple(cog, _send_dm=send_dm, _send_channel=send_channel): await cog.on_member_join(member) send_channel.assert_called_once() send_dm.assert_called_once()
async def test_valid_jwt_without_bearer(self): app = MockFastAPI() auth_middleware.add_auth_middleware(app, 'development') call_next = AsyncMock() request = Mock() request.headers.get.return_value = 'some_jwt' _ = await app.middleware_func(request, call_next) call_next.assert_called_once()
def test_run_callable(self): rapp = AsyncMock() with patch("kafkaesk.app.run_app", rapp), patch("kafkaesk.app.cli_parser") as cli_parser: args = Mock() args.app = "tests.unit.test_app:app_callable" args.kafka_settings = None cli_parser.parse_args.return_value = args run() rapp.assert_called_once()
async def test_exclude_path(self): app = MockFastAPI() auth_middleware.add_auth_middleware(app, 'development', excludes=['/some/path']) call_next = AsyncMock() request = Mock() request.headers.get.return_value = 'Bearer some_jwt' request.url.path = '/some/path' response = await app.middleware_func(request, call_next) call_next.assert_called_once()
async def test_retries_on_connection_failure(self): sub = SubscriptionConsumer( Application(), Subscription("foo", lambda record: 1, "group"), ) run_mock = AsyncMock() sleep = AsyncMock() run_mock.side_effect = [aiokafka.errors.KafkaConnectionError, StopConsumer] with patch.object(sub, "initialize", AsyncMock()), patch.object( sub, "finalize", AsyncMock() ), patch.object(sub, "_run", run_mock), patch("kafkaesk.subscription.asyncio.sleep", sleep): await sub() sleep.assert_called_once() assert len(run_mock.mock_calls) == 2
async def test_exclude_regex_path(self, path, expected): app = MockFastAPI() auth_middleware.add_auth_middleware(app, 'development', excludes_regex=['^/some/path/(\w|\d|-)+/something/?$']) call_next = AsyncMock() request = Mock() request.headers.get.return_value = 'Bearer some_jwt' request.url.path = path response = await app.middleware_func(request, call_next) if expected: call_next.assert_called_once() else: assert response.status_code == 401
def test_player_send_json_rpc_request_throws_connection_closed(self): """Tests Player.send_json_rpc_request() when websocket send() throws a ConnectionClosed exception """ websocket = MagicMock() send = AsyncMock(side_effect=ConnectionClosed(0, 'Connection Closed')) websocket.attach_mock(send, 'send') player = Player(None, websocket) string = "Player is " + str([player]) params = dict() params['a'] = 'x' asyncio.run(player.send_json_rpc_request("general_request", params)) # No exception should propage as expected send.assert_called_once()
async def test_callbacks(): """Test detection and handling of both sync and async callbacks.""" callback = Mock() async_callback = AsyncMock() with patch("pysqueezebox.discovery._unpack_discovery_response", return_value=RESPONSE): protocol = pysqueezebox.discovery.ServerDiscoveryProtocol(callback) async_protocol = pysqueezebox.discovery.ServerDiscoveryProtocol( async_callback) protocol.datagram_received(ADDR, DATA) async_protocol.datagram_received(ADDR, DATA) callback.assert_called_once() async_callback.assert_called_once()
async def test_failed_task_returns_exceptions( dask_client: DaskClient, user_id: UserID, project_id: ProjectID, cluster_id: ClusterID, gpu_image: ImageParams, mocked_node_ports: None, mocked_user_completed_cb: mock.AsyncMock, ): # NOTE: this must be inlined so that the test works, # the dask-worker must be able to import the function def fake_failing_sidecar_fct( docker_auth: DockerBasicAuth, service_key: str, service_version: str, input_data: TaskInputData, output_data_keys: TaskOutputDataSchema, log_file_url: AnyUrl, command: List[str], ) -> TaskOutputData: raise ValueError( "sadly we are failing to execute anything cause we are dumb...") await dask_client.send_computation_tasks( user_id=user_id, project_id=project_id, cluster_id=cluster_id, tasks=gpu_image.fake_task, callback=mocked_user_completed_cb, remote_fct=fake_failing_sidecar_fct, ) assert (len(dask_client._taskid_to_future_map) == 1 ), "dask client did not store the future of the task sent" job_id, future = list(dask_client._taskid_to_future_map.items())[0] # this waits for the computation to run with pytest.raises(ValueError): task_result = await future.result( timeout=_ALLOW_TIME_FOR_GATEWAY_TO_CREATE_WORKERS) await _wait_for_call(mocked_user_completed_cb) mocked_user_completed_cb.assert_called_once() assert mocked_user_completed_cb.call_args[0][0].job_id == job_id assert mocked_user_completed_cb.call_args[0][ 0].state == RunningState.FAILED mocked_user_completed_cb.call_args[0][0].msg.find("Traceback") mocked_user_completed_cb.call_args[0][0].msg.find("raise ValueError")
async def test_retries_on_connection_failure(self, subscription): run_mock = AsyncMock() sleep = AsyncMock() run_mock.side_effect = [ aiokafka.errors.KafkaConnectionError, StopConsumer ] subscription._consumer = MagicMock() with patch.object(subscription, "initialize", AsyncMock()), patch.object( subscription, "finalize", AsyncMock()), patch.object( subscription, "_consume", run_mock), patch( "kafkaesk.consumer.asyncio.sleep", sleep): await subscription() sleep.assert_called_once() assert len(run_mock.mock_calls) == 2
def test_run(self): rapp = AsyncMock() with patch("kafkaesk.app.run_app", rapp), patch("kafkaesk.app.cli_parser") as cli_parser: args = Mock() args.app = "tests.unit.test_app:test_app" args.kafka_servers = "foo,bar" args.kafka_settings = json.dumps({"foo": "bar"}) args.topic_prefix = "prefix" args.api_version = "api_version" cli_parser.parse_args.return_value = args run() rapp.assert_called_once() assert test_app._kafka_servers == ["foo", "bar"] assert test_app._kafka_settings == {"foo": "bar"} assert test_app._topic_prefix == "prefix" assert test_app._kafka_api_version == "api_version"
async def _async_test_effect(name, target=None, called=True): async_mocked_start_flow = AsyncMock() mocked_bulb.async_start_flow = async_mocked_start_flow await hass.services.async_call( "light", SERVICE_TURN_ON, {ATTR_ENTITY_ID: ENTITY_LIGHT, ATTR_EFFECT: name}, blocking=True, ) if not called: return async_mocked_start_flow.assert_called_once() if target is None: return args, _ = async_mocked_start_flow.call_args flow = args[0] assert flow.count == target.count assert flow.action == target.action assert str(flow.transitions) == str(target.transitions)
async def _async_test_service( service, data, method, payload=None, domain=DOMAIN, failure_side_effect=BulbException, ): err_count = len( [x for x in caplog.records if x.levelno == logging.ERROR]) # success if method.startswith("async_"): mocked_method = AsyncMock() else: mocked_method = MagicMock() setattr(mocked_bulb, method, mocked_method) await hass.services.async_call(domain, service, data, blocking=True) if payload is None: mocked_method.assert_called_once() elif type(payload) == list: mocked_method.assert_called_once_with(*payload) else: mocked_method.assert_called_once_with(**payload) assert (len([x for x in caplog.records if x.levelno == logging.ERROR]) == err_count) # failure if failure_side_effect: if method.startswith("async_"): mocked_method = AsyncMock(side_effect=failure_side_effect) else: mocked_method = MagicMock(side_effect=failure_side_effect) setattr(mocked_bulb, method, mocked_method) await hass.services.async_call(domain, service, data, blocking=True) assert (len([ x for x in caplog.records if x.levelno == logging.ERROR ]) == err_count + 1)
class TestReadDatafeedStrategy: """Testing refresh_session_if_unauthorized strategy""" @retry(retry=strategy.read_datafeed_retry) async def _retryable_coroutine(self, thing): await asyncio.sleep(0.00001) return thing.go() @pytest.mark.asyncio async def test_client_error_recreates_datafeed_and_and_tries_again(self): self._retry_config = minimal_retry_config_with_attempts(2) self._auth_session = Mock() self.recreate_datafeed = AsyncMock() thing = NoApiExceptionAfterCount(1, status=400) value = await self._retryable_coroutine(thing) self.recreate_datafeed.assert_called_once() assert value is True @pytest.mark.asyncio async def test_unauthorized_error_refreshes_session_and_and_tries_again(self): self._retry_config = minimal_retry_config_with_attempts(2) self._auth_session = Mock() self._auth_session.refresh = AsyncMock() thing = NoApiExceptionAfterCount(1, status=401) value = await self._retryable_coroutine(thing) self._auth_session.refresh.assert_called_once() assert value is True @pytest.mark.asyncio async def test_unexpected_api_exception_is_raised(self): self._retry_config = minimal_retry_config_with_attempts(1) thing = NoApiExceptionAfterCount(2, status=404) with pytest.raises(ApiException): await self._retryable_coroutine(thing) assert thing.call_count == 1
async def test_process_reset_after(self, time_travel): """Test process / reading telegrams from telegram queue.""" xknx = XKNX() reset_after_sec = 1 async_after_update_callback = AsyncMock() binaryinput = BinarySensor( xknx, "TestInput", "1/2/3", reset_after=reset_after_sec, device_updated_cb=async_after_update_callback, ) telegram_on = Telegram( destination_address=GroupAddress("1/2/3"), payload=GroupValueWrite(DPTBinary(1)), ) await binaryinput.process(telegram_on) assert binaryinput.state await time_travel(reset_after_sec) assert not binaryinput.state # once for 'on' and once for 'off' assert async_after_update_callback.call_count == 2 async_after_update_callback.reset_mock() # multiple telegrams during reset_after time period shall reset timer await binaryinput.process(telegram_on) async_after_update_callback.assert_called_once() await binaryinput.process(telegram_on) await binaryinput.process(telegram_on) # second and third telegram resets timer but doesn't run callback async_after_update_callback.assert_called_once() assert binaryinput.state await time_travel(reset_after_sec) assert not binaryinput.state # once for 'on' and once for 'off' assert async_after_update_callback.call_count == 2
async def test_process_state(self): """Test process / reading telegrams from telegram queue. Test if device was updated.""" xknx = XKNX() callback_mock = AsyncMock() switch1 = Switch( xknx, "TestOutlet", group_address="1/2/3", group_address_state="1/2/4", device_updated_cb=callback_mock, ) switch2 = Switch( xknx, "TestOutlet", group_address="1/2/3", group_address_state="1/2/4", device_updated_cb=callback_mock, ) assert switch1.state is None assert switch2.state is None callback_mock.assert_not_called() telegram_on = Telegram( destination_address=GroupAddress("1/2/4"), payload=GroupValueResponse(DPTBinary(1)), ) telegram_off = Telegram( destination_address=GroupAddress("1/2/4"), payload=GroupValueResponse(DPTBinary(0)), ) await switch1.process(telegram_on) assert switch1.state is True callback_mock.assert_called_once() callback_mock.reset_mock() await switch1.process(telegram_off) assert switch1.state is False callback_mock.assert_called_once() callback_mock.reset_mock() # test setting switch2 to False with first telegram await switch2.process(telegram_off) assert switch2.state is False callback_mock.assert_called_once() callback_mock.reset_mock() await switch2.process(telegram_on) assert switch2.state is True callback_mock.assert_called_once() callback_mock.reset_mock()
async def test_abort_send_computation_task( dask_client: DaskClient, user_id: UserID, project_id: ProjectID, cluster_id: ClusterID, image_params: ImageParams, mocked_node_ports: None, mocked_user_completed_cb: mock.AsyncMock, ): # NOTE: this must be inlined so that the test works, # the dask-worker must be able to import the function def fake_sidecar_fct( docker_auth: DockerBasicAuth, service_key: str, service_version: str, input_data: TaskInputData, output_data_keys: TaskOutputDataSchema, log_file_url: AnyUrl, command: List[str], expected_annotations: Dict[str, Any], ) -> TaskOutputData: sub = Sub(TaskCancelEvent.topic_name()) # get the task data worker = get_worker() task = worker.tasks.get(worker.get_current_task()) assert task is not None print(f"--> task {task=} started") assert task.annotations == expected_annotations # sleep a bit in case someone is aborting us print("--> waiting for task to be aborted...") for msg in sub: assert msg print(f"--> received cancellation msg: {msg=}") cancel_event = TaskCancelEvent.parse_raw(msg) # type: ignore assert cancel_event if cancel_event.job_id == task.key: print("--> raising cancellation error now") raise asyncio.CancelledError("task cancelled") return TaskOutputData.parse_obj({"some_output_key": 123}) await dask_client.send_computation_tasks( user_id=user_id, project_id=project_id, cluster_id=cluster_id, tasks=image_params.fake_task, callback=mocked_user_completed_cb, remote_fct=functools.partial( fake_sidecar_fct, expected_annotations=image_params.expected_annotations), ) assert (len(dask_client._taskid_to_future_map) == 1 ), "dask client did not store the future of the task sent" # let the task start await asyncio.sleep(2) # now let's abort the computation job_id, future = list(dask_client._taskid_to_future_map.items())[0] assert future.key == job_id await dask_client.abort_computation_tasks([job_id]) assert future.cancelled() == True await _wait_for_call(mocked_user_completed_cb) mocked_user_completed_cb.assert_called_once() mocked_user_completed_cb.assert_called_with( TaskStateEvent( job_id=job_id, msg=None, state=RunningState.ABORTED, )) assert (len(dask_client._taskid_to_future_map) == 0 ), "the list of futures was not cleaned correctly"
class AsyncMockAssert(unittest.TestCase): def setUp(self): self.mock = AsyncMock() async def _runnable_test(self, *args, **kwargs): await self.mock(*args, **kwargs) async def _await_coroutine(self, coroutine): return await coroutine def test_assert_called_but_not_awaited(self): mock = AsyncMock(AsyncClass) with self.assertWarns(RuntimeWarning): # Will raise a warning because never awaited mock.async_method() self.assertTrue(asyncio.iscoroutinefunction(mock.async_method)) mock.async_method.assert_called() mock.async_method.assert_called_once() mock.async_method.assert_called_once_with() with self.assertRaises(AssertionError): mock.assert_awaited() with self.assertRaises(AssertionError): mock.async_method.assert_awaited() def test_assert_called_then_awaited(self): mock = AsyncMock(AsyncClass) mock_coroutine = mock.async_method() mock.async_method.assert_called() mock.async_method.assert_called_once() mock.async_method.assert_called_once_with() with self.assertRaises(AssertionError): mock.async_method.assert_awaited() asyncio.run(self._await_coroutine(mock_coroutine)) # Assert we haven't re-called the function mock.async_method.assert_called_once() mock.async_method.assert_awaited() mock.async_method.assert_awaited_once() mock.async_method.assert_awaited_once_with() def test_assert_called_and_awaited_at_same_time(self): with self.assertRaises(AssertionError): self.mock.assert_awaited() with self.assertRaises(AssertionError): self.mock.assert_called() asyncio.run(self._runnable_test()) self.mock.assert_called_once() self.mock.assert_awaited_once() def test_assert_called_twice_and_awaited_once(self): mock = AsyncMock(AsyncClass) coroutine = mock.async_method() with self.assertWarns(RuntimeWarning): # The first call will be awaited so no warning there # But this call will never get awaited, so it will warn here mock.async_method() with self.assertRaises(AssertionError): mock.async_method.assert_awaited() mock.async_method.assert_called() asyncio.run(self._await_coroutine(coroutine)) mock.async_method.assert_awaited() mock.async_method.assert_awaited_once() def test_assert_called_once_and_awaited_twice(self): mock = AsyncMock(AsyncClass) coroutine = mock.async_method() mock.async_method.assert_called_once() asyncio.run(self._await_coroutine(coroutine)) with self.assertRaises(RuntimeError): # Cannot reuse already awaited coroutine asyncio.run(self._await_coroutine(coroutine)) mock.async_method.assert_awaited() def test_assert_awaited_but_not_called(self): with self.assertRaises(AssertionError): self.mock.assert_awaited() with self.assertRaises(AssertionError): self.mock.assert_called() with self.assertRaises(TypeError): # You cannot await an AsyncMock, it must be a coroutine asyncio.run(self._await_coroutine(self.mock)) with self.assertRaises(AssertionError): self.mock.assert_awaited() with self.assertRaises(AssertionError): self.mock.assert_called() def test_assert_has_calls_not_awaits(self): kalls = [call('foo')] with self.assertWarns(RuntimeWarning): # Will raise a warning because never awaited self.mock('foo') self.mock.assert_has_calls(kalls) with self.assertRaises(AssertionError): self.mock.assert_has_awaits(kalls) def test_assert_has_mock_calls_on_async_mock_no_spec(self): with self.assertWarns(RuntimeWarning): # Will raise a warning because never awaited self.mock() kalls_empty = [('', (), {})] self.assertEqual(self.mock.mock_calls, kalls_empty) with self.assertWarns(RuntimeWarning): # Will raise a warning because never awaited self.mock('foo') self.mock('baz') mock_kalls = ([call(), call('foo'), call('baz')]) self.assertEqual(self.mock.mock_calls, mock_kalls) def test_assert_has_mock_calls_on_async_mock_with_spec(self): a_class_mock = AsyncMock(AsyncClass) with self.assertWarns(RuntimeWarning): # Will raise a warning because never awaited a_class_mock.async_method() kalls_empty = [('', (), {})] self.assertEqual(a_class_mock.async_method.mock_calls, kalls_empty) self.assertEqual(a_class_mock.mock_calls, [call.async_method()]) with self.assertWarns(RuntimeWarning): # Will raise a warning because never awaited a_class_mock.async_method(1, 2, 3, a=4, b=5) method_kalls = [call(), call(1, 2, 3, a=4, b=5)] mock_kalls = [ call.async_method(), call.async_method(1, 2, 3, a=4, b=5) ] self.assertEqual(a_class_mock.async_method.mock_calls, method_kalls) self.assertEqual(a_class_mock.mock_calls, mock_kalls) def test_async_method_calls_recorded(self): with self.assertWarns(RuntimeWarning): # Will raise warnings because never awaited self.mock.something(3, fish=None) self.mock.something_else.something(6, cake=sentinel.Cake) self.assertEqual(self.mock.method_calls, [("something", (3, ), { 'fish': None }), ("something_else.something", (6, ), { 'cake': sentinel.Cake })], "method calls not recorded correctly") self.assertEqual(self.mock.something_else.method_calls, [("something", (6, ), { 'cake': sentinel.Cake })], "method calls not recorded correctly") def test_async_arg_lists(self): def assert_attrs(mock): names = ('call_args_list', 'method_calls', 'mock_calls') for name in names: attr = getattr(mock, name) self.assertIsInstance(attr, _CallList) self.assertIsInstance(attr, list) self.assertEqual(attr, []) assert_attrs(self.mock) with self.assertWarns(RuntimeWarning): # Will raise warnings because never awaited self.mock() self.mock(1, 2) self.mock(a=3) self.mock.reset_mock() assert_attrs(self.mock) a_mock = AsyncMock(AsyncClass) with self.assertWarns(RuntimeWarning): # Will raise warnings because never awaited a_mock.async_method() a_mock.async_method(1, a=3) a_mock.reset_mock() assert_attrs(a_mock) def test_assert_awaited(self): with self.assertRaises(AssertionError): self.mock.assert_awaited() asyncio.run(self._runnable_test()) self.mock.assert_awaited() def test_assert_awaited_once(self): with self.assertRaises(AssertionError): self.mock.assert_awaited_once() asyncio.run(self._runnable_test()) self.mock.assert_awaited_once() asyncio.run(self._runnable_test()) with self.assertRaises(AssertionError): self.mock.assert_awaited_once() def test_assert_awaited_with(self): msg = 'Not awaited' with self.assertRaisesRegex(AssertionError, msg): self.mock.assert_awaited_with('foo') asyncio.run(self._runnable_test()) msg = 'expected await not found' with self.assertRaisesRegex(AssertionError, msg): self.mock.assert_awaited_with('foo') asyncio.run(self._runnable_test('foo')) self.mock.assert_awaited_with('foo') asyncio.run(self._runnable_test('SomethingElse')) with self.assertRaises(AssertionError): self.mock.assert_awaited_with('foo') def test_assert_awaited_once_with(self): with self.assertRaises(AssertionError): self.mock.assert_awaited_once_with('foo') asyncio.run(self._runnable_test('foo')) self.mock.assert_awaited_once_with('foo') asyncio.run(self._runnable_test('foo')) with self.assertRaises(AssertionError): self.mock.assert_awaited_once_with('foo') def test_assert_any_wait(self): with self.assertRaises(AssertionError): self.mock.assert_any_await('foo') asyncio.run(self._runnable_test('baz')) with self.assertRaises(AssertionError): self.mock.assert_any_await('foo') asyncio.run(self._runnable_test('foo')) self.mock.assert_any_await('foo') asyncio.run(self._runnable_test('SomethingElse')) self.mock.assert_any_await('foo') def test_assert_has_awaits_no_order(self): calls = [call('foo'), call('baz')] with self.assertRaises(AssertionError) as cm: self.mock.assert_has_awaits(calls) self.assertEqual(len(cm.exception.args), 1) asyncio.run(self._runnable_test('foo')) with self.assertRaises(AssertionError): self.mock.assert_has_awaits(calls) asyncio.run(self._runnable_test('foo')) with self.assertRaises(AssertionError): self.mock.assert_has_awaits(calls) asyncio.run(self._runnable_test('baz')) self.mock.assert_has_awaits(calls) asyncio.run(self._runnable_test('SomethingElse')) self.mock.assert_has_awaits(calls) def test_assert_has_awaits_ordered(self): calls = [call('foo'), call('baz')] with self.assertRaises(AssertionError): self.mock.assert_has_awaits(calls, any_order=True) asyncio.run(self._runnable_test('baz')) with self.assertRaises(AssertionError): self.mock.assert_has_awaits(calls, any_order=True) asyncio.run(self._runnable_test('bamf')) with self.assertRaises(AssertionError): self.mock.assert_has_awaits(calls, any_order=True) asyncio.run(self._runnable_test('foo')) self.mock.assert_has_awaits(calls, any_order=True) asyncio.run(self._runnable_test('qux')) self.mock.assert_has_awaits(calls, any_order=True) def test_assert_not_awaited(self): self.mock.assert_not_awaited() asyncio.run(self._runnable_test()) with self.assertRaises(AssertionError): self.mock.assert_not_awaited() def test_assert_has_awaits_not_matching_spec_error(self): async def f(x=None): pass self.mock = AsyncMock(spec=f) asyncio.run(self._runnable_test(1)) with self.assertRaisesRegex( AssertionError, '^{}$'.format( re.escape('Awaits not found.\n' 'Expected: [call()]\n' 'Actual: [call(1)]'))) as cm: self.mock.assert_has_awaits([call()]) self.assertIsNone(cm.exception.__cause__) with self.assertRaisesRegex( AssertionError, '^{}$'.format( re.escape('Error processing expected awaits.\n' "Errors: [None, TypeError('too many positional " "arguments')]\n" 'Expected: [call(), call(1, 2)]\n' 'Actual: [call(1)]'))) as cm: self.mock.assert_has_awaits([call(), call(1, 2)]) self.assertIsInstance(cm.exception.__cause__, TypeError)
async def test_send_computation_task( dask_client: DaskClient, user_id: UserID, project_id: ProjectID, cluster_id: ClusterID, image_params: ImageParams, mocked_node_ports: None, mocked_user_completed_cb: mock.AsyncMock, ): # NOTE: this must be inlined so that the test works, # the dask-worker must be able to import the function def fake_sidecar_fct( docker_auth: DockerBasicAuth, service_key: str, service_version: str, input_data: TaskInputData, output_data_keys: TaskOutputDataSchema, log_file_url: AnyUrl, command: List[str], expected_annotations: Dict[str, Any], ) -> TaskOutputData: # sleep a bit in case someone is aborting us time.sleep(1) # get the task data worker = get_worker() task = worker.tasks.get(worker.get_current_task()) assert task is not None assert task.annotations == expected_annotations return TaskOutputData.parse_obj({"some_output_key": 123}) # NOTE: We pass another fct so it can run in our localy created dask cluster await dask_client.send_computation_tasks( user_id=user_id, project_id=project_id, cluster_id=cluster_id, tasks=image_params.fake_task, callback=mocked_user_completed_cb, remote_fct=functools.partial( fake_sidecar_fct, expected_annotations=image_params.expected_annotations), ) assert (len(dask_client._taskid_to_future_map) == 1 ), "dask client did not store the future of the task sent" job_id, future = list(dask_client._taskid_to_future_map.items())[0] # this waits for the computation to run task_result = await future.result( timeout=_ALLOW_TIME_FOR_GATEWAY_TO_CREATE_WORKERS) assert isinstance(task_result, TaskOutputData) assert task_result["some_output_key"] == 123 assert future.key == job_id await _wait_for_call(mocked_user_completed_cb) mocked_user_completed_cb.assert_called_once() mocked_user_completed_cb.assert_called_with( TaskStateEvent( job_id=job_id, msg=json.dumps({"some_output_key": 123}), state=RunningState.SUCCESS, )) assert (len(dask_client._taskid_to_future_map) == 0 ), "the list of futures was not cleaned correctly"