def _handle_er_exception_expanding(self, method: TEventHandler, get_tasks: TGetTask): # Setup: Create an OE service with a session preloaded oe, session, session_uri = self._preloaded_oe_service() # ... Patch the route_request to throw # ... Patch the threading to throw patch_mock = mock.MagicMock(side_effect=Exception('Boom!')) patch_path = 'pgsqltoolsservice.object_explorer.object_explorer_service.route_request' with mock.patch(patch_path, patch_mock): # If: I expand a node (with route_request that throws) rc = RequestFlowValidator() rc.add_expected_response(bool, self.assertTrue) rc.add_expected_notification( ExpandCompletedParameters, EXPAND_COMPLETED_METHOD, lambda param: self._validate_expand_error( param, session_uri, '/')) params = ExpandParameters.from_dict({ 'session_id': session_uri, 'node_path': '/' }) method(oe, rc.request_context, params) # Joining the threads to avoid rc.validate failure for task in session.expand_tasks.values(): task.join() for task in session.refresh_tasks.values(): task.join() # Then: # ... An error notification should have been sent rc.validate() # ... The thread should be attached to the session self.assertEqual(len(get_tasks(session)), 1)
def _handle_er_threading_fail(self, method: TEventHandler): # Setup: Create an OE service with a session preloaded oe, session, session_uri = self._preloaded_oe_service() # ... Patch the threading to throw patch_mock = mock.MagicMock(side_effect=Exception('Boom!')) patch_path = 'pgsqltoolsservice.object_explorer.object_explorer_service.threading.Thread' with mock.patch(patch_path, patch_mock): # If: I expand a node (with threading that throws) rc = RequestFlowValidator() rc.add_expected_response(bool, self.assertTrue) rc.add_expected_notification( ExpandCompletedParameters, EXPAND_COMPLETED_METHOD, lambda param: self._validate_expand_error( param, session_uri, '/')) params = ExpandParameters.from_dict({ 'session_id': session_uri, 'node_path': '/' }) method(oe, rc.request_context, params) # Then: # ... The error notification should have been returned rc.validate() # ... The session should not have an expand task defined self.assertDictEqual(session.expand_tasks, {}) self.assertDictEqual(session.refresh_tasks, {})
def test_handle_create_session_threading_fail(self): # Setup: # ... Create an OE service oe = ObjectExplorerService() oe._service_provider = utils.get_mock_service_provider({}) # ... Patch the threading to throw patch_mock = mock.MagicMock(side_effect=Exception('Boom!')) patch_path = 'pgsqltoolsservice.object_explorer.object_explorer_service.threading.Thread' with mock.patch(patch_path, patch_mock): # If: I create a new session params, session_uri = _connection_details() rc = RequestFlowValidator() rc.add_expected_response( CreateSessionResponse, lambda param: self.assertEqual(param.session_id, session_uri)) rc.add_expected_notification( SessionCreatedParameters, SESSION_CREATED_METHOD, lambda param: self._validate_init_error(param, session_uri)) oe._handle_create_session_request(rc.request_context, params) # Then: # ... The error notification should have been returned rc.validate() # ... The session should have been cleaned up self.assertDictEqual(oe._session_map, {})
def test_handle_create_session_successful(self): # Setup: # ... Create OE service with mock connection service that returns a successful connection response mock_connection = MockPGServerConnection(cur=None, host='myserver', name='postgres', user='******', port=123) cs = ConnectionService() cs.connect = mock.MagicMock(return_value=ConnectionCompleteParams()) cs.get_connection = mock.MagicMock(return_value=mock_connection) oe = ObjectExplorerService() oe._service_provider = utils.get_mock_service_provider( {constants.CONNECTION_SERVICE_NAME: cs}) oe._provider = constants.PG_PROVIDER_NAME oe._server = Server # ... Create parameters, session, request context validator params, session_uri = _connection_details() # ... Create validation of success notification def validate_success_notification(response: SessionCreatedParameters): self.assertTrue(response.success) self.assertEqual(response.session_id, session_uri) self.assertIsNone(response.error_message) self.assertIsInstance(response.root_node, NodeInfo) self.assertEqual(response.root_node.label, TEST_DBNAME) self.assertEqual(response.root_node.node_path, session_uri) self.assertEqual(response.root_node.node_type, 'Database') self.assertIsInstance(response.root_node.metadata, ObjectMetadata) self.assertEqual(response.root_node.metadata.urn, oe._session_map[session_uri].server.urn_base) self.assertEqual( response.root_node.metadata.name, oe._session_map[session_uri].server.maintenance_db_name) self.assertEqual(response.root_node.metadata.metadata_type_name, 'Database') self.assertFalse(response.root_node.is_leaf) rc = RequestFlowValidator() rc.add_expected_response( CreateSessionResponse, lambda param: self.assertEqual(param.session_id, session_uri)) rc.add_expected_notification(SessionCreatedParameters, SESSION_CREATED_METHOD, validate_success_notification) # If: I create a session oe._handle_create_session_request(rc.request_context, params) oe._session_map[session_uri].init_task.join() # Then: # ... Error notification should have been returned, session should be cleaned up from OE service rc.validate() # ... The session should still exist and should have connection and server setup self.assertIn(session_uri, oe._session_map) self.assertIsInstance(oe._session_map[session_uri].server, Server) self.assertTrue(oe._session_map[session_uri].is_ready)
def _handle_er_node_alivetasksuccessful(self, method: TEventHandler, get_tasks: TGetTask): # Setup: Create an OE service with a session preloaded oe, session, session_uri = self._preloaded_oe_service() # ... Define validation for the return notification def validate_success_notification(response: ExpandCompletedParameters): self.assertIsNone(response.error_message) self.assertEqual(response.session_id, session_uri) self.assertEqual(response.node_path, '/') self.assertIsInstance(response.nodes, list) for node in response.nodes: self.assertIsInstance(node, NodeInfo) def myfunc(e): while not e.isSet(): pass # If: I expand a node rc = RequestFlowValidator() rc.add_expected_response(bool, self.assertTrue) params = ExpandParameters.from_dict({ 'session_id': session_uri, 'node_path': '/' }) testevent = threading.Event() testtask = threading.Thread(target=myfunc, args=(testevent, )) session.expand_tasks[params.node_path] = testtask session.refresh_tasks[params.node_path] = testtask testtask.start() method(oe, rc.request_context, params) # Then: # ... I should have gotten a completed successfully message rc.validate() # ... The thread should be attached to the session self.assertEqual(len(get_tasks(session)), 1) testevent.set()
def _handle_er_node_successful(self, method: TEventHandler, get_tasks: TGetTask): # Setup: Create an OE service with a session preloaded oe, session, session_uri = self._preloaded_oe_service() # ... Define validation for the return notification def validate_success_notification(response: ExpandCompletedParameters): self.assertIsNone(response.error_message) self.assertEqual(response.session_id, session_uri) self.assertEqual(response.node_path, '/') self.assertIsInstance(response.nodes, list) for node in response.nodes: self.assertIsInstance(node, NodeInfo) # If: I expand a node rc = RequestFlowValidator() rc.add_expected_response(bool, self.assertTrue) rc.add_expected_notification(ExpandCompletedParameters, EXPAND_COMPLETED_METHOD, validate_success_notification) params = ExpandParameters.from_dict({ 'session_id': session_uri, 'node_path': '/' }) method(oe, rc.request_context, params) # Joining the threads to avoid rc.validate failure for task in session.expand_tasks.values(): task.join() for task in session.refresh_tasks.values(): task.join() # Then: # ... I should have gotten a completed successfully message rc.validate() # ... The thread should be attached to the session self.assertEqual(len(get_tasks(session)), 1)
class TaskServiceTests(unittest.TestCase): """Methods for testing the task service""" def setUp(self): self.task_service = TaskService() self.service_provider = ServiceProviderMock({constants.TASK_SERVICE_NAME: self.task_service}) self.request_validator = RequestFlowValidator() self.mock_task_1 = Task(None, None, None, None, None, self.request_validator.request_context, mock.Mock(), mock.Mock()) self.request_validator.add_expected_notification(TaskInfo, 'tasks/newtaskcreated') self.mock_task_1.start = mock.Mock() self.mock_task_2 = Task(None, None, None, None, None, self.request_validator.request_context, mock.Mock(), mock.Mock()) self.request_validator.add_expected_notification(TaskInfo, 'tasks/newtaskcreated') self.mock_task_2.start = mock.Mock() def test_registration(self): """Test that the service registers its cancel and list methods correctly""" # If I call the task service's register method self.task_service.register(self.service_provider) # Then CANCEL_TASK_REQUEST and LIST_TASKS_REQUEST should have been registered self.service_provider.server.set_request_handler.assert_has_calls( [mock.call(CANCEL_TASK_REQUEST, self.task_service.handle_cancel_request), mock.call(LIST_TASKS_REQUEST, self.task_service.handle_list_request)], any_order=True) def test_start_task(self): """Test that the service can start tasks""" # If I start both tasks self.task_service.start_task(self.mock_task_1) self.task_service.start_task(self.mock_task_2) # Then the task service is aware of them self.assertIs(self.task_service._task_map[self.mock_task_1.id], self.mock_task_1) self.assertIs(self.task_service._task_map[self.mock_task_2.id], self.mock_task_2) # And the tasks' start methods were called self.mock_task_1.start.assert_called_once() self.mock_task_2.start.assert_called_once() def test_cancel_request(self): """Test that sending a cancellation request attempts to cancel the task""" # Set up a task self.mock_task_1.cancel = mock.Mock(return_value=True) self.mock_task_1.status = TaskStatus.IN_PROGRESS self.task_service.start_task(self.mock_task_1) # Set up the request flow validator self.request_validator.add_expected_response(bool, self.assertTrue) # If I call the cancellation handler params = CancelTaskParameters() params.task_id = self.mock_task_1.id self.task_service.handle_cancel_request(self.request_validator.request_context, params) # Then the task's cancel method should have been called and a positive response should have been sent self.mock_task_1.cancel.assert_called_once() self.request_validator.validate() def test_cancel_request_no_task(self): """Test that the cancellation handler returns false when there is no task to cancel""" # Set up the request flow validator self.request_validator.add_expected_response(bool, self.assertFalse) # If I call the cancellation handler without a corresponding task params = CancelTaskParameters() params.task_id = self.mock_task_1.id self.task_service.handle_cancel_request(self.request_validator.request_context, params) # Then a negative response should have been sent self.request_validator.validate() def test_list_all_tasks(self): """Test that the list task handler displays the correct task information""" self._test_list_tasks(False) def test_list_active_tasks(self): """Test that the list task handler displays the correct task information""" self._test_list_tasks(True) def _test_list_tasks(self, active_tasks_only: bool): # Set up task 1 to be in progress and task 2 to be complete self.task_service.start_task(self.mock_task_1) self.task_service.start_task(self.mock_task_2) self.mock_task_1.status = TaskStatus.IN_PROGRESS self.mock_task_2.status = TaskStatus.SUCCEEDED # Set up the request validator def validate_list_response(response_params: List[TaskInfo]): actual_response_dict = [info.__dict__ for info in response_params] expected_response_dict = [self.mock_task_1.task_info.__dict__] if not active_tasks_only: expected_response_dict.append(self.mock_task_2.task_info.__dict__) self.assertEqual(len(actual_response_dict), len(expected_response_dict)) for expected_info in expected_response_dict: self.assertIn(expected_info, actual_response_dict) self.request_validator.add_expected_response(list, validate_list_response) # If I start the tasks and then list them self.task_service.start_task(self.mock_task_1) self.task_service.start_task(self.mock_task_2) params = ListTasksParameters() params.list_active_tasks_only = active_tasks_only self.task_service.handle_list_request(self.request_validator.request_context, params) # Then the service responds with TaskInfo for only task 1 self.request_validator.validate()