예제 #1
0
    def _handle_er_exception_expanding(self, method: TEventHandler,
                                       get_tasks: TGetTask):
        # Setup: Create an OE service with a session preloaded
        oe, session, session_uri = self._preloaded_oe_service()

        # ... Patch the route_request to throw
        # ... Patch the threading to throw
        patch_mock = mock.MagicMock(side_effect=Exception('Boom!'))
        patch_path = 'pgsqltoolsservice.object_explorer.object_explorer_service.route_request'
        with mock.patch(patch_path, patch_mock):
            # If: I expand a node (with route_request that throws)
            rc = RequestFlowValidator()
            rc.add_expected_response(bool, self.assertTrue)
            rc.add_expected_notification(
                ExpandCompletedParameters, EXPAND_COMPLETED_METHOD,
                lambda param: self._validate_expand_error(
                    param, session_uri, '/'))
            params = ExpandParameters.from_dict({
                'session_id': session_uri,
                'node_path': '/'
            })
            method(oe, rc.request_context, params)

        # Joining the threads to avoid rc.validate failure
        for task in session.expand_tasks.values():
            task.join()
        for task in session.refresh_tasks.values():
            task.join()
        # Then:
        # ... An error notification should have been sent
        rc.validate()

        # ... The thread should be attached to the session
        self.assertEqual(len(get_tasks(session)), 1)
예제 #2
0
    def _handle_er_threading_fail(self, method: TEventHandler):
        # Setup: Create an OE service with a session preloaded
        oe, session, session_uri = self._preloaded_oe_service()

        # ... Patch the threading to throw
        patch_mock = mock.MagicMock(side_effect=Exception('Boom!'))
        patch_path = 'pgsqltoolsservice.object_explorer.object_explorer_service.threading.Thread'
        with mock.patch(patch_path, patch_mock):
            # If: I expand a node (with threading that throws)
            rc = RequestFlowValidator()
            rc.add_expected_response(bool, self.assertTrue)
            rc.add_expected_notification(
                ExpandCompletedParameters, EXPAND_COMPLETED_METHOD,
                lambda param: self._validate_expand_error(
                    param, session_uri, '/'))
            params = ExpandParameters.from_dict({
                'session_id': session_uri,
                'node_path': '/'
            })
            method(oe, rc.request_context, params)

        # Then:
        # ... The error notification should have been returned
        rc.validate()

        # ... The session should not have an expand task defined
        self.assertDictEqual(session.expand_tasks, {})
        self.assertDictEqual(session.refresh_tasks, {})
예제 #3
0
    def test_handle_create_session_threading_fail(self):
        # Setup:
        # ... Create an OE service
        oe = ObjectExplorerService()
        oe._service_provider = utils.get_mock_service_provider({})

        # ... Patch the threading to throw
        patch_mock = mock.MagicMock(side_effect=Exception('Boom!'))
        patch_path = 'pgsqltoolsservice.object_explorer.object_explorer_service.threading.Thread'
        with mock.patch(patch_path, patch_mock):
            # If: I create a new session
            params, session_uri = _connection_details()

            rc = RequestFlowValidator()
            rc.add_expected_response(
                CreateSessionResponse,
                lambda param: self.assertEqual(param.session_id, session_uri))
            rc.add_expected_notification(
                SessionCreatedParameters, SESSION_CREATED_METHOD,
                lambda param: self._validate_init_error(param, session_uri))
            oe._handle_create_session_request(rc.request_context, params)

        # Then:
        # ... The error notification should have been returned
        rc.validate()

        # ... The session should have been cleaned up
        self.assertDictEqual(oe._session_map, {})
    def test_handle_create_session_successful(self):
        # Setup:
        # ... Create OE service with mock connection service that returns a successful connection response
        mock_connection = MockPGServerConnection(cur=None,
                                                 host='myserver',
                                                 name='postgres',
                                                 user='******',
                                                 port=123)
        cs = ConnectionService()
        cs.connect = mock.MagicMock(return_value=ConnectionCompleteParams())
        cs.get_connection = mock.MagicMock(return_value=mock_connection)
        oe = ObjectExplorerService()
        oe._service_provider = utils.get_mock_service_provider(
            {constants.CONNECTION_SERVICE_NAME: cs})
        oe._provider = constants.PG_PROVIDER_NAME
        oe._server = Server

        # ... Create parameters, session, request context validator
        params, session_uri = _connection_details()

        # ... Create validation of success notification
        def validate_success_notification(response: SessionCreatedParameters):
            self.assertTrue(response.success)
            self.assertEqual(response.session_id, session_uri)
            self.assertIsNone(response.error_message)

            self.assertIsInstance(response.root_node, NodeInfo)
            self.assertEqual(response.root_node.label, TEST_DBNAME)
            self.assertEqual(response.root_node.node_path, session_uri)
            self.assertEqual(response.root_node.node_type, 'Database')
            self.assertIsInstance(response.root_node.metadata, ObjectMetadata)
            self.assertEqual(response.root_node.metadata.urn,
                             oe._session_map[session_uri].server.urn_base)
            self.assertEqual(
                response.root_node.metadata.name,
                oe._session_map[session_uri].server.maintenance_db_name)
            self.assertEqual(response.root_node.metadata.metadata_type_name,
                             'Database')
            self.assertFalse(response.root_node.is_leaf)

        rc = RequestFlowValidator()
        rc.add_expected_response(
            CreateSessionResponse,
            lambda param: self.assertEqual(param.session_id, session_uri))
        rc.add_expected_notification(SessionCreatedParameters,
                                     SESSION_CREATED_METHOD,
                                     validate_success_notification)

        # If: I create a session
        oe._handle_create_session_request(rc.request_context, params)
        oe._session_map[session_uri].init_task.join()

        # Then:
        # ... Error notification should have been returned, session should be cleaned up from OE service
        rc.validate()

        # ... The session should still exist and should have connection and server setup
        self.assertIn(session_uri, oe._session_map)
        self.assertIsInstance(oe._session_map[session_uri].server, Server)
        self.assertTrue(oe._session_map[session_uri].is_ready)
