示例#1
0
    def test_get_connection_creates_connection(self):
        """Test that get_connection creates a new connection when none exists for the given URI and type"""
        # Set up the test with mock data
        connection_uri = 'someuri'
        connection_type = ConnectionType.EDIT
        mock_connection = MockPGServerConnection(cur=None,
                                                 host='myserver',
                                                 name='postgres',
                                                 user='******')

        # Insert a ConnectionInfo object into the connection service's map
        connection_details = ConnectionDetails.from_data({})
        connection_info = ConnectionInfo(connection_uri, connection_details)
        self.connection_service.owner_to_connection_map[
            connection_uri] = connection_info

        with mock.patch(
                'ossdbtoolsservice.driver.connection_manager.ConnectionManager._create_connection',
                new=mock.Mock(
                    return_value=mock_connection)) as mock_psycopg2_connect:
            # Open the connection
            self.connection_service.connect(
                ConnectRequestParams(connection_details, connection_uri,
                                     connection_type))

            # Get the connection
            connection = self.connection_service.get_connection(
                connection_uri, connection_type)
            self.assertEqual(connection, mock_connection)
            mock_psycopg2_connect.assert_called_once()
示例#2
0
    def setUp(self):
        """Set up the tests with common connection parameters"""
        # Set up the mock connection service and connection info
        self.connection_service = ConnectionService()
        self.connection_service._service_provider = utils.get_mock_service_provider(
            {WORKSPACE_SERVICE_NAME: WorkspaceService()})
        self.owner_uri = 'test_uri'
        self.connection_type = ConnectionType.DEFAULT
        self.connect_params: ConnectRequestParams = ConnectRequestParams.from_dict(
            {
                'ownerUri': self.owner_uri,
                'type': self.connection_type,
                'connection': {
                    'options': {}
                }
            })
        self.mock_psycopg_connection = MockPsycopgConnection(dsn_parameters={
            'host': 'myserver',
            'dbname': 'postgres',
            'user': '******'
        })

        # Mock psycopg2's connect method to store the current cancellation token. This lets us
        # capture the cancellation token state as it would be during a long-running connection.
        self.token_store = []
示例#3
0
    def _handle_simple_execute_request(self, request_context: RequestContext, params: SimpleExecuteRequest):

        new_owner_uri = str(uuid.uuid4())

        connection_service = self._service_provider[utils.constants.CONNECTION_SERVICE_NAME]
        connection_info = connection_service.get_connection_info(params.owner_uri)
        connection_service.connect(ConnectRequestParams(connection_info.details, new_owner_uri, ConnectionType.QUERY))
        new_connection = self._get_connection(new_owner_uri, ConnectionType.QUERY)

        execute_params = ExecuteStringParams()
        execute_params.query = params.query_string
        execute_params.owner_uri = new_owner_uri

        def on_query_complete(query_complete_params):
            subset_params = SubsetParams()
            subset_params.owner_uri = new_owner_uri
            subset_params.batch_index = 0
            subset_params.result_set_index = 0
            subset_params.rows_start_index = 0

            resultset_summary = query_complete_params.batch_summaries[0].result_set_summaries[0]

            subset_params.rows_count = resultset_summary.row_count

            subset = self._get_result_subset(request_context, subset_params)

            simple_execute_response = SimpleExecuteResponse(subset.result_subset.rows, subset.result_subset.row_count, resultset_summary.column_info)
            request_context.send_response(simple_execute_response)

        worker_args = ExecuteRequestWorkerArgs(new_owner_uri, new_connection, request_context, ResultSetStorageType.FILE_STORAGE,
                                               on_query_complete=on_query_complete)

        self._start_query_execution_thread(request_context, execute_params, worker_args)
