Esempio n. 1
0
    def test_init_session_failed_connection(self):
        # Setup:
        # ... Create OE service with mock connection service that returns a failed connection response
        cs = ConnectionService()
        connect_response = ConnectionCompleteParams()
        connect_response.error_message = 'Boom!'
        cs.connect = mock.MagicMock(return_value=connect_response)
        oe = ObjectExplorerService()
        oe._service_provider = utils.get_mock_service_provider(
            {constants.CONNECTION_SERVICE_NAME: cs})

        # If: I initialize a session (NOTE: We're bypassing request handler to avoid threading issues)
        params, session_uri = _connection_details()
        session = ObjectExplorerSession(session_uri, params)
        oe._session_map[session_uri] = session

        rc = RequestFlowValidator()
        rc.add_expected_notification(
            SessionCreatedParameters, SESSION_CREATED_METHOD,
            lambda param: self._validate_init_error(param, session_uri))
        oe._initialize_session(rc.request_context, session)

        # Then:
        # ... Error notification should have been returned, session should be cleaned up from OE service
        rc.validate()
        self.assertDictEqual(oe._session_map, {})
    def setUp(self):
        """Set up the tests with common connection parameters"""
        # Set up the mock connection service and connection info
        self.connection_service = ConnectionService()
        self.connection_service._service_provider = {
            constants.WORKSPACE_SERVICE_NAME: WorkspaceService()
        }
        self.owner_uri = 'test_uri'
        self.connection_type = ConnectionType.DEFAULT
        self.connect_params: ConnectRequestParams = ConnectRequestParams.from_dict(
            {
                'ownerUri': self.owner_uri,
                'type': self.connection_type,
                'connection': {
                    'options': {}
                }
            })
        self.mock_connection = MockConnection(dsn_parameters={
            'host': 'myserver',
            'dbname': 'postgres',
            'user': '******'
        })

        # Mock psycopg2's connect method to store the current cancellation token. This lets us
        # capture the cancellation token state as it would be during a long-running connection.
        self.token_store = []
    def test_handle_scriptas_successful_operation(self):
        # NOTE: There's no need to test all types here, the scripter tests should handle this

        # Setup:
        # ... Create a scripting service
        mock_connection = MockConnection(None)
        cs = ConnectionService()
        cs.connect = mock.MagicMock(return_value=ConnectionCompleteParams())
        cs.get_connection = mock.MagicMock(return_value=mock_connection)
        ss = ScriptingService()
        ss._service_provider = utils.get_mock_service_provider(
            {CONNECTION_SERVICE_NAME: cs})

        # ... Create validation logic for responses
        def validate_response(response: ScriptAsResponse) -> None:
            self.assertEqual(response.owner_uri, TestScriptingService.MOCK_URI)
            self.assertEqual(response.script, TestScriptingService.MOCK_SCRIPT)

        # ... Create a scripter with mocked out calls
        patch_path = 'pgsqltoolsservice.scripting.scripting_service.Scripter'
        with mock.patch(patch_path) as scripter_patch:
            mock_scripter: Scripter = Scripter(mock_connection)
            mock_scripter.script = mock.MagicMock(
                return_value=TestScriptingService.MOCK_SCRIPT)
            scripter_patch.return_value = mock_scripter

            scripting_object = {
                'type': 'Table',
                'name': 'test_table',
                'schema': 'test_schema'
            }

            # For each operation supported
            for operation in ScriptOperation:
                # If: I request to script
                rc: RequestFlowValidator = RequestFlowValidator()
                rc.add_expected_response(ScriptAsResponse, validate_response)

                params = ScriptAsParameters.from_dict({
                    'ownerUri':
                    TestScriptingService.MOCK_URI,
                    'operation':
                    operation,
                    'scripting_objects': [scripting_object]
                })

                ss._handle_scriptas_request(rc.request_context, params)

                # Then:
                # ... The request should have been handled correctly
                rc.validate()

            # ... All of the scripter methods should have been called once
            matches = {operation: 0 for operation in ScriptOperation}
            for call_args in mock_scripter.script.call_args_list:
                matches[call_args[0][0]] += 1

            for calls in matches.values():
                self.assertEqual(calls, 1)
Esempio n. 4
0
    def test_handle_create_session_successful(self):
        # Setup:
        # ... Create OE service with mock connection service that returns a successful connection response
        mock_connection = utils.MockConnection({
            'host': 'myserver',
            'dbname': 'postgres',
            'user': '******',
            'port': 123
        })
        cs = ConnectionService()
        cs.connect = mock.MagicMock(return_value=ConnectionCompleteParams())
        cs.get_connection = mock.MagicMock(return_value=mock_connection)
        oe = ObjectExplorerService()
        oe._service_provider = utils.get_mock_service_provider(
            {constants.CONNECTION_SERVICE_NAME: cs})

        # ... Create parameters, session, request context validator
        params, session_uri = _connection_details()

        # ... Create validation of success notification
        def validate_success_notification(response: SessionCreatedParameters):
            self.assertTrue(response.success)
            self.assertEqual(response.session_id, session_uri)
            self.assertIsNone(response.error_message)

            self.assertIsInstance(response.root_node, NodeInfo)
            self.assertEqual(response.root_node.label, TEST_DBNAME)
            self.assertEqual(response.root_node.node_path, session_uri)
            self.assertEqual(response.root_node.node_type, 'Database')
            self.assertIsInstance(response.root_node.metadata, ObjectMetadata)
            self.assertEqual(response.root_node.metadata.urn,
                             oe._session_map[session_uri].server.urn_base)
            self.assertEqual(
                response.root_node.metadata.name,
                oe._session_map[session_uri].server.maintenance_db_name)
            self.assertEqual(response.root_node.metadata.metadata_type_name,
                             'Database')
            self.assertFalse(response.root_node.is_leaf)

        rc = RequestFlowValidator()
        rc.add_expected_response(
            CreateSessionResponse,
            lambda param: self.assertEqual(param.session_id, session_uri))
        rc.add_expected_notification(SessionCreatedParameters,
                                     SESSION_CREATED_METHOD,
                                     validate_success_notification)

        # If: I create a session
        oe._handle_create_session_request(rc.request_context, params)
        oe._session_map[session_uri].init_task.join()

        # Then:
        # ... Error notification should have been returned, session should be cleaned up from OE service
        rc.validate()

        # ... The session should still exist and should have connection and server setup
        self.assertIn(session_uri, oe._session_map)
        self.assertIsInstance(oe._session_map[session_uri].server, Server)
        self.assertTrue(oe._session_map[session_uri].is_ready)
