コード例 #1
0
 def handle_restore_request(self, request_context: RequestContext,
                            params: RestoreParams) -> None:
     """
     Respond to restore/restore requests by performing a restore
     """
     connection_info: ConnectionInfo = self._service_provider[
         constants.CONNECTION_SERVICE_NAME].get_connection_info(
             params.owner_uri)
     if connection_info is None:
         request_context.send_error(
             'No connection corresponding to the given owner URI'
         )  # TODO: Localize
         return
     host = connection_info.details.options['host']
     database = connection_info.details.options['dbname']
     task = Task(
         'Restore',
         f'Host: {host}, Database: {database}',
         constants.PROVIDER_NAME,
         host,
         database,
         request_context,  # TODO: Localize
         functools.partial(_perform_restore, connection_info, params))
     self._service_provider[constants.TASK_SERVICE_NAME].start_task(task)
     request_context.send_response({})
コード例 #2
0
 def handle_cancellation_request(self, request_context: RequestContext, params: CancelConnectParams) -> None:
     """Cancel a connection attempt in response to a cancellation request"""
     cancellation_key = (params.owner_uri, params.type)
     with self._cancellation_lock:
         connection_found = cancellation_key in self._cancellation_map
         if connection_found:
             self._cancellation_map[cancellation_key].cancel()
     request_context.send_response(connection_found)
コード例 #3
0
 def handle_cancel_request(self, request_context: RequestContext,
                           params: CancelTaskParameters) -> None:
     """Respond to tasks/canceltask requests by canceling the requested task"""
     try:
         request_context.send_response(
             self._task_map[params.task_id].cancel())
     except KeyError:
         request_context.send_response(False)
コード例 #4
0
 def handle_list_request(self, request_context: RequestContext,
                         params: ListTasksParameters) -> None:
     """Respond to tasks/listtasks requests by returning the TaskInfo for all tasks"""
     tasks = list(self._task_map.values())
     if params.list_active_tasks_only:
         tasks = [
             task for task in tasks if task.status is TaskStatus.IN_PROGRESS
         ]
     request_context.send_response([task.task_info for task in tasks])
コード例 #5
0
    def _dispose(self, request_context: RequestContext,
                 params: DisposeRequest) -> None:

        try:
            self._active_sessions.pop(params.owner_uri)

        except KeyError:
            request_context.send_error('Edit data session not found')

        request_context.send_response(DisposeResponse())
コード例 #6
0
    def handle_connect_request(self, request_context: RequestContext,
                               params: ConnectRequestParams) -> None:
        """Kick off a connection in response to an incoming connection request"""
        thread = threading.Thread(target=self._connect_and_respond,
                                  args=(request_context, params))
        thread.daemon = True
        thread.start()
        self.owner_to_thread_map[params.owner_uri] = thread

        request_context.send_response(True)
コード例 #7
0
    def _handle_session_request(
            self, session_operation_request: SessionOperationRequest,
            request_context: RequestContext, session_operation: Callable):
        edit_session = self._get_active_session(
            session_operation_request.owner_uri)
        try:
            result = session_operation(edit_session)
            request_context.send_response(result)

        except Exception as ex:
            request_context.send_error(str(ex))
            self._logger.error(str(ex))
コード例 #8
0
 def _metadata_list_worker(self, request_context: RequestContext,
                           params: MetadataListParameters) -> None:
     try:
         metadata = self._list_metadata(params.owner_uri)
         request_context.send_response(MetadataListResponse(metadata))
     except Exception:
         if self._service_provider.logger is not None:
             self._service_provider.logger.exception(
                 'Unhandled exception while executing the metadata list worker thread'
             )
         request_context.send_error(
             'Unhandled exception while listing metadata')  # TODO: Localize
