Ejemplo n.º 1
0
    def test_init_session_failed_connection(self):
        # Setup:
        # ... Create OE service with mock connection service that returns a failed connection response
        cs = ConnectionService()
        connect_response = ConnectionCompleteParams()
        connect_response.error_message = 'Boom!'
        cs.connect = mock.MagicMock(return_value=connect_response)
        oe = ObjectExplorerService()
        oe._service_provider = utils.get_mock_service_provider(
            {constants.CONNECTION_SERVICE_NAME: cs})

        # If: I initialize a session (NOTE: We're bypassing request handler to avoid threading issues)
        params, session_uri = _connection_details()
        session = ObjectExplorerSession(session_uri, params)
        oe._session_map[session_uri] = session

        rc = RequestFlowValidator()
        rc.add_expected_notification(
            SessionCreatedParameters, SESSION_CREATED_METHOD,
            lambda param: self._validate_init_error(param, session_uri))
        oe._initialize_session(rc.request_context, session)

        # Then:
        # ... Error notification should have been returned, session should be cleaned up from OE service
        rc.validate()
        self.assertDictEqual(oe._session_map, {})
    def test_handle_database_change_request(self):
        """Test that the handle_connect_request method kicks off a new thread to do the connection"""
        # Setup: Create a mock request context to handle output
        rc = utils.MockRequestContext()
        connect_response = ConnectionCompleteParams()
        self.connection_service.connect = Mock(return_value=connect_response)
        self.connection_service.get_connection_info = mock.MagicMock()

        params: ChangeDatabaseRequestParams = ChangeDatabaseRequestParams.from_dict({
            'owner_uri': 'someUri',
            'new_database': 'newDb'
        })

        self.connection_service.handle_change_database_request(rc, params)

        connection_thread = self.connection_service.owner_to_thread_map[params.owner_uri]
        self.assertIsNotNone(connection_thread)
        connection_thread.join()

        self.connection_service.connect.assert_called_once()

        # ... A True should have been sent as the response to the request
        rc.send_response.assert_called_once_with(True)

        # ... A connection complete notification should have been sent back as well
        rc.send_notification.assert_called_once_with(CONNECTION_COMPLETE_METHOD, connect_response)

        # ... An error should not have been called
        rc.send_error.assert_not_called()