Esempio n. 5
0
 def setUp(self):
     self.admin_service = AdminService()
     self.connection_service = ConnectionService()
     self.service_provider = ServiceProviderMock({
         constants.ADMIN_SERVICE_NAME: self.admin_service,
         constants.CONNECTION_SERVICE_NAME: self.connection_service})
     self.admin_service.register(self.service_provider)
    def setUp(self):
        self._service_under_test = EditDataService()
        self._mock_connection = mock.MagicMock()
        self._service_provider = ServiceProviderMock({
            'query_execution': {},
            'connection':
            self._mock_connection
        })

        self.cursor = utils.MockCursor(None)
        self.connection = utils.MockConnection(cursor=self.cursor)
        self.cursor.connection = self.connection
        self.connection_service = ConnectionService()
        self.connection_service.get_connection = mock.Mock(
            return_value=self.connection)
        self.query_execution_service = QueryExecutionService()
        self._service_provider._services = {
            constants.CONNECTION_SERVICE_NAME: self.connection_service,
            constants.QUERY_EXECUTION_SERVICE_NAME:
            self.query_execution_service
        }
        self._service_provider._is_initialized = True

        self._service_under_test.register(self._service_provider)

        # self._connection = MockConnection({"port": "8080", "host": "test", "dbname": "test"})
        self._initialize_edit_request = InitializeEditParams()

        self._initialize_edit_request.schema_name = 'public'
        self._initialize_edit_request.object_name = 'Employee'
        self._initialize_edit_request.object_type = 'Table'
        self._initialize_edit_request.owner_uri = 'testuri'
Esempio n. 7
0
 def setUp(self):
     """Constructor"""
     self.default_uri = 'file://my.sql'
     self.flow_validator = RequestFlowValidator()
     self.mock_server_set_request = mock.MagicMock()
     self.mock_server = JSONRPCServer(None, None)
     self.mock_server.set_request_handler = self.mock_server_set_request
     self.mock_workspace_service = WorkspaceService()
     self.mock_connection_service = ConnectionService()
     self.mock_service_provider = ServiceProvider(self.mock_server, {},
                                                  None)
     self.mock_service_provider._services[
         constants.WORKSPACE_SERVICE_NAME] = self.mock_workspace_service
     self.mock_service_provider._services[
         constants.CONNECTION_SERVICE_NAME] = self.mock_connection_service
     self.mock_service_provider._is_initialized = True
     self.default_text_position = TextDocumentPosition.from_dict({
         'text_document': {
             'uri': self.default_uri
         },
         'position': {
             'line': 3,
             'character': 10
         }
     })
     self.default_text_document_id = TextDocumentIdentifier.from_dict(
         {'uri': self.default_uri})
Esempio n. 8
0
    def setUp(self):
        """Constructor"""
        self.default_connection_key = 'server_db_user'
        self.mock_connection_service = ConnectionService()
        self.mock_server = JSONRPCServer(None, None)
        self.mock_service_provider = ServiceProvider(self.mock_server, {},
                                                     None)
        self.mock_service_provider._services[
            constants.CONNECTION_SERVICE_NAME] = self.mock_connection_service
        self.mock_service_provider._is_initialized = True

        # Create connection information for use in the tests
        self.connection_details = ConnectionDetails.from_data({})
        self.connection_details.server_name = 'test_host'
        self.connection_details.database_name = 'test_db'
        self.connection_details.user_name = 'user'
        self.expected_context_key = 'test_host|test_db|user'
        self.expected_connection_uri = INTELLISENSE_URI + self.expected_context_key
        self.test_uri = 'test_uri'
        self.connection_info = ConnectionInfo(self.test_uri,
                                              self.connection_details)

        # Create mock CompletionRefresher to avoid calls to create separate thread
        self.refresher_mock = mock.MagicMock()
        self.refresh_method_mock = mock.MagicMock()
        self.refresher_mock.refresh = self.refresh_method_mock
    def test_list_databases(self):
        """Test that the list databases handler correctly lists the connection's databases"""
        connection_service = ConnectionService()
        connection_uri = 'someuri'
        request_context = MockRequestContext()
        params = ListDatabasesParams()
        params.owner_uri = connection_uri
        connection = psycopg2.connect(**get_connection_details())
        connection_service.get_connection = mock.Mock(return_value=connection)

        # If I call the list database handler
        connection_service.handle_list_databases(request_context, params)

        # Then a response is returned that lists all the databases
        database_names = request_context.last_response_params.database_names
        self.assertGreater(len(database_names), 0)
        self.assertIn(connection.get_dsn_parameters()['dbname'], database_names)
Esempio n. 10
0
    def test_create_connection_successful(self):
        # Setup:
        mock_connection = MockConnection('test')
        oe = ObjectExplorerService()
        cs = ConnectionService()
        cs.connect = mock.MagicMock(return_value=ConnectionCompleteParams())
        cs.get_connection = mock.MagicMock(return_value=mock_connection)
        oe._service_provider = utils.get_mock_service_provider(
            {constants.CONNECTION_SERVICE_NAME: cs})
        params, session_uri = _connection_details()
        session = ObjectExplorerSession(session_uri, params)
        connection = oe._create_connection(session, 'foo_database')

        self.assertIsNotNone(connection)
        self.assertEqual(connection, mock_connection)
        cs.connect.assert_called_once()
        cs.get_connection.assert_called_once()