示例#4
0
    def test_connect(self):
        """Test that the service connects to a MySQL server"""
        # Set up the parameters for the connection
        params: ConnectRequestParams = ConnectRequestParams.from_dict({
            'ownerUri':
            'someUri',
            'type':
            ConnectionType.DEFAULT,
            'connection': {
                'options': {
                    'user': '******',
                    'password': '******',
                    'host': 'myserver',
                    'dbname': 'mysql'
                }
            }
        })

        # Set up the connection service and call its connect method with the supported options
        with mock.patch(
                'pymysql.connect',
                new=mock.Mock(return_value=self.mock_pymysql_connection)):
            response = self.connection_service.connect(params)

        # Verify that pymysql's connection method was called and that the
        # response has a connection id, indicating success.
        self.assertIs(
            self.connection_service.owner_to_connection_map[
                params.owner_uri].get_connection(params.type)._conn,
            self.mock_pymysql_connection)
        self.assertIsNotNone(response.connection_id)
        self.assertIsNotNone(response.server_info.server_version)
        self.assertFalse(response.server_info.is_cloud)
示例#5
0
    def test_non_default_database(self):
        """Test that if a database is given, the default database is not used"""
        # Set up the connection params and default database name
        default_db = 'test_db'
        actual_db = 'postgres'
        self.connection_service._service_provider[
            WORKSPACE_SERVICE_NAME].configuration.pgsql.default_database = default_db
        params: ConnectRequestParams = ConnectRequestParams.from_dict({
            'ownerUri':
            'someUri',
            'type':
            ConnectionType.DEFAULT,
            'connection': {
                'options': {
                    'user': '******',
                    'password': '******',
                    'host': 'myserver',
                    'dbname': actual_db
                }
            }
        })

        # If I connect with an empty database name
        with mock.patch('ossdbtoolsservice.connection.connection_service._build_connection_response'), \
                mock.patch('psycopg2.connect', return_value=MockPsycopgConnection()) as mock_psycopg2_connect:
            self.connection_service.connect(params)

            # Then psycopg2's connect method was called with the default database
            calls = mock_psycopg2_connect.mock_calls
            self.assertEqual(len(calls), 1)
            self.assertNotEqual(calls[0][2]['dbname'], default_db)
            self.assertEqual(calls[0][2]['dbname'], actual_db)
示例#6
0
    def server_info_is_cloud_internal(self, host_suffix, is_cloud):
        """Test that the connection response handles cloud connections correctly"""
        # Set up the parameters for the connection
        connection_uri = 'someuri'
        connection_details = ConnectionDetails()
        connection_details.options = {
            'user': '******',
            'password': '******',
            'host': f'myserver{host_suffix}',
            'dbname': 'postgres'
        }
        connection_type = ConnectionType.DEFAULT

        # Set up the mock connection for psycopg2's connect method to return
        mock_connection = MockPsycopgConnection(
            dsn_parameters={
                'host': f'myserver{host_suffix}',
                'dbname': 'postgres',
                'user': '******'
            })

        # Set up the connection service and call its connect method with the
        # supported options
        with mock.patch('psycopg2.connect',
                        new=mock.Mock(return_value=mock_connection)):
            response = self.connection_service.connect(
                ConnectRequestParams(connection_details, connection_uri,
                                     connection_type))

        # Verify that the response's serverInfo.isCloud attribute is set correctly
        self.assertIsNotNone(response.connection_id)
        self.assertIsNotNone(response.server_info.server_version)
        self.assertEqual(response.server_info.is_cloud, is_cloud)
    def _initialize_session(self, request_context: RequestContext,
                            session: ObjectExplorerSession):
        conn_service = self._service_provider[
            utils.constants.CONNECTION_SERVICE_NAME]
        connection = None

        try:
            # Step 1: Connect with the provided connection details
            connect_request = ConnectRequestParams(
                session.connection_details, session.id,
                ConnectionType.OBJECT_EXLPORER)
            connect_result = conn_service.connect(connect_request)
            if connect_result is None:
                raise RuntimeError(
                    'Connection was cancelled during connect')  # TODO Localize
            if connect_result.error_message is not None:
                raise RuntimeError(connect_result.error_message)

            # Step 2: Get the connection to use for object explorer
            connection = conn_service.get_connection(
                session.id, ConnectionType.OBJECT_EXLPORER)

            # Step 3: Create the Server object for the session and create the root node for the server
            session.server = self._server(
                connection, functools.partial(self._create_connection,
                                              session))
            metadata = ObjectMetadata(session.server.urn_base, None,
                                      'Database',
                                      session.server.maintenance_db_name)
            node = NodeInfo()
            node.label = session.connection_details.database_name
            node.is_leaf = False
            node.node_path = session.id
            node.node_type = 'Database'
            node.metadata = metadata

            # Step 4: Send the completion notification to the server
            response = SessionCreatedParameters()
            response.success = True
            response.session_id = session.id
            response.root_node = node
            response.error_message = None
            request_context.send_notification(SESSION_CREATED_METHOD, response)

            # Mark the session as complete
            session.is_ready = True

        except Exception as e:
            # Return a notification that an error occurred
            message = f'Failed to initialize object explorer session: {str(e)}'  # TODO Localize
            self._session_created_error(request_context, session, message)

            # Attempt to clean up the connection
            if connection is not None:
                conn_service.disconnect(session.id,
                                        ConnectionType.OBJECT_EXLPORER)
    def get_connection(self, owner_uri: str, connection_type: ConnectionType) -> Optional[ServerConnection]:
        """
        Get a connection for the given owner URI and connection type

        :raises ValueError: If there is no connection associated with the provided URI
        """
        connection_info = self.owner_to_connection_map.get(owner_uri)
        if connection_info is None:
            raise ValueError('No connection associated with given owner URI')

        if not connection_info.has_connection(connection_type):
            self.connect(ConnectRequestParams(connection_info.details, owner_uri, connection_type))
        return connection_info.get_connection(connection_type)