Ejemplo n.º 3
0
    def test_handle_create_session_successful(self):
        # Setup:
        # ... Create OE service with mock connection service that returns a successful connection response
        mock_connection = utils.MockConnection({
            'host': 'myserver',
            'dbname': 'postgres',
            'user': '******',
            'port': 123
        })
        cs = ConnectionService()
        cs.connect = mock.MagicMock(return_value=ConnectionCompleteParams())
        cs.get_connection = mock.MagicMock(return_value=mock_connection)
        oe = ObjectExplorerService()
        oe._service_provider = utils.get_mock_service_provider(
            {constants.CONNECTION_SERVICE_NAME: cs})

        # ... Create parameters, session, request context validator
        params, session_uri = _connection_details()

        # ... Create validation of success notification
        def validate_success_notification(response: SessionCreatedParameters):
            self.assertTrue(response.success)
            self.assertEqual(response.session_id, session_uri)
            self.assertIsNone(response.error_message)

            self.assertIsInstance(response.root_node, NodeInfo)
            self.assertEqual(response.root_node.label, TEST_DBNAME)
            self.assertEqual(response.root_node.node_path, session_uri)
            self.assertEqual(response.root_node.node_type, 'Database')
            self.assertIsInstance(response.root_node.metadata, ObjectMetadata)
            self.assertEqual(response.root_node.metadata.urn,
                             oe._session_map[session_uri].server.urn_base)
            self.assertEqual(
                response.root_node.metadata.name,
                oe._session_map[session_uri].server.maintenance_db_name)
            self.assertEqual(response.root_node.metadata.metadata_type_name,
                             'Database')
            self.assertFalse(response.root_node.is_leaf)

        rc = RequestFlowValidator()
        rc.add_expected_response(
            CreateSessionResponse,
            lambda param: self.assertEqual(param.session_id, session_uri))
        rc.add_expected_notification(SessionCreatedParameters,
                                     SESSION_CREATED_METHOD,
                                     validate_success_notification)

        # If: I create a session
        oe._handle_create_session_request(rc.request_context, params)
        oe._session_map[session_uri].init_task.join()

        # Then:
        # ... Error notification should have been returned, session should be cleaned up from OE service
        rc.validate()

        # ... The session should still exist and should have connection and server setup
        self.assertIn(session_uri, oe._session_map)
        self.assertIsInstance(oe._session_map[session_uri].server, Server)
        self.assertTrue(oe._session_map[session_uri].is_ready)
    def test_handle_scriptas_successful_operation(self):
        # NOTE: There's no need to test all types here, the scripter tests should handle this

        # Setup:
        # ... Create a scripting service
        mock_connection = MockConnection(None)
        cs = ConnectionService()
        cs.connect = mock.MagicMock(return_value=ConnectionCompleteParams())
        cs.get_connection = mock.MagicMock(return_value=mock_connection)
        ss = ScriptingService()
        ss._service_provider = utils.get_mock_service_provider(
            {CONNECTION_SERVICE_NAME: cs})

        # ... Create validation logic for responses
        def validate_response(response: ScriptAsResponse) -> None:
            self.assertEqual(response.owner_uri, TestScriptingService.MOCK_URI)
            self.assertEqual(response.script, TestScriptingService.MOCK_SCRIPT)

        # ... Create a scripter with mocked out calls
        patch_path = 'pgsqltoolsservice.scripting.scripting_service.Scripter'
        with mock.patch(patch_path) as scripter_patch:
            mock_scripter: Scripter = Scripter(mock_connection)
            mock_scripter.script = mock.MagicMock(
                return_value=TestScriptingService.MOCK_SCRIPT)
            scripter_patch.return_value = mock_scripter

            scripting_object = {
                'type': 'Table',
                'name': 'test_table',
                'schema': 'test_schema'
            }

            # For each operation supported
            for operation in ScriptOperation:
                # If: I request to script
                rc: RequestFlowValidator = RequestFlowValidator()
                rc.add_expected_response(ScriptAsResponse, validate_response)

                params = ScriptAsParameters.from_dict({
                    'ownerUri':
                    TestScriptingService.MOCK_URI,
                    'operation':
                    operation,
                    'scripting_objects': [scripting_object]
                })

                ss._handle_scriptas_request(rc.request_context, params)

                # Then:
                # ... The request should have been handled correctly
                rc.validate()

            # ... All of the scripter methods should have been called once
            matches = {operation: 0 for operation in ScriptOperation}
            for call_args in mock_scripter.script.call_args_list:
                matches[call_args[0][0]] += 1

            for calls in matches.values():
                self.assertEqual(calls, 1)
Ejemplo n.º 5
0
    def test_create_connection_failed(self):
        # Setup:
        oe = ObjectExplorerService()
        cs = ConnectionService()
        connect_response = ConnectionCompleteParams()
        error = 'Failed'
        connect_response.error_message = error
        cs.connect = mock.MagicMock(return_value=connect_response)
        oe._service_provider = utils.get_mock_service_provider(
            {constants.CONNECTION_SERVICE_NAME: cs})
        params, session_uri = _connection_details()
        session = ObjectExplorerSession(session_uri, params)

        with self.assertRaises(RuntimeError) as context:
            oe._create_connection(session, 'foo_database')
            self.assertEqual(error, str(context.exception))

        cs.connect.assert_called_once()
Ejemplo n.º 6
0
def _build_connection_response_error(connection_info: ConnectionInfo, connection_type: ConnectionType, err)\
        -> ConnectionCompleteParams:
    """Build a connection complete response object"""
    response: ConnectionCompleteParams = ConnectionCompleteParams()
    response.owner_uri = connection_info.owner_uri
    response.type = connection_type
    response.messages = str(err)
    response.error_message = str(err)

    return response
    def test_handle_database_change_request_with_empty_connection_info_for_false(self):
        """Test that the handle_connect_request method kicks off a new thread to do the connection"""
        # Setup: Create a mock request context to handle output
        rc = utils.MockRequestContext()
        connect_response = ConnectionCompleteParams()
        self.connection_service.connect = Mock(return_value=connect_response)

        params: ChangeDatabaseRequestParams = ChangeDatabaseRequestParams.from_dict({
            'owner_uri': 'someUri',
            'new_database': 'newDb'
        })

        result: bool = self.connection_service.handle_change_database_request(rc, params)
        self.assertEqual(result, False)