コード例 #9
0
    def send_definition_using_connected_completions(
            self, request_context: RequestContext,
            scriptparseinfo: ScriptParseInfo, params: TextDocumentPosition,
            context: ConnectionContext) -> bool:
        if not context or not context.is_connected:
            return False

        definition_result: DefinitionResult = None
        completer: PGCompleter = context.pgcompleter
        completions: List[Completion] = completer.get_completions(
            scriptparseinfo.document, None)

        if completions:
            word_under_cursor = scriptparseinfo.document.get_word_under_cursor(
            )
            matching_completion = next(
                completion for completion in completions
                if completion.display == word_under_cursor)
            if matching_completion:
                connection = self._connection_service.get_connection(
                    params.text_document.uri, ConnectionType.QUERY)
                scripter_instance = scripter.Scripter(connection)
                object_metadata = ObjectMetadata(
                    None, None, matching_completion.display_meta,
                    matching_completion.display, matching_completion.schema)
                create_script = scripter_instance.script(
                    ScriptOperation.CREATE, object_metadata)

                if create_script:
                    with tempfile.NamedTemporaryFile(
                            mode='wt',
                            delete=False,
                            encoding='utf-8',
                            suffix='.sql',
                            newline=None) as namedfile:
                        namedfile.write(create_script)
                        if namedfile.name:
                            file_uri = "file:///" + namedfile.name.strip('/')
                            location_in_script = Location(
                                file_uri, Range(Position(0, 1), Position(1,
                                                                         1)))
                            definition_result = DefinitionResult(
                                False, None, [
                                    location_in_script,
                                ])

                            request_context.send_response(
                                definition_result.locations)
                            return True

        if definition_result is None:
            request_context.send_response(DefinitionResult(True, '', []))
            return False
コード例 #10
0
 def _send_default_completions(self, request_context: RequestContext,
                               script_file: ScriptFile,
                               params: TextDocumentPosition) -> bool:
     response = []
     line: str = script_file.get_line(params.position.line)
     (token_text,
      text_range) = TextUtilities.get_text_and_range(params.position, line)
     if token_text:
         completions = self._completion_helper.get_matches(
             token_text, text_range, self.should_lowercase)
         response = completions
     request_context.send_response(response)
     return True
コード例 #11
0
    def _edit_subset(self, request_context: RequestContext,
                     params: EditSubsetParams) -> None:
        session: DataEditorSession = self._active_sessions.get(
            params.owner_uri)

        rows = session.get_rows(params.owner_uri, params.row_start_index,
                                params.row_start_index + params.row_count)

        self._handle_create_row_default_values(rows, session)

        edit_subset_result = EditSubsetResponse(len(rows), rows)

        request_context.send_response(edit_subset_result)
コード例 #12
0
 def _handle_dispose_request(self, request_context: RequestContext,
                             params: QueryDisposeParams):
     try:
         if params.owner_uri not in self.query_results:
             request_context.send_error(NO_QUERY_MESSAGE)  # TODO: Localize
             return
         # Make sure to cancel the query first if it's not executed.
         # If it's not started, then make sure it never starts. If it's executing, make sure
         # that we stop it
         if self.query_results[
                 params.
                 owner_uri].execution_state is not ExecutionState.EXECUTED:
             self.cancel_query(params.owner_uri)
         del self.query_results[params.owner_uri]
         request_context.send_response({})
     except Exception as e:
         request_context.send_unhandled_error_response(e)
コード例 #13
0
 def handle_list_databases(self, request_context: RequestContext, params: ListDatabasesParams):
     """List all databases on the server that the given URI has a connection to"""
     connection = None
     try:
         connection = self.get_connection(params.owner_uri, ConnectionType.DEFAULT)
     except ValueError as err:
         request_context.send_error(str(err))
         return
     query_results = None
     try:
         query_results = _execute_query(connection, 'SELECT datname FROM pg_database WHERE datistemplate = false;')
     except psycopg2.ProgrammingError as err:
         if self._service_provider is not None and self._service_provider.logger is not None:
             self._service_provider.logger.exception('Error listing databases')
         request_context.send_error(str(err))
         return
     database_names = [result[0] for result in query_results]
     request_context.send_response(ListDatabasesResponse(database_names))