示例#9
0
    def _create_connection(
            self, connection_key: str,
            conn_info: ConnectionInfo) -> Optional[ServerConnection]:
        conn_service = self._connection_service
        key_uri = INTELLISENSE_URI + connection_key
        connect_request = ConnectRequestParams(conn_info.details, key_uri,
                                               ConnectionType.INTELLISENSE)
        connect_result = conn_service.connect(connect_request)
        if connect_result.error_message is not None:
            raise RuntimeError(connect_result.error_message)

        connection = conn_service.get_connection(key_uri,
                                                 ConnectionType.INTELLISENSE)
        return connection
    def handle_change_database_request(self, request_context: RequestContext,
                                       params: ChangeDatabaseRequestParams) -> bool:
        """change database of an existing connection or create a new connection
        with default database from input"""
        connection_info: ConnectionInfo = self.get_connection_info(params.owner_uri)

        if connection_info is None:
            return False

        connection_info_params: Dict[str, str] = connection_info.details.options.copy()
        connection_info_params["dbname"] = params.new_database
        connection_details: ConnectionDetails = ConnectionDetails.from_data(connection_info_params)

        connection_request_params: ConnectRequestParams = ConnectRequestParams(connection_details, params.owner_uri, ConnectionType.DEFAULT)
        self.handle_connect_request(request_context, connection_request_params)
示例#11
0
    def test_changing_options_disconnects_existing_connection(self):
        """
        Test that the connect method disconnects an existing connection when trying to open the same connection with
        different options
        """
        # Set up the test with mock data
        connection_uri = 'someuri'
        connection_type = ConnectionType.DEFAULT
        mock_connection = MockPGServerConnection(cur=None,
                                                 host='myserver',
                                                 name='postgres',
                                                 user='******')

        # Insert a ConnectionInfo object into the connection service's map
        old_connection_details = ConnectionDetails.from_data({
            'host': 'myserver',
            'dbname': 'postgres',
            'user': '******',
            'abc': 123
        })
        old_connection_info = ConnectionInfo(connection_uri,
                                             old_connection_details)
        old_connection_info.add_connection(connection_type, mock_connection)
        self.connection_service.owner_to_connection_map[
            connection_uri] = old_connection_info

        # Create a different request with the same owner uri
        params: ConnectRequestParams = ConnectRequestParams.from_dict({
            'ownerUri':
            connection_uri,
            'type':
            connection_type,
            'connection': {
                'options': {
                    'host': 'myserver',
                    'dbname': 'postgres',
                    'user': '******',
                    'abc': 234
                }
            }
        })

        # Connect with different options, and verify that disconnect was called
        with mock.patch('psycopg2.connect',
                        new=mock.Mock(return_value=mock_connection)):
            self.connection_service.connect(params)
        mock_connection.close.assert_called_once()