Ejemplo n.º 8
0
    def test_create_connection_successful(self):
        # Setup:
        mock_connection = MockConnection('test')
        oe = ObjectExplorerService()
        cs = ConnectionService()
        cs.connect = mock.MagicMock(return_value=ConnectionCompleteParams())
        cs.get_connection = mock.MagicMock(return_value=mock_connection)
        oe._service_provider = utils.get_mock_service_provider(
            {constants.CONNECTION_SERVICE_NAME: cs})
        params, session_uri = _connection_details()
        session = ObjectExplorerSession(session_uri, params)
        connection = oe._create_connection(session, 'foo_database')

        self.assertIsNotNone(connection)
        self.assertEqual(connection, mock_connection)
        cs.connect.assert_called_once()
        cs.get_connection.assert_called_once()
Ejemplo n.º 9
0
def _build_connection_response(connection_info: ConnectionInfo, connection_type: ConnectionType) -> ConnectionCompleteParams:
    """Build a connection complete response object"""
    connection = connection_info.get_connection(connection_type)
    dsn_parameters = connection.get_dsn_parameters()

    connection_summary = ConnectionSummary(
        server_name=dsn_parameters['host'],
        database_name=dsn_parameters['dbname'],
        user_name=dsn_parameters['user'])

    response: ConnectionCompleteParams = ConnectionCompleteParams()
    response.connection_id = connection_info.connection_id
    response.connection_summary = connection_summary
    response.owner_uri = connection_info.owner_uri
    response.type = connection_type
    response.server_info = _get_server_info(connection)

    return response
    def test_handle_connect_request(self):
        """Test that the handle_connect_request method kicks off a new thread to do the connection"""
        # Setup: Create a mock request context to handle output
        rc = utils.MockRequestContext()
        connect_response = ConnectionCompleteParams()
        self.connection_service.connect = Mock(return_value=connect_response)

        # If: I make a request to connect
        params: ConnectRequestParams = ConnectRequestParams.from_dict({
            'ownerUri':
            'someUri',
            'type':
            ConnectionType.QUERY,
            'connection': {
                'server_name': 'someserver',
                'user_name': 'someuser',
                'database_name': 'somedb',
                'options': {
                    'password': '******'
                }
            }
        })

        # Connect and wait for the thread to finish executing, then verify the connection information
        self.connection_service.handle_connect_request(rc, params)
        connection_thread = self.connection_service.owner_to_thread_map[
            params.owner_uri]
        self.assertIsNotNone(connection_thread)
        connection_thread.join()

        # Then:
        # ... Connect should have been called once
        self.connection_service.connect.assert_called_once_with(params)

        # ... A True should have been sent as the response to the request
        rc.send_response.assert_called_once_with(True)

        # ... A connection complete notification should have been sent back as well
        rc.send_notification.assert_called_once_with(
            CONNECTION_COMPLETE_METHOD, connect_response)

        # ... An error should not have been called
        rc.send_error.assert_not_called()
Ejemplo n.º 11
0
    def test_handle_scriptas_invalid_operation(self):
        # Setup: Create a scripting service
        mock_connection = {}
        cs = ConnectionService()
        cs.connect = mock.MagicMock(return_value=ConnectionCompleteParams())
        cs.get_connection = mock.MagicMock(return_value=mock_connection)
        ss = ScriptingService()
        ss._service_provider = utils.get_mock_service_provider(
            {CONNECTION_SERVICE_NAME: cs})

        # If: I create an OE session with missing params
        rc: RequestFlowValidator = RequestFlowValidator()
        rc.add_expected_error(type(None),
                              RequestFlowValidator.basic_error_validation)
        ss._handle_scriptas_request(rc.request_context, None)

        # Then:
        # ... I should get an error response
        rc.validate()