コード例 #14
0
    def _handle_create_session_request(self, request_context: RequestContext, params: ConnectionDetails) -> None:
        """Handle a create object explorer session request"""
        # Step 1: Create the session
        try:
            # Make sure we have the appropriate session params
            utils.validate.is_not_none('params', params)

            if params.database_name is None or params.database_name == '':
                params.database_name = self._service_provider[utils.constants.WORKSPACE_SERVICE_NAME].configuration.pgsql.default_database

            # Generate the session ID and create/store the session
            session_id = self._generate_session_uri(params)
            session: ObjectExplorerSession = ObjectExplorerSession(session_id, params)

            # Add the session to session map in a lock to prevent race conditions between check and add
            with self._session_lock:
                if session_id in self._session_map:
                    # Removed the exception for now. But we need to investigate why we would get this
                    if self._service_provider.logger is not None:
                        self._service_provider.logger.error(f'Object explorer session for {session_id} already exists!')
                    request_context.send_response(False)
                    return

                self._session_map[session_id] = session

            # Respond that the session was created
            response = CreateSessionResponse(session_id)
            request_context.send_response(response)

        except Exception as e:
            message = f'Failed to create OE session: {str(e)}'
            if self._service_provider.logger is not None:
                self._service_provider.logger.error(message)
            request_context.send_error(message)
            return

        # Step 2: Connect the session and lookup the root node asynchronously
        try:
            session.init_task = threading.Thread(target=self._initialize_session, args=(request_context, session))
            session.init_task.daemon = True
            session.init_task.start()
        except Exception as e:
            # TODO: Localize
            self._session_created_error(request_context, session, f'Failed to start OE init task: {str(e)}')
コード例 #15
0
 def send_connected_completions(self, request_context: RequestContext,
                                scriptparseinfo: ScriptParseInfo,
                                params: TextDocumentPosition,
                                context: ConnectionContext) -> bool:
     if not context or not context.is_connected:
         return False
     # Else use the completer to query for completions
     completer: PGCompleter = context.pgcompleter
     completions: List[Completion] = completer.get_completions(
         scriptparseinfo.document, None)
     if completions:
         response = [
             LanguageService.to_completion_item(completion, params)
             for completion in completions
         ]
         request_context.send_response(response)
         return True
     # Else return false so the timeout task can be sent instead
     return False
コード例 #16
0
    def _handle_get_database_info_request(
            self, request_context: RequestContext,
            params: GetDatabaseInfoParameters) -> None:
        # Retrieve the connection service
        connection_service = self._service_provider[
            constants.CONNECTION_SERVICE_NAME]
        connection = connection_service.get_connection(params.owner_uri,
                                                       ConnectionType.DEFAULT)

        # Get database info
        database_name = connection.get_dsn_parameters()['dbname']
        owner_query = 'SELECT pg_catalog.pg_get_userbyid(db.datdba) FROM pg_catalog.pg_database db WHERE db.datname = %s'
        with connection.cursor() as cursor:
            cursor.execute(owner_query, (database_name, ))
            owner_result = cursor.fetchall()[0][0]

        # Set up and send the response
        options = {DatabaseInfo.OWNER: owner_result}
        request_context.send_response(
            GetDatabaseInfoResponse(DatabaseInfo(options)))
コード例 #17
0
    def _handle_initialize_request(request_context: RequestContext, params: Optional[InitializeRequestParams]) -> None:
        """
        Sends the capabilities of the tools service language features
        :param request_context: Context for the request
        :param params: Initialization request parameters
        """
        capabilities = ServerCapabilities(
            text_document_sync=TextDocumentSyncKind.INCREMENTAL,
            definition_provider=True,
            references_provider=False,
            document_formatting_provider=True,
            document_range_formatting_provider=True,
            document_highlight_provider=False,
            hover_provider=False,
            completion_provider=CompletionOptions(True, ['.', '-', ':', '\\', '[', '"'])
        )
        result = InitializeResult(capabilities)

        # Send the request
        request_context.send_response(result)
コード例 #18
0
    def _get_session(self, request_context: RequestContext, params: ExpandParameters) -> Optional[ObjectExplorerSession]:
        try:
            utils.validate.is_not_none('params', params)
            utils.validate.is_not_none_or_whitespace('params.node_path', params.node_path)
            utils.validate.is_not_none_or_whitespace('params.session_id', params.session_id)

            session = self._session_map.get(params.session_id)
            if session is None:
                raise ValueError(f'OE session with ID {params.session_id} does not exist')   # TODO: Localize

            if not session.is_ready:
                raise ValueError(f'Object Explorer session with ID {params.session_id} is not ready, yet.')     # TODO: Localize

            request_context.send_response(True)
            return session
        except Exception as e:
            message = f'Failed to expand node: {str(e)}'    # TODO: Localize
            if self._service_provider.logger is not None:
                self._service_provider.logger.error(message)
            request_context.send_error(message)
            return