Esempio n. 11
0
    def test_create_connection_failed(self):
        # Setup:
        oe = ObjectExplorerService()
        cs = ConnectionService()
        connect_response = ConnectionCompleteParams()
        error = 'Failed'
        connect_response.error_message = error
        cs.connect = mock.MagicMock(return_value=connect_response)
        oe._service_provider = utils.get_mock_service_provider(
            {constants.CONNECTION_SERVICE_NAME: cs})
        params, session_uri = _connection_details()
        session = ObjectExplorerSession(session_uri, params)

        with self.assertRaises(RuntimeError) as context:
            oe._create_connection(session, 'foo_database')
            self.assertEqual(error, str(context.exception))

        cs.connect.assert_called_once()
    def test_handle_scriptas_invalid_operation(self):
        # Setup: Create a scripting service
        mock_connection = {}
        cs = ConnectionService()
        cs.connect = mock.MagicMock(return_value=ConnectionCompleteParams())
        cs.get_connection = mock.MagicMock(return_value=mock_connection)
        ss = ScriptingService()
        ss._service_provider = utils.get_mock_service_provider(
            {CONNECTION_SERVICE_NAME: cs})

        # If: I create an OE session with missing params
        rc: RequestFlowValidator = RequestFlowValidator()
        rc.add_expected_error(type(None),
                              RequestFlowValidator.basic_error_validation)
        ss._handle_scriptas_request(rc.request_context, None)

        # Then:
        # ... I should get an error response
        rc.validate()
 def setUp(self):
     self.metadata_service = MetadataService()
     self.connection_service = ConnectionService()
     self.service_provider = ServiceProviderMock({
         constants.METADATA_SERVICE_NAME:
         self.metadata_service,
         constants.CONNECTION_SERVICE_NAME:
         self.connection_service
     })
     self.metadata_service.register(self.service_provider)
     self.test_uri = 'test_uri'
Esempio n. 14
0
    def setUp(self):
        # Setup: Create an OE service and add a session to it
        self.cs = ConnectionService()
        self.mock_connection = {}
        self.oe = ObjectExplorerService()
        params, session_uri = _connection_details()
        self.session = ObjectExplorerSession(session_uri, params)
        self.oe._session_map[session_uri] = self.session
        name = 'dbname'
        self.mock_server = Server(MockConnection(name))
        self.session.server = self.mock_server
        self.db = Database(self.mock_server, name)
        self.db._connection = MockConnection(name)
        self.session.server._child_objects[Database.__name__] = [self.db]
        self.cs.get_connection = mock.MagicMock(
            return_value=self.mock_connection)

        self.cs.disconnect = mock.MagicMock(return_value=True)
        self.oe._service_provider = utils.get_mock_service_provider(
            {constants.CONNECTION_SERVICE_NAME: self.cs})
