def _handle_er_threading_fail(self, method: TEventHandler): # Setup: Create an OE service with a session preloaded oe, session, session_uri = self._preloaded_oe_service() # ... Patch the threading to throw patch_mock = mock.MagicMock(side_effect=Exception('Boom!')) patch_path = 'pgsqltoolsservice.object_explorer.object_explorer_service.threading.Thread' with mock.patch(patch_path, patch_mock): # If: I expand a node (with threading that throws) rc = RequestFlowValidator() rc.add_expected_response(bool, self.assertTrue) rc.add_expected_notification( ExpandCompletedParameters, EXPAND_COMPLETED_METHOD, lambda param: self._validate_expand_error( param, session_uri, '/')) params = ExpandParameters.from_dict({ 'session_id': session_uri, 'node_path': '/' }) method(oe, rc.request_context, params) # Then: # ... The error notification should have been returned rc.validate() # ... The session should not have an expand task defined self.assertDictEqual(session.expand_tasks, {}) self.assertDictEqual(session.refresh_tasks, {})
def test_handle_create_session_successful(self): # Setup: # ... Create OE service with mock connection service that returns a successful connection response mock_connection = MockPGServerConnection(cur=None, host='myserver', name='postgres', user='******', port=123) cs = ConnectionService() cs.connect = mock.MagicMock(return_value=ConnectionCompleteParams()) cs.get_connection = mock.MagicMock(return_value=mock_connection) oe = ObjectExplorerService() oe._service_provider = utils.get_mock_service_provider( {constants.CONNECTION_SERVICE_NAME: cs}) oe._provider = constants.PG_PROVIDER_NAME oe._server = Server # ... Create parameters, session, request context validator params, session_uri = _connection_details() # ... Create validation of success notification def validate_success_notification(response: SessionCreatedParameters): self.assertTrue(response.success) self.assertEqual(response.session_id, session_uri) self.assertIsNone(response.error_message) self.assertIsInstance(response.root_node, NodeInfo) self.assertEqual(response.root_node.label, TEST_DBNAME) self.assertEqual(response.root_node.node_path, session_uri) self.assertEqual(response.root_node.node_type, 'Database') self.assertIsInstance(response.root_node.metadata, ObjectMetadata) self.assertEqual(response.root_node.metadata.urn, oe._session_map[session_uri].server.urn_base) self.assertEqual( response.root_node.metadata.name, oe._session_map[session_uri].server.maintenance_db_name) self.assertEqual(response.root_node.metadata.metadata_type_name, 'Database') self.assertFalse(response.root_node.is_leaf) rc = RequestFlowValidator() rc.add_expected_response( CreateSessionResponse, lambda param: self.assertEqual(param.session_id, session_uri)) rc.add_expected_notification(SessionCreatedParameters, SESSION_CREATED_METHOD, validate_success_notification) # If: I create a session oe._handle_create_session_request(rc.request_context, params) oe._session_map[session_uri].init_task.join() # Then: # ... Error notification should have been returned, session should be cleaned up from OE service rc.validate() # ... The session should still exist and should have connection and server setup self.assertIn(session_uri, oe._session_map) self.assertIsInstance(oe._session_map[session_uri].server, Server) self.assertTrue(oe._session_map[session_uri].is_ready)
def _handle_er_incomplete_params(method: TEventHandler): # Setup: # ... Create an OE service oe = ObjectExplorerService() oe._service_provider = utils.get_mock_service_provider({}) # ... Create a set of invalid parameters to test param_sets = [ None, ExpandParameters.from_dict({ 'session_id': None, 'node_path': '/' }), ExpandParameters.from_dict({ 'session_id': 'session', 'node_path': None }) ] for params in param_sets: # If: I expand with an invalid set of parameters rc = RequestFlowValidator().add_expected_error( type(None), RequestFlowValidator.basic_error_validation) method(oe, rc.request_context, params) # Then: I should get an error response rc.validate()
def test_init_session_failed_connection(self): # Setup: # ... Create OE service with mock connection service that returns a failed connection response cs = ConnectionService() connect_response = ConnectionCompleteParams() connect_response.error_message = 'Boom!' cs.connect = mock.MagicMock(return_value=connect_response) oe = ObjectExplorerService() oe._service_provider = utils.get_mock_service_provider( {constants.CONNECTION_SERVICE_NAME: cs}) # If: I initialize a session (NOTE: We're bypassing request handler to avoid threading issues) params, session_uri = _connection_details() session = ObjectExplorerSession(session_uri, params) oe._session_map[session_uri] = session rc = RequestFlowValidator() rc.add_expected_notification( SessionCreatedParameters, SESSION_CREATED_METHOD, lambda param: self._validate_init_error(param, session_uri)) oe._initialize_session(rc.request_context, session) # Then: # ... Error notification should have been returned, session should be cleaned up from OE service rc.validate() self.assertDictEqual(oe._session_map, {})
def test_handle_create_session_threading_fail(self): # Setup: # ... Create an OE service oe = ObjectExplorerService() oe._service_provider = utils.get_mock_service_provider({}) # ... Patch the threading to throw patch_mock = mock.MagicMock(side_effect=Exception('Boom!')) patch_path = 'pgsqltoolsservice.object_explorer.object_explorer_service.threading.Thread' with mock.patch(patch_path, patch_mock): # If: I create a new session params, session_uri = _connection_details() rc = RequestFlowValidator() rc.add_expected_response( CreateSessionResponse, lambda param: self.assertEqual(param.session_id, session_uri)) rc.add_expected_notification( SessionCreatedParameters, SESSION_CREATED_METHOD, lambda param: self._validate_init_error(param, session_uri)) oe._handle_create_session_request(rc.request_context, params) # Then: # ... The error notification should have been returned rc.validate() # ... The session should have been cleaned up self.assertDictEqual(oe._session_map, {})
def _handle_er_exception_expanding(self, method: TEventHandler, get_tasks: TGetTask): # Setup: Create an OE service with a session preloaded oe, session, session_uri = self._preloaded_oe_service() # ... Patch the route_request to throw # ... Patch the threading to throw patch_mock = mock.MagicMock(side_effect=Exception('Boom!')) patch_path = 'pgsqltoolsservice.object_explorer.object_explorer_service.route_request' with mock.patch(patch_path, patch_mock): # If: I expand a node (with route_request that throws) rc = RequestFlowValidator() rc.add_expected_response(bool, self.assertTrue) rc.add_expected_notification( ExpandCompletedParameters, EXPAND_COMPLETED_METHOD, lambda param: self._validate_expand_error( param, session_uri, '/')) params = ExpandParameters.from_dict({ 'session_id': session_uri, 'node_path': '/' }) method(oe, rc.request_context, params) # Joining the threads to avoid rc.validate failure for task in session.expand_tasks.values(): task.join() for task in session.refresh_tasks.values(): task.join() # Then: # ... An error notification should have been sent rc.validate() # ... The thread should be attached to the session self.assertEqual(len(get_tasks(session)), 1)
def test_handle_close_session_missing_params(self): # If: I close an OE session with missing params rc = RequestFlowValidator().add_expected_error( type(None), RequestFlowValidator.basic_error_validation) self.oe._handle_close_session_request(rc.request_context, None) # Then: I should get an error response rc.validate()
def test_handle_close_session_incomplete_params(self): # If: I close an OE session for with missing params # NOTE: We only need to get the generate uri method to throw, we make sure it throws in all # scenarios in a different test rc = RequestFlowValidator().add_expected_error( type(None), RequestFlowValidator.basic_error_validation) params = ConnectionDetails.from_data({}) self.oe._handle_close_session_request(rc.request_context, params) # Then: # ... I should get an error response rc.validate()
def test_handle_close_session_nosession(self): # Setup: Create an empty session dictionary self.oe._session_map = {} # If: I close an OE session that doesn't exist rc = RequestFlowValidator().add_expected_response( bool, self.assertFalse) session_id = _connection_details()[1] params = _close_session_params() params.session_id = session_id self.oe._handle_close_session_request(rc.request_context, params) # Then: I should get a successful response rc.validate()
def test_handle_close_session_throwsException(self): # setup to throw exception on disconnect self.cs.disconnect = mock.MagicMock(side_effect=Exception) # If: I close an OE session that doesn't exist rc = RequestFlowValidator().add_expected_error(type(None)) session_id = _connection_details()[1] params = _close_session_params() params.session_id = session_id self.oe._handle_close_session_request(rc.request_context, params) # Then: I should get a successful response rc.validate() self.oe._service_provider.logger.error.assert_called_once()
def _handle_er_session_not_ready(self, method: TEventHandler): # Setup: Create an OE service with a session that is not ready oe, session, session_uri = self._preloaded_oe_service() session.is_ready = False # If: I expand a node on a session that isn't ready rc = RequestFlowValidator().add_expected_error( type(None), RequestFlowValidator.basic_error_validation) params = ExpandParameters.from_dict({ 'session_id': session_uri, 'node_path': None }) method(oe, rc.request_context, params) # Then: I should get an error back rc.validate()
def _handle_er_no_session_match(method: TEventHandler): # Setup: Create an OE service oe = ObjectExplorerService() oe._service_provider = utils.get_mock_service_provider({}) # If: I expand a node on a session that doesn't exist rc = RequestFlowValidator().add_expected_error( type(None), RequestFlowValidator.basic_error_validation) params = ExpandParameters.from_dict({ 'session_id': 'session', 'node_path': None }) method(oe, rc.request_context, params) # Then: I should get an error back rc.validate()
def test_handle_create_session_missing_params(self): # Setup: Create an OE service oe = ObjectExplorerService() oe._service_provider = utils.get_mock_service_provider({}) # If: I create an OE session with missing params rc = RequestFlowValidator().add_expected_error( type(None), RequestFlowValidator.basic_error_validation) oe._handle_create_session_request(rc.request_context, None) # Then: # ... I should get an error response rc.validate() # ... A session should not have been created self.assertDictEqual(oe._session_map, {})
def test_handle_close_session_unsuccessful(self): self.cs.disconnect = mock.MagicMock(return_value=False) # If: I close an OE session that doesn't exist rc = RequestFlowValidator().add_expected_response( bool, self.assertFalse) session_id = _connection_details()[1] params = _close_session_params() params.session_id = session_id self.oe._handle_close_session_request(rc.request_context, params) # Then: I should get a successful response rc.validate() self.oe._service_provider.logger.info.assert_called_with( 'Could not close the OE session with Id objectexplorer://testuser@testhost:testdb/' )
def test_handle_close_session_successful(self): # If: I close a session rc = RequestFlowValidator().add_expected_response( bool, self.assertTrue) session_id = _connection_details()[1] params = _close_session_params() params.session_id = session_id self.oe._handle_close_session_request(rc.request_context, params) # Then: # ... I should get a successful response rc.validate() # ... The session should no longer be in the self.assertDictEqual(self.oe._session_map, {})
def test_handle_create_session_session_exists(self): # Setup: Create an OE service and pre-load a session oe = ObjectExplorerService() oe._service_provider = utils.get_mock_service_provider({}) params, session_uri = _connection_details() session = ObjectExplorerSession(session_uri, params) oe._session_map[session_uri] = session # If: I attempt to create an OE session that already exists rc = RequestFlowValidator().add_expected_response( bool, self.assertFalse) oe._handle_create_session_request(rc.request_context, params) # Then: # ... I should get a response as False rc.validate() # ... The old session should remain self.assertIs(oe._session_map[session_uri], session)
def test_handle_create_session_incomplete_params(self): # Setup: Create an OE service oe = ObjectExplorerService() oe._service_provider = utils.get_mock_service_provider({}) # If: I create an OE session for with missing params # NOTE: We only need to get the generate uri method to throw, we make sure it throws in all # scenarios in a different test rc = RequestFlowValidator().add_expected_error( type(None), RequestFlowValidator.basic_error_validation) params = ConnectionDetails.from_data({}) oe._handle_create_session_request(rc.request_context, params) # Then: # ... I should get an error response rc.validate() # ... A session should not have been created self.assertDictEqual(oe._session_map, {})
def _handle_er_node_alivetasksuccessful(self, method: TEventHandler, get_tasks: TGetTask): # Setup: Create an OE service with a session preloaded oe, session, session_uri = self._preloaded_oe_service() # ... Define validation for the return notification def validate_success_notification(response: ExpandCompletedParameters): self.assertIsNone(response.error_message) self.assertEqual(response.session_id, session_uri) self.assertEqual(response.node_path, '/') self.assertIsInstance(response.nodes, list) for node in response.nodes: self.assertIsInstance(node, NodeInfo) def myfunc(e): while not e.isSet(): pass # If: I expand a node rc = RequestFlowValidator() rc.add_expected_response(bool, self.assertTrue) params = ExpandParameters.from_dict({ 'session_id': session_uri, 'node_path': '/' }) testevent = threading.Event() testtask = threading.Thread(target=myfunc, args=(testevent, )) session.expand_tasks[params.node_path] = testtask session.refresh_tasks[params.node_path] = testtask testtask.start() method(oe, rc.request_context, params) # Then: # ... I should have gotten a completed successfully message rc.validate() # ... The thread should be attached to the session self.assertEqual(len(get_tasks(session)), 1) testevent.set()
def _handle_er_node_successful(self, method: TEventHandler, get_tasks: TGetTask): # Setup: Create an OE service with a session preloaded oe, session, session_uri = self._preloaded_oe_service() # ... Define validation for the return notification def validate_success_notification(response: ExpandCompletedParameters): self.assertIsNone(response.error_message) self.assertEqual(response.session_id, session_uri) self.assertEqual(response.node_path, '/') self.assertIsInstance(response.nodes, list) for node in response.nodes: self.assertIsInstance(node, NodeInfo) # If: I expand a node rc = RequestFlowValidator() rc.add_expected_response(bool, self.assertTrue) rc.add_expected_notification(ExpandCompletedParameters, EXPAND_COMPLETED_METHOD, validate_success_notification) params = ExpandParameters.from_dict({ 'session_id': session_uri, 'node_path': '/' }) method(oe, rc.request_context, params) # Joining the threads to avoid rc.validate failure for task in session.expand_tasks.values(): task.join() for task in session.refresh_tasks.values(): task.join() # Then: # ... I should have gotten a completed successfully message rc.validate() # ... The thread should be attached to the session self.assertEqual(len(get_tasks(session)), 1)
class TaskServiceTests(unittest.TestCase): """Methods for testing the task service""" def setUp(self): self.task_service = TaskService() self.service_provider = ServiceProviderMock({constants.TASK_SERVICE_NAME: self.task_service}) self.request_validator = RequestFlowValidator() self.mock_task_1 = Task(None, None, None, None, None, self.request_validator.request_context, mock.Mock(), mock.Mock()) self.request_validator.add_expected_notification(TaskInfo, 'tasks/newtaskcreated') self.mock_task_1.start = mock.Mock() self.mock_task_2 = Task(None, None, None, None, None, self.request_validator.request_context, mock.Mock(), mock.Mock()) self.request_validator.add_expected_notification(TaskInfo, 'tasks/newtaskcreated') self.mock_task_2.start = mock.Mock() def test_registration(self): """Test that the service registers its cancel and list methods correctly""" # If I call the task service's register method self.task_service.register(self.service_provider) # Then CANCEL_TASK_REQUEST and LIST_TASKS_REQUEST should have been registered self.service_provider.server.set_request_handler.assert_has_calls( [mock.call(CANCEL_TASK_REQUEST, self.task_service.handle_cancel_request), mock.call(LIST_TASKS_REQUEST, self.task_service.handle_list_request)], any_order=True) def test_start_task(self): """Test that the service can start tasks""" # If I start both tasks self.task_service.start_task(self.mock_task_1) self.task_service.start_task(self.mock_task_2) # Then the task service is aware of them self.assertIs(self.task_service._task_map[self.mock_task_1.id], self.mock_task_1) self.assertIs(self.task_service._task_map[self.mock_task_2.id], self.mock_task_2) # And the tasks' start methods were called self.mock_task_1.start.assert_called_once() self.mock_task_2.start.assert_called_once() def test_cancel_request(self): """Test that sending a cancellation request attempts to cancel the task""" # Set up a task self.mock_task_1.cancel = mock.Mock(return_value=True) self.mock_task_1.status = TaskStatus.IN_PROGRESS self.task_service.start_task(self.mock_task_1) # Set up the request flow validator self.request_validator.add_expected_response(bool, self.assertTrue) # If I call the cancellation handler params = CancelTaskParameters() params.task_id = self.mock_task_1.id self.task_service.handle_cancel_request(self.request_validator.request_context, params) # Then the task's cancel method should have been called and a positive response should have been sent self.mock_task_1.cancel.assert_called_once() self.request_validator.validate() def test_cancel_request_no_task(self): """Test that the cancellation handler returns false when there is no task to cancel""" # Set up the request flow validator self.request_validator.add_expected_response(bool, self.assertFalse) # If I call the cancellation handler without a corresponding task params = CancelTaskParameters() params.task_id = self.mock_task_1.id self.task_service.handle_cancel_request(self.request_validator.request_context, params) # Then a negative response should have been sent self.request_validator.validate() def test_list_all_tasks(self): """Test that the list task handler displays the correct task information""" self._test_list_tasks(False) def test_list_active_tasks(self): """Test that the list task handler displays the correct task information""" self._test_list_tasks(True) def _test_list_tasks(self, active_tasks_only: bool): # Set up task 1 to be in progress and task 2 to be complete self.task_service.start_task(self.mock_task_1) self.task_service.start_task(self.mock_task_2) self.mock_task_1.status = TaskStatus.IN_PROGRESS self.mock_task_2.status = TaskStatus.SUCCEEDED # Set up the request validator def validate_list_response(response_params: List[TaskInfo]): actual_response_dict = [info.__dict__ for info in response_params] expected_response_dict = [self.mock_task_1.task_info.__dict__] if not active_tasks_only: expected_response_dict.append(self.mock_task_2.task_info.__dict__) self.assertEqual(len(actual_response_dict), len(expected_response_dict)) for expected_info in expected_response_dict: self.assertIn(expected_info, actual_response_dict) self.request_validator.add_expected_response(list, validate_list_response) # If I start the tasks and then list them self.task_service.start_task(self.mock_task_1) self.task_service.start_task(self.mock_task_2) params = ListTasksParameters() params.list_active_tasks_only = active_tasks_only self.task_service.handle_list_request(self.request_validator.request_context, params) # Then the service responds with TaskInfo for only task 1 self.request_validator.validate()
class TestLanguageService(unittest.TestCase): """Methods for testing the language service""" def setUp(self): """Constructor""" self.default_uri = 'file://my.sql' self.flow_validator = RequestFlowValidator() self.mock_server_set_request = mock.MagicMock() self.mock_server = JSONRPCServer(None, None) self.mock_server.set_request_handler = self.mock_server_set_request self.mock_workspace_service = WorkspaceService() self.mock_connection_service = ConnectionService() self.mock_service_provider = ServiceProvider(self.mock_server, {}, PG_PROVIDER_NAME, None) self.mock_service_provider._services[ constants.WORKSPACE_SERVICE_NAME] = self.mock_workspace_service self.mock_service_provider._services[ constants.CONNECTION_SERVICE_NAME] = self.mock_connection_service self.mock_service_provider._is_initialized = True self.default_text_position = TextDocumentPosition.from_dict({ 'text_document': { 'uri': self.default_uri }, 'position': { 'line': 3, 'character': 10 } }) self.default_text_document_id = TextDocumentIdentifier.from_dict( {'uri': self.default_uri}) def test_register(self): """Test registration of the service""" # Setup: # ... Create a mock service provider server: JSONRPCServer = JSONRPCServer(None, None) server.set_notification_handler = mock.MagicMock() server.set_request_handler = mock.MagicMock() provider: ServiceProvider = ServiceProvider( server, {constants.CONNECTION_SERVICE_NAME: ConnectionService}, PG_PROVIDER_NAME, utils.get_mock_logger()) provider._is_initialized = True conn_service: ConnectionService = provider[ constants.CONNECTION_SERVICE_NAME] self.assertEqual(0, len(conn_service._on_connect_callbacks)) # If: I register a language service service: LanguageService = LanguageService() service.register(provider) # Then: # ... The notifications should have been registered server.set_notification_handler.assert_called() server.set_request_handler.assert_called() self.assertEqual(1, len(conn_service._on_connect_callbacks)) self.assertEqual(1, server.count_shutdown_handlers()) # ... The service provider should have been stored self.assertIs(service._service_provider, provider) # noqa def test_handle_shutdown(self): # Given a language service service: LanguageService = self._init_service( stop_operations_queue=False) self.assertFalse(service.operations_queue.stop_requested) # When I shutdown the service service._handle_shutdown() # Then the language service should be cleaned up self.assertTrue(service.operations_queue.stop_requested) def test_completion_intellisense_off(self): """ Test that the completion handler returns empty if the intellisense is disabled """ # If: intellisense is disabled context: RequestContext = utils.MockRequestContext() config = Configuration() config.sql.intellisense.enable_intellisense = False self.mock_workspace_service._configuration = config service: LanguageService = self._init_service() # When: I request completion item service.handle_completion_request(context, self.default_text_position) # Then: # ... An empty completion should be sent over the notification context.send_response.assert_called_once() self.assertEqual(context.last_response_params, []) def test_completion_file_not_found(self): """ Test that the completion handler returns empty if the intellisense is disabled """ # If: The script file doesn't exist (there is an empty workspace) context: RequestContext = utils.MockRequestContext() self.mock_workspace_service._workspace = Workspace() service: LanguageService = self._init_service() # When: I request completion item service.handle_completion_request(context, self.default_text_position) # Then: # ... An empty completion should be sent over the notification context.send_response.assert_called_once() self.assertEqual(context.last_response_params, []) def test_default_completion_items(self): """ Test that the completion handler returns a set of default values when not connected to any URI """ # If: The script file exists input_text = 'create tab' doc_position = TextDocumentPosition.from_dict({ 'text_document': { 'uri': self.default_uri }, 'position': { 'line': 0, 'character': 10 # end of 'tab' word } }) context: RequestContext = utils.MockRequestContext() config = Configuration() config.sql.intellisense.enable_intellisense = True self.mock_workspace_service._configuration = config workspace, script_file = self._get_test_workspace(True, input_text) self.mock_workspace_service._workspace = workspace service: LanguageService = self._init_service() service._valid_uri.add(doc_position.text_document.uri) # When: I request completion item service.handle_completion_request(context, doc_position) # Then: # ... An default completion set should be sent over the notification context.send_response.assert_called_once() completions: List[CompletionItem] = context.last_response_params self.assertTrue(len(completions) > 0) self.verify_match('TABLE', completions, Range.from_data(0, 7, 0, 10)) def test_pg_language_flavor(self): """ Test that if provider is PGSQL, the service ignores files registered as being for non-PGSQL flavors """ # If: I create a new language service pgsql_params = LanguageFlavorChangeParams.from_data( 'file://pguri.sql', 'sql', PG_PROVIDER_NAME) mysql_params = LanguageFlavorChangeParams.from_data( 'file://mysqluri.sql', 'sql', MYSQL_PROVIDER_NAME) mssql_params = LanguageFlavorChangeParams.from_data( 'file://msuri.sql', 'sql', MSSQL_PROVIDER_NAME) other_params = LanguageFlavorChangeParams.from_data( 'file://other.doc', 'doc', '') provider = utils.get_mock_service_provider() service = LanguageService() service._service_provider = provider # When: I notify of language preferences context: NotificationContext = utils.get_mock_notification_context() service.handle_flavor_change(context, pgsql_params) service.handle_flavor_change(context, mssql_params) service.handle_flavor_change(context, mysql_params) service.handle_flavor_change(context, other_params) # Then: # ... Only non-PGSQL SQL files should be ignored context.send_notification.assert_not_called() self.assertFalse(service.is_valid_uri(mssql_params.uri)) self.assertTrue(service.is_valid_uri(pgsql_params.uri)) self.assertFalse(service.is_valid_uri(other_params.uri)) self.assertFalse(service.is_valid_uri(mysql_params.uri)) # When: I change from MSSQL to PGSQL mssql_params = LanguageFlavorChangeParams.from_data( 'file://msuri.sql', 'sql', PG_PROVIDER_NAME) service.handle_flavor_change(context, mssql_params) # Then: the service is updated to allow intellisense self.assertTrue(service.is_valid_uri(mssql_params.uri)) # When: I change from PGSQL to MYSQL mssql_params = LanguageFlavorChangeParams.from_data( 'file://msuri.sql', 'sql', MYSQL_PROVIDER_NAME) service.handle_flavor_change(context, mssql_params) # Then: the service is updated to not allow intellisense self.assertFalse(service.is_valid_uri(mssql_params.uri)) def test_mysql_language_flavor(self): """ Test that if provider is MySQL, the service ignores files registered as being for non-MySQL flavors """ # If: I create a new language service pgsql_params = LanguageFlavorChangeParams.from_data( 'file://pguri.sql', 'sql', PG_PROVIDER_NAME) mysql_params = LanguageFlavorChangeParams.from_data( 'file://mysqluri.sql', 'sql', MYSQL_PROVIDER_NAME) mssql_params = LanguageFlavorChangeParams.from_data( 'file://msuri.sql', 'sql', MSSQL_PROVIDER_NAME) other_params = LanguageFlavorChangeParams.from_data( 'file://other.doc', 'doc', '') # create a mock mysql service provider provider = utils.get_mock_service_provider( provider_name=MYSQL_PROVIDER_NAME) service = LanguageService() service._service_provider = provider # When: I notify of language preferences context: NotificationContext = utils.get_mock_notification_context() service.handle_flavor_change(context, pgsql_params) service.handle_flavor_change(context, mssql_params) service.handle_flavor_change(context, mysql_params) service.handle_flavor_change(context, other_params) # Then: # ... Only non-MySQL SQL files should be ignored context.send_notification.assert_not_called() self.assertFalse(service.is_valid_uri(mssql_params.uri)) self.assertFalse(service.is_valid_uri(pgsql_params.uri)) self.assertFalse(service.is_valid_uri(other_params.uri)) self.assertTrue(service.is_valid_uri(mysql_params.uri)) # When: I change from MSSQL to PGSQL mssql_params = LanguageFlavorChangeParams.from_data( 'file://msuri.sql', 'sql', PG_PROVIDER_NAME) service.handle_flavor_change(context, mssql_params) # Then: the service is updated to not allow intellisense self.assertFalse(service.is_valid_uri(mssql_params.uri)) # When: I change from PGSQL to MYSQL mssql_params = LanguageFlavorChangeParams.from_data( 'file://msuri.sql', 'sql', MYSQL_PROVIDER_NAME) service.handle_flavor_change(context, mssql_params) # Then: the service is updated to allow intellisense self.assertTrue(service.is_valid_uri(mssql_params.uri)) def test_on_connect_sends_notification(self): """ Test that the service sends an intellisense ready notification after handling an on connect notification from the connection service. This is a slightly more end-to-end test that verifies calling through to the queue layer """ # If: I create a new language service service: LanguageService = self._init_service_with_flow_validator() conn_info = ConnectionInfo( 'file://msuri.sql', ConnectionDetails.from_data({ 'host': None, 'dbname': 'TEST_DBNAME', 'user': '******' })) connect_result = mock.MagicMock() connect_result.error_message = None self.mock_connection_service.get_connection = mock.Mock( return_value=mock.MagicMock()) self.mock_connection_service.connect = mock.MagicMock( return_value=connect_result) def validate_success_notification(response: IntelliSenseReadyParams): self.assertEqual(response.owner_uri, conn_info.owner_uri) # When: I notify of a connection complete for a given URI self.flow_validator.add_expected_notification( IntelliSenseReadyParams, INTELLISENSE_READY_NOTIFICATION, validate_success_notification) refresher_mock = mock.MagicMock() refresh_method_mock = mock.MagicMock() refresher_mock.refresh = refresh_method_mock patch_path = 'ossdbtoolsservice.language.operations_queue.CompletionRefresher' with mock.patch(patch_path) as refresher_patch: refresher_patch.return_value = refresher_mock task: threading.Thread = service.on_connect(conn_info) # And when refresh is "complete" refresh_method_mock.assert_called_once() callback = refresh_method_mock.call_args[0][0] self.assertIsNotNone(callback) callback(None) # Wait for task to return task.join() # Then: # an intellisense ready notification should be sent for that URI self.flow_validator.validate() # ... and the scriptparseinfo should be created info: ScriptParseInfo = service.get_script_parse_info( conn_info.owner_uri) self.assertIsNotNone(info) # ... and the info should have the connection key set self.assertEqual(info.connection_key, OperationsQueue.create_key(conn_info)) def test_format_doc_no_pgsql_format(self): """ Test that the format codepath succeeds even if the configuration options aren't defined """ input_text = 'select * from foo where id in (select id from bar);' context: RequestContext = utils.MockRequestContext() self.mock_workspace_service._configuration = None workspace, script_file = self._get_test_workspace(True, input_text) self.mock_workspace_service._workspace = workspace service: LanguageService = self._init_service() format_options = FormattingOptions() format_options.insert_spaces = False format_params = DocumentFormattingParams() format_params.options = format_options format_params.text_document = self.default_text_document_id # add uri to valid uri set ensure request passes uri check # normally done in flavor change handler, but we are not testing that here service._valid_uri.add(format_params.text_document.uri) # When: I have no useful formatting defaults defined service.handle_doc_format_request(context, format_params) # Then: # ... There should be no changes to the doc context.send_response.assert_called_once() edits: List[TextEdit] = context.last_response_params self.assertTrue(len(edits) > 0) self.assert_range_equals(edits[0].range, Range.from_data(0, 0, 0, len(input_text))) self.assertEqual(edits[0].new_text, input_text) def test_format_doc(self): """ Test that the format document codepath works as expected """ # If: We have a basic string to be formatted input_text = 'select * from foo where id in (select id from bar);' # Note: sqlparse always uses '\n\ for line separator even on windows. # For now, respecting this behavior and leaving as-is expected_output = '\n'.join([ 'SELECT *', 'FROM foo', 'WHERE id IN', '\t\t\t\t(SELECT id', '\t\t\t\t\tFROM bar);' ]) context: RequestContext = utils.MockRequestContext() config = Configuration() config.pgsql = PGSQLConfiguration() config.pgsql.format.keyword_case = 'upper' self.mock_workspace_service._configuration = config workspace, script_file = self._get_test_workspace(True, input_text) self.mock_workspace_service._workspace = workspace service: LanguageService = self._init_service() format_options = FormattingOptions() format_options.insert_spaces = False format_params = DocumentFormattingParams() format_params.options = format_options format_params.text_document = self.default_text_document_id # add uri to valid uri set ensure request passes uri check # normally done in flavor change handler, but we are not testing that here service._valid_uri.add(format_params.text_document.uri) # When: I request document formatting service.handle_doc_format_request(context, format_params) # Then: # ... The entire document text should be formatted context.send_response.assert_called_once() edits: List[TextEdit] = context.last_response_params self.assertTrue(len(edits) > 0) self.assert_range_equals(edits[0].range, Range.from_data(0, 0, 0, len(input_text))) self.assertEqual(edits[0].new_text, expected_output) def test_format_doc_range(self): """ Test that the format document range codepath works as expected """ # If: The script file doesn't exist (there is an empty workspace) input_lines: List[str] = [ 'select * from t1', 'select * from foo where id in (select id from bar);' ] input_text = '\n'.join(input_lines) expected_output = '\n'.join([ 'SELECT *', 'FROM foo', 'WHERE id IN', '\t\t\t\t(SELECT id', '\t\t\t\t\tFROM bar);' ]) context: RequestContext = utils.MockRequestContext() config = Configuration() config.pgsql = PGSQLConfiguration() config.pgsql.format.keyword_case = 'upper' self.mock_workspace_service._configuration = config workspace, script_file = self._get_test_workspace(True, input_text) self.mock_workspace_service._workspace = workspace service: LanguageService = self._init_service() format_options = FormattingOptions() format_options.insert_spaces = False format_params = DocumentRangeFormattingParams() format_params.options = format_options format_params.text_document = self.default_text_document_id # add uri to valid uri set ensure request passes uri check # normally done in flavor change handler, but we are not testing that here service._valid_uri.add(format_params.text_document.uri) # When: I request format the 2nd line of a document format_params.range = Range.from_data(1, 0, 1, len(input_lines[1])) service.handle_doc_range_format_request(context, format_params) # Then: # ... only the 2nd line should be formatted context.send_response.assert_called_once() edits: List[TextEdit] = context.last_response_params self.assertTrue(len(edits) > 0) self.assert_range_equals(edits[0].range, format_params.range) self.assertEqual(edits[0].new_text, expected_output) def test_format_mysql_doc(self): """ Test that the format document codepath works as expected with a mysql doc """ # set up service provider with mysql connection self.mock_service_provider = ServiceProvider(self.mock_server, {}, MYSQL_PROVIDER_NAME, None) self.mock_service_provider._services[ constants.WORKSPACE_SERVICE_NAME] = self.mock_workspace_service self.mock_service_provider._services[ constants.CONNECTION_SERVICE_NAME] = self.mock_connection_service self.mock_service_provider._is_initialized = True # If: We have a basic string to be formatted input_text = 'select * from foo where id in (select id from bar);' # Note: sqlparse always uses '\n\ for line separator even on windows. # For now, respecting this behavior and leaving as-is expected_output = '\n'.join([ 'SELECT *', 'FROM foo', 'WHERE id IN', '\t\t\t\t(SELECT id', '\t\t\t\t\tFROM bar);' ]) context: RequestContext = utils.MockRequestContext() config = Configuration() config.my_sql = MySQLConfiguration() config.my_sql.format.keyword_case = 'upper' self.mock_workspace_service._configuration = config workspace, script_file = self._get_test_workspace(True, input_text) self.mock_workspace_service._workspace = workspace service: LanguageService = self._init_service() format_options = FormattingOptions() format_options.insert_spaces = False format_params = DocumentFormattingParams() format_params.options = format_options format_params.text_document = self.default_text_document_id # add uri to valid uri set ensure request passes uri check # normally done in flavor change handler, but we are not testing that here service._valid_uri.add(format_params.text_document.uri) # When: I request document formatting service.handle_doc_format_request(context, format_params) # Then: # ... The entire document text should be formatted context.send_response.assert_called_once() edits: List[TextEdit] = context.last_response_params self.assertTrue(len(edits) > 0) self.assert_range_equals(edits[0].range, Range.from_data(0, 0, 0, len(input_text))) self.assertEqual(edits[0].new_text, expected_output) def test_format_mysql_doc_range(self): """ Test that the format document range codepath works as expected with a mysql doc """ # set up service provider with mysql connection self.mock_service_provider = ServiceProvider(self.mock_server, {}, MYSQL_PROVIDER_NAME, None) self.mock_service_provider._services[ constants.WORKSPACE_SERVICE_NAME] = self.mock_workspace_service self.mock_service_provider._services[ constants.CONNECTION_SERVICE_NAME] = self.mock_connection_service self.mock_service_provider._is_initialized = True # If: The script file doesn't exist (there is an empty workspace) input_lines: List[str] = [ 'select * from t1', 'select * from foo where id in (select id from bar);' ] input_text = '\n'.join(input_lines) expected_output = '\n'.join([ 'SELECT *', 'FROM foo', 'WHERE id IN', '\t\t\t\t(SELECT id', '\t\t\t\t\tFROM bar);' ]) context: RequestContext = utils.MockRequestContext() config = Configuration() config.my_sql = MySQLConfiguration() config.my_sql.format.keyword_case = 'upper' self.mock_workspace_service._configuration = config workspace, script_file = self._get_test_workspace(True, input_text) self.mock_workspace_service._workspace = workspace service: LanguageService = self._init_service() format_options = FormattingOptions() format_options.insert_spaces = False format_params = DocumentRangeFormattingParams() format_params.options = format_options format_params.text_document = self.default_text_document_id # add uri to valid uri set ensure request passes uri check # normally done in flavor change handler, but we are not testing that here service._valid_uri.add(format_params.text_document.uri) # When: I request format the 2nd line of a document format_params.range = Range.from_data(1, 0, 1, len(input_lines[1])) service.handle_doc_range_format_request(context, format_params) # Then: # ... only the 2nd line should be formatted context.send_response.assert_called_once() edits: List[TextEdit] = context.last_response_params self.assertTrue(len(edits) > 0) self.assert_range_equals(edits[0].range, format_params.range) self.assertEqual(edits[0].new_text, expected_output) @parameterized.expand([ (0, 10), (-2, 8), ]) def test_completion_to_completion_item(self, relative_start_pos, expected_start_char): """ Tests that PGCompleter's Completion objects get converted to CompletionItems as expected """ text = 'item' display = 'item is a table' display_meta = 'table' completion = Completion(text, relative_start_pos, display, display_meta) completion_item: CompletionItem = LanguageService.to_completion_item( completion, self.default_text_position) self.assertEqual(completion_item.label, text) self.assertEqual(completion_item.text_edit.new_text, text) text_pos: Position = self.default_text_position.position # pylint: disable=maybe-no-member self.assertEqual(completion_item.text_edit.range.start.line, text_pos.line) self.assertEqual(completion_item.text_edit.range.start.character, expected_start_char) self.assertEqual(completion_item.text_edit.range.end.line, text_pos.line) self.assertEqual(completion_item.text_edit.range.end.character, text_pos.character) self.assertEqual(completion_item.detail, display) self.assertEqual(completion_item.label, text) def test_handle_definition_request_should_return_empty_if_query_file_do_not_exist( self): # If: The script file doesn't exist (there is an empty workspace) context: RequestContext = utils.MockRequestContext() self.mock_workspace_service._workspace = Workspace() service: LanguageService = self._init_service() service.handle_definition_request(context, self.default_text_position) context.send_response.assert_called_once() self.assertEqual(context.last_response_params, []) def test_handle_definition_request_intellisense_off(self): request_context: RequestContext = utils.MockRequestContext() config = Configuration() config.sql.intellisense.enable_intellisense = False self.mock_workspace_service._configuration = config language_service = self._init_service() language_service.handle_definition_request(request_context, self.default_text_position) request_context.send_response.assert_called_once() self.assertEqual(request_context.last_response_params, []) def test_completion_keyword_completion_sort_text(self): """ Tests that a Keyword Completion is converted with sort text that puts it after other objects """ text = 'item' display = 'item is something' # Given I have anything other than a keyword, I expect label to match key table_completion = Completion(text, 0, display, 'table') completion_item: CompletionItem = LanguageService.to_completion_item( table_completion, self.default_text_position) self.assertEqual(completion_item.sort_text, text) # Given I have a keyword, I expect keyword_completion = Completion(text, 0, display, 'keyword') completion_item: CompletionItem = LanguageService.to_completion_item( keyword_completion, self.default_text_position) self.assertEqual(completion_item.sort_text, '~' + text) def _init_service(self, stop_operations_queue=True) -> LanguageService: """ Initializes a simple service instance. By default stops the threaded queue since this could cause issues debugging multiple tests, and the class can be tested without this running the queue """ service = LanguageService() service.register(self.mock_service_provider) if stop_operations_queue: service.operations_queue.stop() return service def _init_service_with_flow_validator(self) -> LanguageService: self.mock_server.send_notification = self.flow_validator.request_context.send_notification return self._init_service() def _get_test_workspace( self, script_file: bool = True, buffer: str = '') -> Tuple[Workspace, Optional[ScriptFile]]: workspace: Workspace = Workspace() file: Optional[ScriptFile] = None if script_file: file = ScriptFile(self.default_uri, buffer, '') workspace._workspace_files[self.default_uri] = file return workspace, file def verify_match(self, word: str, matches: List[CompletionItem], text_range: Range): """Verifies match against its label and other properties""" match: CompletionItem = next( iter(obj for obj in matches if obj.label == word), None) self.assertIsNotNone(match) self.assertEqual(word, match.label) self.assertEqual(CompletionItemKind.Keyword, match.kind) self.assertEqual(word, match.insert_text) self.assert_range_equals(text_range, match.text_edit.range) self.assertEqual(word, match.text_edit.new_text) def assert_range_equals(self, first: Range, second: Range): self.assertEqual(first.start.line, second.start.line) self.assertEqual(first.start.character, second.start.character) self.assertEqual(first.end.line, second.end.line) self.assertEqual(first.end.character, second.end.character)