コード例 #19
0
    def _handle_scriptas_request(self, request_context: RequestContext,
                                 params: ScriptAsParameters) -> None:
        try:
            utils.validate.is_not_none('params', params)

            scripting_operation = params.operation
            connection_service = self._service_provider[
                utils.constants.CONNECTION_SERVICE_NAME]
            connection = connection_service.get_connection(
                params.owner_uri, ConnectionType.QUERY)
            object_metadata = self.create_metadata(params)
            scripter = Scripter(connection)

            script = scripter.script(scripting_operation, object_metadata)
            request_context.send_response(
                ScriptAsResponse(params.owner_uri, script))
        except Exception as e:
            if self._service_provider.logger is not None:
                self._service_provider.logger.exception(
                    'Scripting operation failed')
            request_context.send_error(str(e), params)
コード例 #20
0
    def _handle_cancel_query_request(self, request_context: RequestContext,
                                     params: QueryCancelParams):
        """Handles a 'query/cancel' request"""
        try:
            if params.owner_uri in self.query_results:
                query = self.query_results[params.owner_uri]
            else:
                request_context.send_response(
                    QueryCancelResult(NO_QUERY_MESSAGE))  # TODO: Localize
                return

            # Only cancel the query if we're in a cancellable state
            if query.execution_state is ExecutionState.EXECUTED:
                request_context.send_response(
                    QueryCancelResult(
                        'Query already executed'))  # TODO: Localize
                return

            query.is_canceled = True

            # Only need to do additional work to cancel the query
            # if it's currently running
            if query.execution_state is ExecutionState.EXECUTING:
                self.cancel_query(params.owner_uri)
            request_context.send_response(QueryCancelResult())

        except Exception as e:
            if self._service_provider.logger is not None:
                self._service_provider.logger.exception(str(e))
            request_context.send_unhandled_error_response(e)
コード例 #21
0
    def _handle_close_session_request(self, request_context: RequestContext, params: CloseSessionParameters) -> None:
        """Handle close Object Explorer" sessions request"""
        try:
            utils.validate.is_not_none('params', params)

            # Try to remove the session
            session = self._session_map.pop(params.session_id, None)
            if session is not None:
                self._close_database_connections(session)
                conn_service = self._service_provider[utils.constants.CONNECTION_SERVICE_NAME]
                connect_result = conn_service.disconnect(session.id, ConnectionType.OBJECT_EXLPORER)

                if not connect_result:
                    if self._service_provider.logger is not None:
                        self._service_provider.logger.info(f'Could not close the OE session with Id {session.id}')
                    request_context.send_response(False)
                else:
                    request_context.send_response(True)
            else:
                request_context.send_response(False)
        except Exception as e:
            message = f'Failed to close OE session: {str(e)}'   # TODO: Localize
            if self._service_provider.logger is not None:
                self._service_provider.logger.error(message)
            request_context.send_error(message)
コード例 #22
0
 def handle_completion_resolve_request(self,
                                       request_context: RequestContext,
                                       params: CompletionItem) -> None:
     """Fill in additional details for a CompletionItem. Returns the same CompletionItem over the wire"""
     request_context.send_response(params)
コード例 #23
0
 def _handle_subset_request(self, request_context: RequestContext,
                            params: SubsetParams):
     """Sends a response back to the query/subset request"""
     request_context.send_response(
         self._get_result_subset(request_context, params))