class TestConnectionService(unittest.TestCase):
    """Methods for testing the connection service"""

    def setUp(self):
        """Set up the tests with a connection service"""
        self.connection_service = ConnectionService()
        self.connection_service._service_provider = utils.get_mock_service_provider({constants.WORKSPACE_SERVICE_NAME: WorkspaceService()})

    def test_connect(self):
        """Test that the service connects to a PostgreSQL server"""
        # Set up the parameters for the connection
        params: ConnectRequestParams = ConnectRequestParams.from_dict({
            'ownerUri': 'someUri',
            'type': ConnectionType.DEFAULT,
            'connection': {
                'options': {
                    'user': '******',
                    'password': '******',
                    'host': 'myserver',
                    'dbname': 'postgres'
                }
            }
        })

        # Set up the mock connection for psycopg2's connect method to return
        mock_connection = MockConnection(dsn_parameters={
            'host': 'myserver',
            'dbname': 'postgres',
            'user': '******'
        })

        # Set up the connection service and call its connect method with the supported options
        with mock.patch('psycopg2.connect', new=mock.Mock(return_value=mock_connection)):
            response = self.connection_service.connect(params)

        # Verify that psycopg2's connection method was called and that the
        # response has a connection id, indicating success.
        self.assertIs(self.connection_service.owner_to_connection_map[params.owner_uri].get_connection(params.type),
                      mock_connection)
        self.assertIsNotNone(response.connection_id)
        self.assertIsNotNone(response.server_info.server_version)
        self.assertFalse(response.server_info.is_cloud)

    def test_server_info_is_cloud(self):
        """Test that the connection response handles cloud connections correctly"""
        self.server_info_is_cloud_internal('postgres.database.azure.com', True)
        self.server_info_is_cloud_internal('postgres.database.windows.net', True)
        self.server_info_is_cloud_internal('some.host.com', False)

    def server_info_is_cloud_internal(self, host_suffix, is_cloud):
        """Test that the connection response handles cloud connections correctly"""
        # Set up the parameters for the connection
        connection_uri = 'someuri'
        connection_details = ConnectionDetails()
        connection_details.options = {
            'user': '******',
            'password': '******',
            'host': f'myserver{host_suffix}',
            'dbname': 'postgres'}
        connection_type = ConnectionType.DEFAULT

        # Set up the mock connection for psycopg2's connect method to return
        mock_connection = MockConnection(dsn_parameters={
            'host': f'myserver{host_suffix}',
            'dbname': 'postgres',
            'user': '******'
        })

        # Set up the connection service and call its connect method with the
        # supported options
        with mock.patch('psycopg2.connect', new=mock.Mock(return_value=mock_connection)):
            response = self.connection_service.connect(
                ConnectRequestParams(connection_details, connection_uri, connection_type))

        # Verify that the response's serverInfo.isCloud attribute is set correctly
        self.assertIsNotNone(response.connection_id)
        self.assertIsNotNone(response.server_info.server_version)
        self.assertEqual(response.server_info.is_cloud, is_cloud)

    def test_changing_options_disconnects_existing_connection(self):
        """
        Test that the connect method disconnects an existing connection when trying to open the same connection with
        different options
        """
        # Set up the test with mock data
        connection_uri = 'someuri'
        connection_type = ConnectionType.DEFAULT
        mock_connection = MockConnection(dsn_parameters={
            'host': 'myserver',
            'dbname': 'postgres',
            'user': '******'
        })

        # Insert a ConnectionInfo object into the connection service's map
        old_connection_details = ConnectionDetails.from_data({
            'host': 'myserver',
            'dbname': 'postgres',
            'user': '******',
            'abc': 123
        })
        old_connection_info = ConnectionInfo(connection_uri, old_connection_details)
        old_connection_info.add_connection(connection_type, mock_connection)
        self.connection_service.owner_to_connection_map[connection_uri] = old_connection_info

        # Create a different request with the same owner uri
        params: ConnectRequestParams = ConnectRequestParams.from_dict({
            'ownerUri': connection_uri,
            'type': connection_type,
            'connection': {
                'options': {
                    'host': 'myserver',
                    'dbname': 'postgres',
                    'user': '******',
                    'abc': 234
                }
            }
        })

        # Connect with different options, and verify that disconnect was called
        with mock.patch('psycopg2.connect', new=mock.Mock(return_value=mock_connection)):
            self.connection_service.connect(params)
        mock_connection.close.assert_called_once()

    def test_same_options_uses_existing_connection(self):
        """Test that the connect method uses an existing connection when connecting again with the same options"""
        # Set up the test with mock data
        connection_uri = 'someuri'
        connection_type = ConnectionType.DEFAULT
        mock_connection = MockConnection(dsn_parameters={
            'host': 'myserver',
            'dbname': 'postgres',
            'user': '******'
        })

        # Insert a ConnectionInfo object into the connection service's map
        old_connection_details = ConnectionDetails.from_data({
            'host': 'myserver',
            'dbname': 'postgres',
            'user': '******',
            'abc': 123
        })
        old_connection_info = ConnectionInfo(connection_uri, old_connection_details)
        old_connection_info.add_connection(connection_type, mock_connection)
        self.connection_service.owner_to_connection_map[connection_uri] = old_connection_info

        # Connect with identical options, and verify that disconnect was not called
        params: ConnectRequestParams = ConnectRequestParams.from_dict({
            'ownerUri': connection_uri,
            'type': connection_type,
            'connection': {
                'options': old_connection_details.options
            }
        })
        with mock.patch('psycopg2.connect', new=mock.Mock(return_value=mock_connection)) as mock_psycopg2_connect:
            response = self.connection_service.connect(params)
            mock_psycopg2_connect.assert_not_called()
        mock_connection.close.assert_not_called()
        self.assertIsNotNone(response.connection_id)

    def test_response_when_connect_fails(self):
        """Test that the proper response is given when a connection fails"""
        error_message = 'some error'
        params: ConnectRequestParams = ConnectRequestParams.from_dict({
            'ownerUri': 'someUri',
            'type': ConnectionType.DEFAULT,
            'connection': {
                'options': {
                    'connectionString': ''
                }
            }
        })
        with mock.patch('psycopg2.connect', new=mock.Mock(side_effect=Exception(error_message))):
            response = self.connection_service.connect(params)
        # The response should not have a connection ID and should contain the error message
        self.assertIsNone(response.connection_id)
        self.assertEqual(response.error_message, error_message)

    def test_register_on_connect_callback(self):
        """Tests that callbacks are added to a list of callbacks as expected"""
        callback = MagicMock()
        self.connection_service.register_on_connect_callback(callback)
        self.assertListEqual(self.connection_service._on_connect_callbacks, [callback])

    def test_on_connect_backs_called_on_connect(self):
        self.run_on_connect_callback(ConnectionType.DEFAULT, True)
        self.run_on_connect_callback(ConnectionType.EDIT, False)
        self.run_on_connect_callback(ConnectionType.INTELLISENSE, False)
        self.run_on_connect_callback(ConnectionType.QUERY, False)

    def run_on_connect_callback(self, conn_type: ConnectionType, expect_callback: bool) -> None:
        """Inner function for callback tests that verifies expected behavior given different connection types"""
        callbacks = [MagicMock(), MagicMock()]
        for callback in callbacks:
            self.connection_service.register_on_connect_callback(callback)

        # Set up the parameters for the connection
        connection_uri = 'someuri'
        connection_details = ConnectionDetails()
        connection_details.options = {
            'user': '******',
            'password': '******',
            'host': f'myserver',
            'dbname': 'postgres'}
        connection_type = conn_type

        # Set up the mock connection for psycopg2's connect method to return
        mock_connection = MockConnection(dsn_parameters={
            'host': f'myserver',
            'dbname': 'postgres',
            'user': '******'
        })

        # Set up the connection service and call its connect method with the
        # supported options
        with mock.patch('psycopg2.connect', new=mock.Mock(return_value=mock_connection)):
            self.connection_service.connect(
                ConnectRequestParams(connection_details, connection_uri, connection_type))
            self.connection_service.get_connection(connection_uri, conn_type)
        # ... The mock config change callbacks should have been called
        for callback in callbacks:
            if (expect_callback):
                callback.assert_called_once()
                # Verify call args match expected
                callargs: ConnectionInfo = callback.call_args[0][0]
                self.assertEqual(callargs.owner_uri, connection_uri)
            else:
                callback.assert_not_called()

    def test_disconnect_single_type(self):
        """Test that the disconnect method calls close on a single open connection type when a type is given"""
        # Set up the test with mock data
        connection_uri = 'someuri'
        connection_type_1 = ConnectionType.DEFAULT
        connection_type_2 = ConnectionType.EDIT
        mock_connection_1 = MockConnection(dsn_parameters={
            'host': 'myserver1',
            'dbname': 'postgres1',
            'user': '******'
        })
        mock_connection_2 = MockConnection(dsn_parameters={
            'host': 'myserver2',
            'dbname': 'postgres2',
            'user': '******'
        })

        # Insert a ConnectionInfo object into the connection service's map
        old_connection_details = ConnectionDetails.from_data({'abc': 123})
        old_connection_info = ConnectionInfo(connection_uri, old_connection_details)
        old_connection_info.add_connection(connection_type_1, mock_connection_1)
        old_connection_info.add_connection(connection_type_2, mock_connection_2)
        self.connection_service.owner_to_connection_map[connection_uri] = old_connection_info

        # Close the connection by calling disconnect
        response = self.connection_service._close_connections(old_connection_info, connection_type_1)
        mock_connection_1.close.assert_called_once()
        mock_connection_2.close.assert_not_called()
        self.assertTrue(response)

    def test_disconnect_all_types(self):
        """Test that the disconnect method calls close on a all open connection types when no type is given"""
        # Set up the test with mock data
        connection_uri = 'someuri'
        connection_type_1 = ConnectionType.DEFAULT
        connection_type_2 = ConnectionType.EDIT
        mock_connection_1 = MockConnection(dsn_parameters={
            'host': 'myserver1',
            'dbname': 'postgres1',
            'user': '******'
        })
        mock_connection_2 = MockConnection(dsn_parameters={
            'host': 'myserver2',
            'dbname': 'postgres2',
            'user': '******'
        })

        # Insert a ConnectionInfo object into the connection service's map
        old_connection_details = ConnectionDetails.from_data({'abc': 123})
        old_connection_info = ConnectionInfo(connection_uri, old_connection_details)
        old_connection_info.add_connection(connection_type_1, mock_connection_1)
        old_connection_info.add_connection(connection_type_2, mock_connection_2)
        self.connection_service.owner_to_connection_map[connection_uri] = old_connection_info

        # Close the connection by calling disconnect
        response = self.connection_service._close_connections(old_connection_info)
        mock_connection_1.close.assert_called_once()
        mock_connection_2.close.assert_called_once()
        self.assertTrue(response)

    def test_disconnect_for_invalid_connection(self):
        """Test that the disconnect method returns false when called on a connection that does not exist"""
        # Set up the test with mock data
        connection_uri = 'someuri'
        connection_type_1 = ConnectionType.DEFAULT
        mock_connection_1 = MockConnection(dsn_parameters={
            'host': 'myserver1',
            'dbname': 'postgres1',
            'user': '******'
        })

        # Insert a ConnectionInfo object into the connection service's map
        old_connection_details = ConnectionDetails.from_data({'abc': 123})
        old_connection_info = ConnectionInfo(connection_uri, old_connection_details)
        old_connection_info.add_connection(connection_type_1, mock_connection_1)
        self.connection_service.owner_to_connection_map[connection_uri] = old_connection_info

        # Close the connection by calling disconnect
        response = self.connection_service._close_connections(old_connection_info, ConnectionType.EDIT)
        mock_connection_1.close.assert_not_called()
        self.assertFalse(response)

    def test_handle_disconnect_request_unknown_uri(self):
        """Test that the handle_disconnect_request method returns false when the given URI is unknown"""
        # Setup: Create a mock request context
        rc = utils.MockRequestContext()

        # If: I request to disconnect an unknown URI
        params: DisconnectRequestParams = DisconnectRequestParams.from_dict({
            'ownerUri': 'nonexistent'
        })
        self.connection_service.handle_disconnect_request(rc, params)

        # Then: Send result should have been called once with False
        rc.send_response.assert_called_once_with(False)
        rc.send_notification.assert_not_called()
        rc.send_error.assert_not_called()

    def test_handle_connect_request(self):
        """Test that the handle_connect_request method kicks off a new thread to do the connection"""
        # Setup: Create a mock request context to handle output
        rc = utils.MockRequestContext()
        connect_response = ConnectionCompleteParams()
        self.connection_service.connect = Mock(return_value=connect_response)

        # If: I make a request to connect
        params: ConnectRequestParams = ConnectRequestParams.from_dict({
            'ownerUri': 'someUri',
            'type': ConnectionType.QUERY,
            'connection': {
                'server_name': 'someserver',
                'user_name': 'someuser',
                'database_name': 'somedb',
                'options': {
                    'password': '******'
                }
            }
        })

        # Connect and wait for the thread to finish executing, then verify the connection information
        self.connection_service.handle_connect_request(rc, params)
        connection_thread = self.connection_service.owner_to_thread_map[params.owner_uri]
        self.assertIsNotNone(connection_thread)
        connection_thread.join()

        # Then:
        # ... Connect should have been called once
        self.connection_service.connect.assert_called_once_with(params)

        # ... A True should have been sent as the response to the request
        rc.send_response.assert_called_once_with(True)

        # ... A connection complete notification should have been sent back as well
        rc.send_notification.assert_called_once_with(CONNECTION_COMPLETE_METHOD, connect_response)

        # ... An error should not have been called
        rc.send_error.assert_not_called()

    def test_handle_database_change_request_with_empty_connection_info_for_false(self):
        """Test that the handle_connect_request method kicks off a new thread to do the connection"""
        # Setup: Create a mock request context to handle output
        rc = utils.MockRequestContext()
        connect_response = ConnectionCompleteParams()
        self.connection_service.connect = Mock(return_value=connect_response)

        params: ChangeDatabaseRequestParams = ChangeDatabaseRequestParams.from_dict({
            'owner_uri': 'someUri',
            'new_database': 'newDb'
        })

        result: bool = self.connection_service.handle_change_database_request(rc, params)
        self.assertEqual(result, False)

    def test_handle_database_change_request(self):
        """Test that the handle_connect_request method kicks off a new thread to do the connection"""
        # Setup: Create a mock request context to handle output
        rc = utils.MockRequestContext()
        connect_response = ConnectionCompleteParams()
        self.connection_service.connect = Mock(return_value=connect_response)
        self.connection_service.get_connection_info = mock.MagicMock()

        params: ChangeDatabaseRequestParams = ChangeDatabaseRequestParams.from_dict({
            'owner_uri': 'someUri',
            'new_database': 'newDb'
        })

        self.connection_service.handle_change_database_request(rc, params)

        connection_thread = self.connection_service.owner_to_thread_map[params.owner_uri]
        self.assertIsNotNone(connection_thread)
        connection_thread.join()

        self.connection_service.connect.assert_called_once()

        # ... A True should have been sent as the response to the request
        rc.send_response.assert_called_once_with(True)

        # ... A connection complete notification should have been sent back as well
        rc.send_notification.assert_called_once_with(CONNECTION_COMPLETE_METHOD, connect_response)

        # ... An error should not have been called
        rc.send_error.assert_not_called()

    def test_list_databases(self):
        """Test that the list databases handler correctly lists the connection's databases"""
        # Set up the test with mock data
        mock_query_results = [('database1',), ('database2',)]
        connection_uri = 'someuri'
        mock_connection = MockConnection(
            dsn_parameters={
                'host': 'myserver',
                'dbname': 'postgres',
                'user': '******'
            },
            cursor=MockCursor(mock_query_results))
        mock_request_context = utils.MockRequestContext()

        # Insert a ConnectionInfo object into the connection service's map
        connection_details = ConnectionDetails.from_data({})
        connection_info = ConnectionInfo(connection_uri, connection_details)
        self.connection_service.owner_to_connection_map[connection_uri] = connection_info

        # Verify that calling the listdatabases handler returns the expected databases
        params = ListDatabasesParams()
        params.owner_uri = connection_uri

        with mock.patch('psycopg2.connect', new=mock.Mock(return_value=mock_connection)):
            self.connection_service.handle_list_databases(mock_request_context, params)
        expected_databases = [result[0] for result in mock_query_results]
        self.assertEqual(mock_request_context.last_response_params.database_names, expected_databases)

    def test_get_connection_for_existing_connection(self):
        """Test that get_connection returns a connection that already exists for the given URI and type"""
        # Set up the test with mock data
        connection_uri = 'someuri'
        connection_type = ConnectionType.EDIT
        mock_connection = MockConnection(
            dsn_parameters={
                'host': 'myserver',
                'dbname': 'postgres',
                'user': '******'
            })

        # Insert a ConnectionInfo object into the connection service's map
        connection_details = ConnectionDetails.from_data({})
        connection_info = ConnectionInfo(connection_uri, connection_details)
        self.connection_service.owner_to_connection_map[connection_uri] = connection_info

        # Get the connection without first creating it
        with mock.patch('psycopg2.connect', new=mock.Mock(return_value=mock_connection)) as mock_psycopg2_connect:
            connection = self.connection_service.get_connection(connection_uri, connection_type)
            mock_psycopg2_connect.assert_called_once()
        self.assertEqual(connection, mock_connection)

    def test_get_connection_creates_connection(self):
        """Test that get_connection creates a new connection when none exists for the given URI and type"""
        # Set up the test with mock data
        connection_uri = 'someuri'
        connection_type = ConnectionType.EDIT
        mock_connection = MockConnection(
            dsn_parameters={
                'host': 'myserver',
                'dbname': 'postgres',
                'user': '******'
            })

        # Insert a ConnectionInfo object into the connection service's map
        connection_details = ConnectionDetails.from_data({})
        connection_info = ConnectionInfo(connection_uri, connection_details)
        self.connection_service.owner_to_connection_map[connection_uri] = connection_info

        with mock.patch('psycopg2.connect', new=mock.Mock(return_value=mock_connection)) as mock_psycopg2_connect:
            # Open the connection
            self.connection_service.connect(ConnectRequestParams(connection_details, connection_uri, connection_type))

            # Get the connection
            connection = self.connection_service.get_connection(connection_uri, connection_type)
            self.assertEqual(connection, mock_connection)
            mock_psycopg2_connect.assert_called_once()

    def test_get_connection_for_invalid_uri(self):
        """Test that get_connection raises an error if the given URI is unknown"""
        with self.assertRaises(ValueError):
            self.connection_service.get_connection('someuri', ConnectionType.DEFAULT)

    def test_list_databases_handles_invalid_uri(self):
        """Test that the connection/listdatabases handler returns an error when the given URI is unknown"""
        mock_request_context = utils.MockRequestContext()
        params = ListDatabasesParams()
        params.owner_uri = 'unknown_uri'

        self.connection_service.handle_list_databases(mock_request_context, params)
        self.assertIsNone(mock_request_context.last_notification_method)
        self.assertIsNone(mock_request_context.last_notification_params)
        self.assertIsNone(mock_request_context.last_response_params)
        self.assertIsNotNone(mock_request_context.last_error_message)

    def test_list_databases_handles_query_failure(self):
        """Test that the list databases handler returns an error if the list databases query fails for any reason"""
        # Set up the test with mock data
        mock_query_results = [('database1',), ('database2',)]
        connection_uri = 'someuri'
        mock_cursor = MockCursor(mock_query_results)
        mock_cursor.fetchall.side_effect = psycopg2.ProgrammingError('')
        mock_connection = MockConnection(
            dsn_parameters={
                'host': 'myserver',
                'dbname': 'postgres',
                'user': '******'
            },
            cursor=mock_cursor)
        mock_request_context = utils.MockRequestContext()

        # Insert a ConnectionInfo object into the connection service's map
        connection_details = ConnectionDetails.from_data({})
        connection_info = ConnectionInfo(connection_uri, connection_details)
        self.connection_service.owner_to_connection_map[connection_uri] = connection_info

        # Verify that calling the listdatabases handler returns the expected
        # databases
        params = ListDatabasesParams()
        params.owner_uri = connection_uri

        with mock.patch('psycopg2.connect', new=mock.Mock(return_value=mock_connection)):
            self.connection_service.handle_list_databases(mock_request_context, params)
        self.assertIsNone(mock_request_context.last_notification_method)
        self.assertIsNone(mock_request_context.last_notification_params)
        self.assertIsNone(mock_request_context.last_response_params)
        self.assertIsNotNone(mock_request_context.last_error_message)

    def test_build_connection_response(self):
        """Test that the connection response is built correctly"""
        # Set up the test with mock data
        server_name = 'testserver'
        db_name = 'testdb'
        user = '******'
        mock_connection = MockConnection({
            'host': server_name,
            'dbname': db_name,
            'user': user
        })
        connection_type = ConnectionType.EDIT
        connection_details = ConnectionDetails.from_data(opts={})
        owner_uri = 'test_uri'
        connection_info = ConnectionInfo(owner_uri, connection_details)
        connection_info._connection_map = {connection_type: mock_connection}

        # If I build a connection response for the connection
        response = pgsqltoolsservice.connection.connection_service._build_connection_response(
            connection_info, connection_type)

        # Then the response should have accurate information about the connection
        self.assertEqual(response.owner_uri, owner_uri)
        self.assertEqual(response.server_info.server_version, mock_connection.server_version)
        self.assertEqual(response.server_info.is_cloud, False)
        self.assertEqual(response.connection_summary.server_name, server_name)
        self.assertEqual(response.connection_summary.database_name, db_name)
        self.assertEqual(response.connection_summary.user_name, user)
        self.assertEqual(response.type, connection_type)

    def test_default_database(self):
        """Test that if no database is given, the default database is used"""
        # Set up the connection params and default database name
        default_db = 'test_db'
        self.connection_service._service_provider[constants.WORKSPACE_SERVICE_NAME].configuration.pgsql.default_database = default_db
        params: ConnectRequestParams = ConnectRequestParams.from_dict({
            'ownerUri': 'someUri',
            'type': ConnectionType.DEFAULT,
            'connection': {
                'options': {
                    'user': '******',
                    'password': '******',
                    'host': 'myserver',
                    'dbname': ''
                }
            }
        })

        # If I connect with an empty database name
        with mock.patch('pgsqltoolsservice.connection.connection_service._build_connection_response'), \
                mock.patch('psycopg2.connect') as mock_psycopg2_connect:
            self.connection_service.connect(params)

            # Then psycopg2's connect method was called with the default database
            calls = mock_psycopg2_connect.mock_calls
            self.assertEqual(len(calls), 1)
            self.assertEqual(calls[0][2]['dbname'], default_db)

    def test_non_default_database(self):
        """Test that if a database is given, the default database is not used"""
        # Set up the connection params and default database name
        default_db = 'test_db'
        actual_db = 'postgres'
        self.connection_service._service_provider[constants.WORKSPACE_SERVICE_NAME].configuration.pgsql.default_database = default_db
        params: ConnectRequestParams = ConnectRequestParams.from_dict({
            'ownerUri': 'someUri',
            'type': ConnectionType.DEFAULT,
            'connection': {
                'options': {
                    'user': '******',
                    'password': '******',
                    'host': 'myserver',
                    'dbname': actual_db
                }
            }
        })

        # If I connect with an empty database name
        with mock.patch('pgsqltoolsservice.connection.connection_service._build_connection_response'), \
                mock.patch('psycopg2.connect') as mock_psycopg2_connect:
            self.connection_service.connect(params)

            # Then psycopg2's connect method was called with the default database
            calls = mock_psycopg2_connect.mock_calls
            self.assertEqual(len(calls), 1)
            self.assertNotEqual(calls[0][2]['dbname'], default_db)
            self.assertEqual(calls[0][2]['dbname'], actual_db)

    def test_get_connection_info(self):
        """Test that get_connection_info returns the ConnectionInfo object corresponding to a connection"""
        # Set up the test with mock data
        connection_uri = 'someuri'

        # Insert a ConnectionInfo object into the connection service's map
        connection_details = ConnectionDetails.from_data({})
        connection_info = ConnectionInfo(connection_uri, connection_details)
        self.connection_service.owner_to_connection_map[connection_uri] = connection_info

        # Get the connection info
        actual_connection_info = self.connection_service.get_connection_info(connection_uri)
        self.assertIs(actual_connection_info, connection_info)

    def test_get_connection_info_no_connection(self):
        """Test that get_connection_info returns None when there is no connection for the given owner URI"""
        # Set up the test with mock data
        connection_uri = 'someuri'

        # Get the connection info
        actual_connection_info = self.connection_service.get_connection_info(connection_uri)
        self.assertIsNone(actual_connection_info)
 def setUp(self):
     """Set up the tests with a connection service"""
     self.connection_service = ConnectionService()
     self.connection_service._service_provider = utils.get_mock_service_provider({constants.WORKSPACE_SERVICE_NAME: WorkspaceService()})