示例#12
0
    def test_same_options_uses_existing_connection(self):
        """Test that the connect method uses an existing connection when connecting again with the same options"""
        # Set up the test with mock data
        connection_uri = 'someuri'
        connection_type = ConnectionType.DEFAULT
        mock_psycopg_connection = MockPsycopgConnection(dsn_parameters={
            'host': 'myserver',
            'dbname': 'postgres',
            'user': '******'
        })
        mock_server_connection = MockPGServerConnection(cur=None,
                                                        host='myserver',
                                                        name='postgres',
                                                        user='******')

        # Insert a ConnectionInfo object into the connection service's map
        old_connection_details = ConnectionDetails.from_data({
            'host': 'myserver',
            'dbname': 'postgres',
            'user': '******',
            'abc': 123
        })
        old_connection_info = ConnectionInfo(connection_uri,
                                             old_connection_details)
        old_connection_info.add_connection(connection_type,
                                           mock_server_connection)
        self.connection_service.owner_to_connection_map[
            connection_uri] = old_connection_info

        # Connect with identical options, and verify that disconnect was not called
        params: ConnectRequestParams = ConnectRequestParams.from_dict({
            'ownerUri':
            connection_uri,
            'type':
            connection_type,
            'connection': {
                'options': old_connection_details.options
            }
        })
        with mock.patch('psycopg2.connect',
                        new=mock.Mock(return_value=mock_psycopg_connection)
                        ) as mock_psycopg2_connect:
            response = self.connection_service.connect(params)
            mock_psycopg2_connect.assert_not_called()
        mock_psycopg_connection.close.assert_not_called()
        self.assertIsNotNone(response.connection_id)
    def _create_connection(self, session: ObjectExplorerSession,
                           database_name: str) -> Optional[ServerConnection]:
        conn_service = self._service_provider[
            utils.constants.CONNECTION_SERVICE_NAME]

        options = session.connection_details.options.copy()
        options['dbname'] = database_name
        conn_details = ConnectionDetails.from_data(options)

        key_uri = session.id + database_name
        connect_request = ConnectRequestParams(conn_details, key_uri,
                                               ConnectionType.OBJECT_EXLPORER)
        connect_result = conn_service.connect(connect_request)
        if connect_result.error_message is not None:
            raise RuntimeError(connect_result.error_message)

        connection = conn_service.get_connection(
            key_uri, ConnectionType.OBJECT_EXLPORER)
        return connection
示例#14
0
    def test_connect_with_access_token(self):
        """Test that the service connects to a MySQL server using an access token as a password"""
        # Set up the parameters for the connection
        params: ConnectRequestParams = ConnectRequestParams.from_dict({
            'ownerUri':
            'someUri',
            'type':
            ConnectionType.DEFAULT,
            'connection': {
                'options': {
                    'user': '******',
                    'azureAccountToken': 'exampleToken',
                    'host': 'myserver',
                    'dbname': 'mysql'
                }
            }
        })

        # Set up pymysql instance for connection service to call
        mock_connect_method = mock.Mock(
            return_value=self.mock_pymysql_connection)

        # Set up the connection service and call its connect method with the supported options
        with mock.patch('pymysql.connect', new=mock_connect_method):
            response = self.connection_service.connect(params)

        # Verify that pymysql's connection method was called with password set to account token.
        mock_connect_method.assert_called_once_with(
            user='******',
            password='******',
            host='myserver',
            port=DEFAULT_PORT[MYSQL_PROVIDER_NAME],
            database='mysql')

        # Verify that pymysql's connection method was called and that the
        # response has a connection id, indicating success.
        self.assertIs(
            self.connection_service.owner_to_connection_map[
                params.owner_uri].get_connection(params.type)._conn,
            self.mock_pymysql_connection)
        self.assertIsNotNone(response.connection_id)
        self.assertIsNotNone(response.server_info.server_version)
        self.assertFalse(response.server_info.is_cloud)