コード例 #24
0
    def _handle_dmp_capabilities_request(
            self,
            request_context: RequestContext,
            params: Optional[CapabilitiesRequestParams]
    ) -> None:
        """
        Sends the capabilities of the tools service data protocol features
        :param request_context: Context of the request
        :param params: Parameters for the capabilities request
        """
        workspace_service = self._service_provider[constants.WORKSPACE_SERVICE_NAME]
        conn_provider_opts = ConnectionProviderOptions([
            ConnectionOption(
                name='host',
                display_name='Server name',
                description='Name of the PostgreSQL instance',
                value_type=ConnectionOption.VALUE_TYPE_STRING,
                special_value_type=ConnectionOption.SPECIAL_VALUE_SERVER_NAME,
                is_identity=True,
                is_required=True,
                group_name='Source'
            ),
            ConnectionOption(
                name='dbname',
                display_name='Database name',
                description='The name of the initial catalog or database in the data source',
                value_type=ConnectionOption.VALUE_TYPE_STRING,
                special_value_type=ConnectionOption.SPECIAL_VALUE_DATABASE_NAME,
                is_identity=True,
                is_required=False,
                group_name='Source',
                default_value=workspace_service.configuration.pgsql.default_database
            ),
            ConnectionOption(
                name='user',
                display_name='User name',
                description='Indicates the user ID to be used when connecting to the data source',
                value_type=ConnectionOption.VALUE_TYPE_STRING,
                special_value_type=ConnectionOption.SPECIAL_VALUE_USER_NAME,
                is_identity=True,
                is_required=True,
                group_name='Security'
            ),
            ConnectionOption(
                name='password',
                display_name='Password',
                description='Indicates the password to be used when connecting to the data source',
                value_type=ConnectionOption.VALUE_TYPE_PASSWORD,
                special_value_type=ConnectionOption.SPECIAL_VALUE_PASSWORD_NAME,
                is_identity=True,
                is_required=True,
                group_name='Security'
            ),
            ConnectionOption(
                name='azureAccountToken',
                display_name='Access Token',
                description='Indicates an Active Directory access token to be used when connecting to the data source',
                value_type=ConnectionOption.VALUE_TYPE_ACCESS_TOKEN,
                special_value_type=ConnectionOption.SPECIAL_VALUE_ACCESS_TOKEN_NAME,
                is_identity=True,
                is_required=False,
                group_name='Security'
            ),
            ConnectionOption(
                name='hostaddr',
                display_name='Host IP address',
                description='IP address of the server',
                value_type=ConnectionOption.VALUE_TYPE_STRING,
                group_name='Server'
            ),
            ConnectionOption(
                name='port',
                display_name='Port',
                description='Port number for the server',
                value_type=ConnectionOption.VALUE_TYPE_STRING,
                group_name='Server'
            ),
            ConnectionOption(
                name='connectTimeout',
                display_name='Connect timeout',
                description='Seconds to wait before timing out when connecting',
                value_type=ConnectionOption.VALUE_TYPE_NUMBER,
                group_name='Client',
                default_value='15'
            ),
            ConnectionOption(
                name='clientEncoding',
                display_name='Client encoding',
                description='The client encoding for the connection',
                value_type=ConnectionOption.VALUE_TYPE_STRING,
                group_name='Client'
            ),
            ConnectionOption(
                name='options',
                display_name='Command-line options',
                description='Command-line options to send to the server when the connection starts',
                value_type=ConnectionOption.VALUE_TYPE_STRING,
                group_name='Server'
            ),
            ConnectionOption(
                name='applicationName',
                display_name='Application name',
                description='Value for the "application_name" configuration parameter',
                value_type=ConnectionOption.VALUE_TYPE_STRING,
                group_name='Client',
                special_value_type=ConnectionOption.SPECIAL_VALUE_APP_NAME
            ),
            ConnectionOption(
                name='sslmode',
                display_name='SSL mode',
                description='The SSL mode to use when connecting',
                value_type=ConnectionOption.VALUE_TYPE_CATEGORY,
                group_name='SSL',
                category_values=[
                    CategoryValue('Disable', 'disable'),
                    CategoryValue('Allow', 'allow'),
                    CategoryValue('Prefer', 'prefer'),
                    CategoryValue('Require', 'require'),
                    CategoryValue('Verify-CA', 'verify-ca'),
                    CategoryValue('Verify-Full', 'verify-full'),
                ],
                default_value='prefer'
            ),
            ConnectionOption(
                name='sslcompression',
                display_name='Use SSL compression',
                description='Whether to compress SSL connections',
                value_type=ConnectionOption.VALUE_TYPE_BOOLEAN,
                group_name='SSL'
            ),
            ConnectionOption(
                name='sslcert',
                display_name='SSL certificate filename',
                description='The filename of the SSL certificate to use',
                value_type=ConnectionOption.VALUE_TYPE_STRING,
                group_name='SSL'
            ),
            ConnectionOption(
                name='sslkey',
                display_name='SSL key filename',
                description='The filename of the key to use for the SSL certificate',
                value_type=ConnectionOption.VALUE_TYPE_STRING,
                group_name='SSL'
            ),
            ConnectionOption(
                name='sslrootcert',
                display_name='SSL root certificate filename',
                description='The filename of the SSL root CA certificate to use',
                value_type=ConnectionOption.VALUE_TYPE_STRING,
                group_name='SSL'
            ),
            ConnectionOption(
                name='sslcrl',
                display_name='SSL CRL filename',
                description='The filename of the SSL certificate revocation list to use',
                value_type=ConnectionOption.VALUE_TYPE_STRING,
                group_name='SSL'
            ),
            ConnectionOption(
                name='requirepeer',
                display_name='Require peer',
                description='The required username of the server process',
                value_type=ConnectionOption.VALUE_TYPE_STRING,
                group_name='Server'
            ),
            ConnectionOption(
                name='service',
                display_name='Service name',
                description='The service name in pg_service.conf to use for connection parameters',
                value_type=ConnectionOption.VALUE_TYPE_STRING,
                group_name='Client'
            )
        ])
        capabilities = DMPServerCapabilities('1.0', 'PGSQL', 'PostgreSQL', conn_provider_opts, [BACKUP_OPTIONS, RESTORE_OPTIONS, SERIALIZATION_OPTIONS])
        result = CapabilitiesResult(capabilities)

        # Send the response
        request_context.send_response(result)