class TestConnectionCancellation(unittest.TestCase):
    """Methods for testing connection cancellation requests"""

    def setUp(self):
        """Set up the tests with common connection parameters"""
        # Set up the mock connection service and connection info
        self.connection_service = ConnectionService()
        self.connection_service._service_provider = {constants.WORKSPACE_SERVICE_NAME: WorkspaceService()}
        self.owner_uri = 'test_uri'
        self.connection_type = ConnectionType.DEFAULT
        self.connect_params: ConnectRequestParams = ConnectRequestParams.from_dict({
            'ownerUri': self.owner_uri,
            'type': self.connection_type,
            'connection': {
                'options': {
                }
            }
        })
        self.mock_connection = MockConnection(dsn_parameters={
            'host': 'myserver',
            'dbname': 'postgres',
            'user': '******'
        })

        # Mock psycopg2's connect method to store the current cancellation token. This lets us
        # capture the cancellation token state as it would be during a long-running connection.
        self.token_store = []

    def test_connecting_sets_cancellation_token(self):
        """Test that a cancellation token is set before a connection thread attempts to connect"""
        # If I attempt to connect
        with mock.patch('psycopg2.connect', new=mock.Mock(side_effect=self._mock_connect)):
            response = self.connection_service.connect(self.connect_params)

        # Then the cancellation token should have been set and should not have been canceled
        self.assertEqual(len(self.token_store), 1)
        self.assertFalse(self.token_store[0].canceled)

        # And the cancellation token should have been cleared when the connection succeeded
        self.assertIsNone(response.error_message)
        self.assertFalse((self.owner_uri, self.connection_type) in self.connection_service._cancellation_map)

    def test_connection_failed_removes_own_token(self):
        """Test that the cancellation token is removed after a connection fails"""
        # If I attempt to connect
        with mock.patch('psycopg2.connect', new=mock.Mock(side_effect=Exception())):
            response = self.connection_service.connect(self.connect_params)

        # Then the cancellation token should have been cleared when the connection failed
        self.assertIsNotNone(response.error_message)
        self.assertFalse((self.owner_uri, self.connection_type) in self.connection_service._cancellation_map)

    def test_connecting_cancels_previous_connection(self):
        """Test that opening a new connection while one is ongoing cancels the previous connection"""
        # Set up psycopg2's connection method to kick off a new connection. This simulates the case
        # where a call to psycopg2.connect is taking a long time and another connection request for
        # the same URI and connection type comes in and finishes before the current connection
        with mock.patch('psycopg2.connect', new=mock.Mock(side_effect=self._mock_connect)) as mock_psycopg2_connect:
            old_mock_connect = mock_psycopg2_connect.side_effect

            def first_mock_connect(**kwargs):
                """Mock connection method to store the current cancellation token, and kick off another connection"""
                mock_connection = self._mock_connect()
                mock_psycopg2_connect.side_effect = old_mock_connect
                self.connection_service.connect(self.connect_params)
                return mock_connection

            mock_psycopg2_connect.side_effect = first_mock_connect

            # If I attempt to connect, and then kick off a new connection while connecting
            response = self.connection_service.connect(self.connect_params)

        # Then the connection should have been canceled and returned none
        self.assertIsNone(response)

        # And the recorded cancellation tokens should show that the first request was canceled
        self.assertEqual(len(self.token_store), 2)
        self.assertTrue(self.token_store[0].canceled)
        self.assertFalse(self.token_store[1].canceled)

    def test_newer_cancellation_token_not_removed(self):
        """Test that a newer connection's cancellation token is not removed after a connection completes"""
        # Set up psycopg2's connection method to simulate a new connection by overriding the
        # current cancellation token. This simulates the case where a call to psycopg2.connect is
        # taking a long time and another connection request for the same URI and connection type
        # comes in and finishes after the current connection
        cancellation_token = CancellationToken()
        cancellation_key = (self.owner_uri, self.connection_type)

        def override_mock_connect(**kwargs):
            """Mock connection method to override the current connection token, as if another connection is executing"""
            mock_connection = self._mock_connect()
            self.connection_service._cancellation_map[cancellation_key].cancel()
            self.connection_service._cancellation_map[cancellation_key] = cancellation_token
            return mock_connection

        with mock.patch('psycopg2.connect', new=mock.Mock(side_effect=override_mock_connect)):
            # If I attempt to connect, and the cancellation token gets updated while connecting
            response = self.connection_service.connect(self.connect_params)

        # Then the connection should have been canceled and returned none
        self.assertIsNone(response)

        # And the current cancellation token should not have been removed
        self.assertIs(self.connection_service._cancellation_map[cancellation_key], cancellation_token)

    def test_handle_cancellation_request(self):
        """Test that handling a cancellation request modifies the cancellation token for a matched connection"""
        # Set up the connection service with a mock request handler and cancellation token
        cancellation_key = (self.owner_uri, self.connection_type)
        cancellation_token = CancellationToken()
        self.connection_service._cancellation_map[cancellation_key] = cancellation_token
        request_context = utils.MockRequestContext()

        # If I call the cancellation request handler
        cancel_params = CancelConnectParams(self.owner_uri, self.connection_type)
        self.connection_service.handle_cancellation_request(request_context, cancel_params)

        # Then the handler should have responded and set the cancellation flag
        request_context.send_response.assert_called_once_with(True)
        self.assertTrue(cancellation_token.canceled)

    def test_handle_cancellation_no_match(self):
        """Test that handling a cancellation request returns false if there is no matching connection to cancel"""
        # Set up a mock request handler
        request_context = utils.MockRequestContext()

        # If I call the cancellation request handler
        cancel_params = CancelConnectParams(self.owner_uri, self.connection_type)
        self.connection_service.handle_cancellation_request(request_context, cancel_params)

        # Then the handler should have responded false to indicate that no matching connection was in progress
        request_context.send_response.assert_called_once_with(False)

    def test_connect_with_access_token(self):
        """Test that the service connects to a PostgreSQL server using an access token as a password"""
        # Set up the parameters for the connection
        params: ConnectRequestParams = ConnectRequestParams.from_dict({
            'ownerUri': 'someUri',
            'type': ConnectionType.DEFAULT,
            'connection': {
                'options': {
                    'user': '******',
                    'azureAccountToken': 'exampleToken',
                    'host': 'myserver',
                    'dbname': 'postgres'
                }
            }
        })

        # Set up the mock connection for psycopg2's connect method to return
        mock_connection = MockConnection(dsn_parameters={
            'host': 'myserver',
            'dbname': 'postgres',
            'user': '******'
        })

        # Set up psycopg2 instance for connection service to call
        mock_connect_method = mock.Mock(return_value=mock_connection)

        # Set up the connection service and call its connect method with the supported options
        with mock.patch('psycopg2.connect', new=mock_connect_method):
            response = self.connection_service.connect(params)

        # Verify that psycopg2's connection method was called with password set to account token.
        mock_connect_method.assert_called_once_with(user='******', password='******', host='myserver', dbname='postgres')

        # Verify that psycopg2's connection method was called and that the
        # response has a connection id, indicating success.
        self.assertIs(self.connection_service.owner_to_connection_map[params.owner_uri].get_connection(params.type),
                      mock_connection)
        self.assertIsNotNone(response.connection_id)
        self.assertIsNotNone(response.server_info.server_version)
        self.assertFalse(response.server_info.is_cloud)

    def _mock_connect(self, **kwargs):
        """Implementation for the mock psycopg2.connect method that saves the current cancellation token"""
        self.token_store.append(self.connection_service._cancellation_map[(self.owner_uri, self.connection_type)])
        return self.mock_connection