示例#15
0
    def run_on_connect_callback(self, conn_type: ConnectionType,
                                expect_callback: bool) -> None:
        """Inner function for callback tests that verifies expected behavior given different connection types"""
        callbacks = [MagicMock(), MagicMock()]
        for callback in callbacks:
            self.connection_service.register_on_connect_callback(callback)

        # Set up the parameters for the connection
        connection_uri = 'someuri'
        connection_details = ConnectionDetails()
        connection_details.options = {
            'user': '******',
            'password': '******',
            'host': f'myserver',
            'dbname': 'postgres'
        }
        connection_type = conn_type

        # Set up the mock connection for psycopg2's connect method to return
        mock_psycopg_connection = MockPsycopgConnection(
            dsn_parameters={
                'host': f'myserver',
                'dbname': 'postgres',
                'user': '******'
            })

        # Set up the connection service and call its connect method with the
        # supported options
        with mock.patch('psycopg2.connect',
                        new=mock.Mock(return_value=mock_psycopg_connection)):
            self.connection_service.connect(
                ConnectRequestParams(connection_details, connection_uri,
                                     connection_type))
        self.connection_service.get_connection(connection_uri, conn_type)
        # ... The mock config change callbacks should have been called
        for callback in callbacks:
            if (expect_callback):
                callback.assert_called_once()
                # Verify call args match expected
                callargs: ConnectionInfo = callback.call_args[0][0]
                self.assertEqual(callargs.owner_uri, connection_uri)
            else:
                callback.assert_not_called()
示例#16
0
    def test_handle_connect_request(self):
        """Test that the handle_connect_request method kicks off a new thread to do the connection"""
        # Setup: Create a mock request context to handle output
        rc = utils.MockRequestContext()
        connect_response = ConnectionCompleteParams()
        self.connection_service.connect = Mock(return_value=connect_response)

        # If: I make a request to connect
        params: ConnectRequestParams = ConnectRequestParams.from_dict({
            'ownerUri':
            'someUri',
            'type':
            ConnectionType.QUERY,
            'connection': {
                'server_name': 'someserver',
                'user_name': 'someuser',
                'database_name': 'somedb',
                'options': {
                    'password': '******'
                }
            }
        })

        # Connect and wait for the thread to finish executing, then verify the connection information
        self.connection_service.handle_connect_request(rc, params)
        connection_thread = self.connection_service.owner_to_thread_map[
            params.owner_uri]
        self.assertIsNotNone(connection_thread)
        connection_thread.join()

        # Then:
        # ... Connect should have been called once
        self.connection_service.connect.assert_called_once_with(params)

        # ... A True should have been sent as the response to the request
        rc.send_response.assert_called_once_with(True)

        # ... A connection complete notification should have been sent back as well
        rc.send_notification.assert_called_once_with(
            CONNECTION_COMPLETE_METHOD, connect_response)

        # ... An error should not have been called
        rc.send_error.assert_not_called()
示例#17
0
 def test_response_when_connect_fails(self):
     """Test that the proper response is given when a connection fails"""
     error_message = 'some error'
     params: ConnectRequestParams = ConnectRequestParams.from_dict({
         'ownerUri':
         'someUri',
         'type':
         ConnectionType.DEFAULT,
         'connection': {
             'options': {
                 'connectionString': ''
             }
         }
     })
     with mock.patch('psycopg2.connect',
                     new=mock.Mock(side_effect=Exception(error_message))):
         response = self.connection_service.connect(params)
     # The response should not have a connection ID and should contain the error message
     self.assertIsNone(response.connection_id)
     self.assertEqual(response.error_message, error_message)