コード例 #25
0
 def handle_disconnect_request(self, request_context: RequestContext, params: DisconnectRequestParams) -> None:
     """Close a connection in response to an incoming disconnection request"""
     request_context.send_response(self.disconnect(params.owner_uri, params.type))
コード例 #26
0
    def _edit_initialize(self, request_context: RequestContext,
                         params: InitializeEditParams) -> None:
        utils.validate.is_object_params_not_none_or_whitespace(
            'params', params, 'owner_uri', 'schema_name', 'object_name',
            'object_type')

        connection = self._connection_service.get_connection(
            params.owner_uri, ConnectionType.QUERY)
        session = DataEditorSession(SmoEditTableMetadataFactory())
        self._active_sessions[params.owner_uri] = session

        if params.query_string is not None:
            request_context.send_error(
                'Edit data with custom query is not supported currently.')
            return

        def query_executer(query: str, columns: List[DbColumn],
                           on_query_execution_complete: Callable):
            def on_resultset_complete(
                    result_set_params: ResultSetNotificationParams):
                result_set_params.result_set_summary.column_info = columns
                request_context.send_notification(
                    RESULT_SET_UPDATED_NOTIFICATION, result_set_params)
                request_context.send_notification(
                    RESULT_SET_AVAILABLE_NOTIFICATION, result_set_params)
                request_context.send_notification(
                    RESULT_SET_COMPLETE_NOTIFICATION, result_set_params)

            def on_query_complete(
                    query_complete_params: QueryCompleteNotificationParams):
                on_query_execution_complete(
                    DataEditSessionExecutionState(
                        self._query_execution_service.get_query(
                            params.owner_uri)))
                request_context.send_notification(QUERY_COMPLETE_NOTIFICATION,
                                                  query_complete_params)

            worker_args = ExecuteRequestWorkerArgs(
                params.owner_uri,
                connection,
                request_context,
                ResultSetStorageType.IN_MEMORY,
                on_resultset_complete=on_resultset_complete,
                on_query_complete=on_query_complete)
            execution_params = ExecuteStringParams()
            execution_params.query = query
            execution_params.owner_uri = params.owner_uri
            self._query_execution_service._start_query_execution_thread(
                request_context, execution_params, worker_args)

        def on_success():
            request_context.send_notification(
                SESSION_READY_NOTIFICATION,
                SessionReadyNotificationParams(params.owner_uri, True, None))

        def on_failure(error: str):
            request_context.send_notification(
                SESSION_READY_NOTIFICATION,
                SessionReadyNotificationParams(params.owner_uri, False, error))

        session.initialize(params, connection, query_executer, on_success,
                           on_failure)
        request_context.send_response({})