Esempio n. 18
0
    def setUp(self):
        """Set up the tests with a disaster recovery service and connection service with mock connection info"""
        self.disaster_recovery_service = DisasterRecoveryService()
        self.connection_service = ConnectionService()
        self.task_service = TaskService()
        self.disaster_recovery_service._service_provider = utils.get_mock_service_provider(
            {
                constants.CONNECTION_SERVICE_NAME: self.connection_service,
                constants.TASK_SERVICE_NAME: self.task_service
            })

        # Create connection information for use in the tests
        self.connection_details = ConnectionDetails()
        self.host = 'test_host'
        self.dbname = 'test_db'
        self.username = '******'
        self.connection_details.options = {
            'host': self.host,
            'dbname': self.dbname,
            'user': self.username,
            'port': 5432
        }
        self.test_uri = 'test_uri'
        self.connection_info = ConnectionInfo(self.test_uri,
                                              self.connection_details)

        # Create backup parameters for the tests
        self.request_context = utils.MockRequestContext()
        self.backup_path = 'mock/path/test.sql'
        self.backup_type = 'sql'
        self.data_only = False
        self.no_owner = True
        self.schema = 'test_schema'
        self.backup_params = BackupParams.from_dict({
            'ownerUri': self.test_uri,
            'backupInfo': {
                'type': self.backup_type,
                'path': self.backup_path,
                'data_only': self.data_only,
                'no_owner': self.no_owner,
                'schema': self.schema
            }
        })
        self.restore_path = 'mock/path/test.dump'
        self.restore_params = RestoreParams.from_dict({
            'ownerUri': self.test_uri,
            'options': {
                'path': self.restore_path,
                'data_only': self.data_only,
                'no_owner': self.no_owner,
                'schema': self.schema
            }
        })
        self.pg_dump_exe = 'pg_dump'
        self.pg_restore_exe = 'pg_restore'

        # Create the mock task for the tests
        self.mock_action = mock.Mock()
        self.mock_task = Task(None, None, None, None, None,
                              self.request_context, self.mock_action)
        self.mock_task.start = mock.Mock()