def test_run_context_provider_registry_register(): provider_class = mock.Mock() registry = RunContextProviderRegistry() registry.register(provider_class) assert set(registry) == {provider_class.return_value}
def test_run_context_provider_registry_register_entrypoints(): provider_class = mock.Mock() mock_entrypoint = mock.Mock() mock_entrypoint.load.return_value = provider_class with mock.patch("entrypoints.get_group_all", return_value=[mock_entrypoint]) as mock_get_group_all: registry = RunContextProviderRegistry() registry.register_entrypoints() assert set(registry) == {provider_class.return_value} mock_entrypoint.load.assert_called_once_with() mock_get_group_all.assert_called_once_with("mlflow.run_context_provider")
def test_run_context_provider_registry_register_entrypoints_handles_exception( exception): mock_entrypoint = mock.Mock() mock_entrypoint.load.side_effect = exception with mock.patch("entrypoints.get_group_all", return_value=[mock_entrypoint]) as mock_get_group_all: registry = RunContextProviderRegistry() # Check that the raised warning contains the message from the original exception with pytest.warns(UserWarning, match="test exception"): registry.register_entrypoints() mock_entrypoint.load.assert_called_once_with() mock_get_group_all.assert_called_once_with("mlflow.run_context_provider")