예제 #5
0
    def _handle_er_node_alivetasksuccessful(self, method: TEventHandler,
                                            get_tasks: TGetTask):
        # Setup: Create an OE service with a session preloaded
        oe, session, session_uri = self._preloaded_oe_service()

        # ... Define validation for the return notification
        def validate_success_notification(response: ExpandCompletedParameters):
            self.assertIsNone(response.error_message)
            self.assertEqual(response.session_id, session_uri)
            self.assertEqual(response.node_path, '/')
            self.assertIsInstance(response.nodes, list)
            for node in response.nodes:
                self.assertIsInstance(node, NodeInfo)

        def myfunc(e):
            while not e.isSet():
                pass

        # If: I expand a node
        rc = RequestFlowValidator()
        rc.add_expected_response(bool, self.assertTrue)
        params = ExpandParameters.from_dict({
            'session_id': session_uri,
            'node_path': '/'
        })
        testevent = threading.Event()
        testtask = threading.Thread(target=myfunc, args=(testevent, ))
        session.expand_tasks[params.node_path] = testtask
        session.refresh_tasks[params.node_path] = testtask
        testtask.start()
        method(oe, rc.request_context, params)

        # Then:
        # ... I should have gotten a completed successfully message
        rc.validate()

        # ... The thread should be attached to the session
        self.assertEqual(len(get_tasks(session)), 1)
        testevent.set()
예제 #6
0
    def _handle_er_node_successful(self, method: TEventHandler,
                                   get_tasks: TGetTask):
        # Setup: Create an OE service with a session preloaded
        oe, session, session_uri = self._preloaded_oe_service()

        # ... Define validation for the return notification
        def validate_success_notification(response: ExpandCompletedParameters):
            self.assertIsNone(response.error_message)
            self.assertEqual(response.session_id, session_uri)
            self.assertEqual(response.node_path, '/')
            self.assertIsInstance(response.nodes, list)
            for node in response.nodes:
                self.assertIsInstance(node, NodeInfo)

        # If: I expand a node
        rc = RequestFlowValidator()
        rc.add_expected_response(bool, self.assertTrue)
        rc.add_expected_notification(ExpandCompletedParameters,
                                     EXPAND_COMPLETED_METHOD,
                                     validate_success_notification)
        params = ExpandParameters.from_dict({
            'session_id': session_uri,
            'node_path': '/'
        })
        method(oe, rc.request_context, params)

        # Joining the threads to avoid rc.validate failure
        for task in session.expand_tasks.values():
            task.join()
        for task in session.refresh_tasks.values():
            task.join()
        # Then:
        # ... I should have gotten a completed successfully message
        rc.validate()

        # ... The thread should be attached to the session
        self.assertEqual(len(get_tasks(session)), 1)
예제 #7
0
class TaskServiceTests(unittest.TestCase):
    """Methods for testing the task service"""

    def setUp(self):
        self.task_service = TaskService()
        self.service_provider = ServiceProviderMock({constants.TASK_SERVICE_NAME: self.task_service})
        self.request_validator = RequestFlowValidator()
        self.mock_task_1 = Task(None, None, None, None, None, self.request_validator.request_context, mock.Mock(), mock.Mock())
        self.request_validator.add_expected_notification(TaskInfo, 'tasks/newtaskcreated')
        self.mock_task_1.start = mock.Mock()
        self.mock_task_2 = Task(None, None, None, None, None, self.request_validator.request_context, mock.Mock(), mock.Mock())
        self.request_validator.add_expected_notification(TaskInfo, 'tasks/newtaskcreated')
        self.mock_task_2.start = mock.Mock()

    def test_registration(self):
        """Test that the service registers its cancel and list methods correctly"""
        # If I call the task service's register method
        self.task_service.register(self.service_provider)

        # Then CANCEL_TASK_REQUEST and LIST_TASKS_REQUEST should have been registered
        self.service_provider.server.set_request_handler.assert_has_calls(
            [mock.call(CANCEL_TASK_REQUEST, self.task_service.handle_cancel_request), mock.call(LIST_TASKS_REQUEST, self.task_service.handle_list_request)],
            any_order=True)

    def test_start_task(self):
        """Test that the service can start tasks"""
        # If I start both tasks
        self.task_service.start_task(self.mock_task_1)
        self.task_service.start_task(self.mock_task_2)

        # Then the task service is aware of them
        self.assertIs(self.task_service._task_map[self.mock_task_1.id], self.mock_task_1)
        self.assertIs(self.task_service._task_map[self.mock_task_2.id], self.mock_task_2)

        # And the tasks' start methods were called
        self.mock_task_1.start.assert_called_once()
        self.mock_task_2.start.assert_called_once()

    def test_cancel_request(self):
        """Test that sending a cancellation request attempts to cancel the task"""
        # Set up a task
        self.mock_task_1.cancel = mock.Mock(return_value=True)
        self.mock_task_1.status = TaskStatus.IN_PROGRESS
        self.task_service.start_task(self.mock_task_1)

        # Set up the request flow validator
        self.request_validator.add_expected_response(bool, self.assertTrue)

        # If I call the cancellation handler
        params = CancelTaskParameters()
        params.task_id = self.mock_task_1.id
        self.task_service.handle_cancel_request(self.request_validator.request_context, params)

        # Then the task's cancel method should have been called and a positive response should have been sent
        self.mock_task_1.cancel.assert_called_once()
        self.request_validator.validate()

    def test_cancel_request_no_task(self):
        """Test that the cancellation handler returns false when there is no task to cancel"""
        # Set up the request flow validator
        self.request_validator.add_expected_response(bool, self.assertFalse)

        # If I call the cancellation handler without a corresponding task
        params = CancelTaskParameters()
        params.task_id = self.mock_task_1.id
        self.task_service.handle_cancel_request(self.request_validator.request_context, params)

        # Then a negative response should have been sent
        self.request_validator.validate()

    def test_list_all_tasks(self):
        """Test that the list task handler displays the correct task information"""
        self._test_list_tasks(False)

    def test_list_active_tasks(self):
        """Test that the list task handler displays the correct task information"""
        self._test_list_tasks(True)

    def _test_list_tasks(self, active_tasks_only: bool):
        # Set up task 1 to be in progress and task 2 to be complete
        self.task_service.start_task(self.mock_task_1)
        self.task_service.start_task(self.mock_task_2)
        self.mock_task_1.status = TaskStatus.IN_PROGRESS
        self.mock_task_2.status = TaskStatus.SUCCEEDED

        # Set up the request validator
        def validate_list_response(response_params: List[TaskInfo]):
            actual_response_dict = [info.__dict__ for info in response_params]
            expected_response_dict = [self.mock_task_1.task_info.__dict__]
            if not active_tasks_only:
                expected_response_dict.append(self.mock_task_2.task_info.__dict__)
            self.assertEqual(len(actual_response_dict), len(expected_response_dict))
            for expected_info in expected_response_dict:
                self.assertIn(expected_info, actual_response_dict)
        self.request_validator.add_expected_response(list, validate_list_response)

        # If I start the tasks and then list them
        self.task_service.start_task(self.mock_task_1)
        self.task_service.start_task(self.mock_task_2)
        params = ListTasksParameters()
        params.list_active_tasks_only = active_tasks_only
        self.task_service.handle_list_request(self.request_validator.request_context, params)

        # Then the service responds with TaskInfo for only task 1
        self.request_validator.validate()