def handle_restore_request(self, request_context: RequestContext,
                            params: RestoreParams) -> None:
     """
     Respond to restore/restore requests by performing a restore
     """
     connection_info: ConnectionInfo = self._service_provider[
         constants.CONNECTION_SERVICE_NAME].get_connection_info(
             params.owner_uri)
     if connection_info is None:
         request_context.send_error(
             'No connection corresponding to the given owner URI'
         )  # TODO: Localize
         return
     host = connection_info.details.options['host']
     database = connection_info.details.options['dbname']
     task = Task(
         'Restore',
         f'Host: {host}, Database: {database}',
         constants.PROVIDER_NAME,
         host,
         database,
         request_context,  # TODO: Localize
         functools.partial(_perform_restore, connection_info, params))
     self._service_provider[constants.TASK_SERVICE_NAME].start_task(task)
     request_context.send_response({})
    def _connect_and_respond(self, request_context: RequestContext, params: ConnectRequestParams) -> None:
        """Open a connection and fire the connection complete notification"""
        response = self.connect(params)

        # Send the connection complete response unless the connection was canceled
        if response is not None:
            request_context.send_notification(CONNECTION_COMPLETE_METHOD, response)
    def _expand_node_thread(self, is_refresh: bool, request_context: RequestContext, params: ExpandParameters, session: ObjectExplorerSession):
        try:
            response = ExpandCompletedParameters(session.id, params.node_path)
            response.nodes = route_request(is_refresh, session, params.node_path)

            request_context.send_notification(EXPAND_COMPLETED_METHOD, response)
        except Exception as e:
            self._expand_node_error(request_context, params, str(e))
    def _expand_node_error(self, request_context: RequestContext, params: ExpandParameters, message: str):
        if self._service_provider.logger is not None:
            self._service_provider.logger.warning(f'OE service errored while expanding node: {message}')

        response = ExpandCompletedParameters(params.session_id, params.node_path)
        response.error_message = f'Failed to expand node: {message}'    # TODO: Localize

        request_context.send_notification(EXPAND_COMPLETED_METHOD, response)
 def handle_cancellation_request(self, request_context: RequestContext, params: CancelConnectParams) -> None:
     """Cancel a connection attempt in response to a cancellation request"""
     cancellation_key = (params.owner_uri, params.type)
     with self._cancellation_lock:
         connection_found = cancellation_key in self._cancellation_map
         if connection_found:
             self._cancellation_map[cancellation_key].cancel()
     request_context.send_response(connection_found)
Esempio n. 6
0
 def handle_cancel_request(self, request_context: RequestContext,
                           params: CancelTaskParameters) -> None:
     """Respond to tasks/canceltask requests by canceling the requested task"""
     try:
         request_context.send_response(
             self._task_map[params.task_id].cancel())
     except KeyError:
         request_context.send_response(False)
Esempio n. 7
0
 def handle_list_request(self, request_context: RequestContext,
                         params: ListTasksParameters) -> None:
     """Respond to tasks/listtasks requests by returning the TaskInfo for all tasks"""
     tasks = list(self._task_map.values())
     if params.list_active_tasks_only:
         tasks = [
             task for task in tasks if task.status is TaskStatus.IN_PROGRESS
         ]
     request_context.send_response([task.task_info for task in tasks])
Esempio n. 8
0
    def _dispose(self, request_context: RequestContext,
                 params: DisposeRequest) -> None:

        try:
            self._active_sessions.pop(params.owner_uri)

        except KeyError:
            request_context.send_error('Edit data session not found')

        request_context.send_response(DisposeResponse())
Esempio n. 9
0
 def __init__(self):
     RequestContext.__init__(self, None, None)
     self.last_response_params = None
     self.last_notification_method = None
     self.last_notification_params = None
     self.last_error_message = None
     self.send_response = mock.Mock(side_effect=self.send_response_impl)
     self.send_notification = mock.Mock(side_effect=self.send_notification_impl)
     self.send_error = mock.Mock(side_effect=self.send_error_impl)
     self.send_unhandled_error_response = mock.Mock(side_effect=self.send_unhandled_error_response_impl)
    def handle_connect_request(self, request_context: RequestContext,
                               params: ConnectRequestParams) -> None:
        """Kick off a connection in response to an incoming connection request"""
        thread = threading.Thread(target=self._connect_and_respond,
                                  args=(request_context, params))
        thread.daemon = True
        thread.start()
        self.owner_to_thread_map[params.owner_uri] = thread

        request_context.send_response(True)
Esempio n. 11
0
    def _handle_session_request(
            self, session_operation_request: SessionOperationRequest,
            request_context: RequestContext, session_operation: Callable):
        edit_session = self._get_active_session(
            session_operation_request.owner_uri)
        try:
            result = session_operation(edit_session)
            request_context.send_response(result)

        except Exception as ex:
            request_context.send_error(str(ex))
            self._logger.error(str(ex))
Esempio n. 12
0
 def _metadata_list_worker(self, request_context: RequestContext,
                           params: MetadataListParameters) -> None:
     try:
         metadata = self._list_metadata(params.owner_uri)
         request_context.send_response(MetadataListResponse(metadata))
     except Exception:
         if self._service_provider.logger is not None:
             self._service_provider.logger.exception(
                 'Unhandled exception while executing the metadata list worker thread'
             )
         request_context.send_error(
             'Unhandled exception while listing metadata')  # TODO: Localize
Esempio n. 13
0
    def send_definition_using_connected_completions(
            self, request_context: RequestContext,
            scriptparseinfo: ScriptParseInfo, params: TextDocumentPosition,
            context: ConnectionContext) -> bool:
        if not context or not context.is_connected:
            return False

        definition_result: DefinitionResult = None
        completer: PGCompleter = context.pgcompleter
        completions: List[Completion] = completer.get_completions(
            scriptparseinfo.document, None)

        if completions:
            word_under_cursor = scriptparseinfo.document.get_word_under_cursor(
            )
            matching_completion = next(
                completion for completion in completions
                if completion.display == word_under_cursor)
            if matching_completion:
                connection = self._connection_service.get_connection(
                    params.text_document.uri, ConnectionType.QUERY)
                scripter_instance = scripter.Scripter(connection)
                object_metadata = ObjectMetadata(
                    None, None, matching_completion.display_meta,
                    matching_completion.display, matching_completion.schema)
                create_script = scripter_instance.script(
                    ScriptOperation.CREATE, object_metadata)

                if create_script:
                    with tempfile.NamedTemporaryFile(
                            mode='wt',
                            delete=False,
                            encoding='utf-8',
                            suffix='.sql',
                            newline=None) as namedfile:
                        namedfile.write(create_script)
                        if namedfile.name:
                            file_uri = "file:///" + namedfile.name.strip('/')
                            location_in_script = Location(
                                file_uri, Range(Position(0, 1), Position(1,
                                                                         1)))
                            definition_result = DefinitionResult(
                                False, None, [
                                    location_in_script,
                                ])

                            request_context.send_response(
                                definition_result.locations)
                            return True

        if definition_result is None:
            request_context.send_response(DefinitionResult(True, '', []))
            return False
Esempio n. 14
0
 def _send_default_completions(self, request_context: RequestContext,
                               script_file: ScriptFile,
                               params: TextDocumentPosition) -> bool:
     response = []
     line: str = script_file.get_line(params.position.line)
     (token_text,
      text_range) = TextUtilities.get_text_and_range(params.position, line)
     if token_text:
         completions = self._completion_helper.get_matches(
             token_text, text_range, self.should_lowercase)
         response = completions
     request_context.send_response(response)
     return True
Esempio n. 15
0
    def _edit_subset(self, request_context: RequestContext,
                     params: EditSubsetParams) -> None:
        session: DataEditorSession = self._active_sessions.get(
            params.owner_uri)

        rows = session.get_rows(params.owner_uri, params.row_start_index,
                                params.row_start_index + params.row_count)

        self._handle_create_row_default_values(rows, session)

        edit_subset_result = EditSubsetResponse(len(rows), rows)

        request_context.send_response(edit_subset_result)
    def _session_created_error(self, request_context: RequestContext, session: ObjectExplorerSession, message: str):
        if self._service_provider.logger is not None:
            self._service_provider.logger.warning(f'OE service errored while creating session: {message}')

        # Create error notification
        response = SessionCreatedParameters()
        response.success = False
        response.session_id = session.id
        response.root_node = None
        response.error_message = message
        request_context.send_notification(SESSION_CREATED_METHOD, response)

        # Clean up the session from the session map
        self._session_map.pop(session.id)
    def _handle_close_session_request(self, request_context: RequestContext, params: CloseSessionParameters) -> None:
        """Handle close Object Explorer" sessions request"""
        try:
            utils.validate.is_not_none('params', params)

            # Try to remove the session
            session = self._session_map.pop(params.session_id, None)
            if session is not None:
                self._close_database_connections(session)
                conn_service = self._service_provider[utils.constants.CONNECTION_SERVICE_NAME]
                connect_result = conn_service.disconnect(session.id, ConnectionType.OBJECT_EXLPORER)

                if not connect_result:
                    if self._service_provider.logger is not None:
                        self._service_provider.logger.info(f'Could not close the OE session with Id {session.id}')
                    request_context.send_response(False)
                else:
                    request_context.send_response(True)
            else:
                request_context.send_response(False)
        except Exception as e:
            message = f'Failed to close OE session: {str(e)}'   # TODO: Localize
            if self._service_provider.logger is not None:
                self._service_provider.logger.error(message)
            request_context.send_error(message)
Esempio n. 18
0
    def _handle_cancel_query_request(self, request_context: RequestContext,
                                     params: QueryCancelParams):
        """Handles a 'query/cancel' request"""
        try:
            if params.owner_uri in self.query_results:
                query = self.query_results[params.owner_uri]
            else:
                request_context.send_response(
                    QueryCancelResult(NO_QUERY_MESSAGE))  # TODO: Localize
                return

            # Only cancel the query if we're in a cancellable state
            if query.execution_state is ExecutionState.EXECUTED:
                request_context.send_response(
                    QueryCancelResult(
                        'Query already executed'))  # TODO: Localize
                return

            query.is_canceled = True

            # Only need to do additional work to cancel the query
            # if it's currently running
            if query.execution_state is ExecutionState.EXECUTING:
                self.cancel_query(params.owner_uri)
            request_context.send_response(QueryCancelResult())

        except Exception as e:
            if self._service_provider.logger is not None:
                self._service_provider.logger.exception(str(e))
            request_context.send_unhandled_error_response(e)
    def _initialize_session(self, request_context: RequestContext, session: ObjectExplorerSession):
        conn_service = self._service_provider[utils.constants.CONNECTION_SERVICE_NAME]
        connection = None

        try:
            # Step 1: Connect with the provided connection details
            connect_request = ConnectRequestParams(
                session.connection_details,
                session.id,
                ConnectionType.OBJECT_EXLPORER
            )
            connect_result = conn_service.connect(connect_request)
            if connect_result is None:
                raise RuntimeError('Connection was cancelled during connect')   # TODO Localize
            if connect_result.error_message is not None:
                raise RuntimeError(connect_result.error_message)

            # Step 2: Get the connection to use for object explorer
            connection = conn_service.get_connection(session.id, ConnectionType.OBJECT_EXLPORER)

            # Step 3: Create the PGSMO Server object for the session and create the root node for the server
            session.server = Server(connection, functools.partial(self._create_connection, session))
            metadata = ObjectMetadata(session.server.urn_base, None, 'Database', session.server.maintenance_db_name)
            node = NodeInfo()
            node.label = session.connection_details.options['dbname']
            node.is_leaf = False
            node.node_path = session.id
            node.node_type = 'Database'
            node.metadata = metadata

            # Step 4: Send the completion notification to the server
            response = SessionCreatedParameters()
            response.success = True
            response.session_id = session.id
            response.root_node = node
            response.error_message = None
            request_context.send_notification(SESSION_CREATED_METHOD, response)

            # Mark the session as complete
            session.is_ready = True

        except Exception as e:
            # Return a notification that an error occurred
            message = f'Failed to initialize object explorer session: {str(e)}'  # TODO Localize
            self._session_created_error(request_context, session, message)

            # Attempt to clean up the connection
            if connection is not None:
                conn_service.disconnect(session.id, ConnectionType.OBJECT_EXLPORER)
Esempio n. 20
0
    def handle_definition_request(
            self, request_context: RequestContext,
            text_document_position: TextDocumentPosition) -> None:
        request_context.send_notification(
            STATUS_CHANGE_NOTIFICATION,
            StatusChangeParams(
                owner_uri=text_document_position.text_document.uri,
                status="DefinitionRequested"))

        def do_send_default_empty_response():
            request_context.send_response([])

        if self.should_skip_intellisense(
                text_document_position.text_document.uri):
            do_send_default_empty_response()
            return

        script_file: ScriptFile = self._workspace_service.workspace.get_file(
            text_document_position.text_document.uri)
        if script_file is None:
            do_send_default_empty_response()
            return

        script_parse_info: ScriptParseInfo = self.get_script_parse_info(
            text_document_position.text_document.uri,
            create_if_not_exists=False)
        if not script_parse_info or not script_parse_info.can_queue():
            do_send_default_empty_response()
            return

        cursor_position: int = len(
            script_file.get_text_in_range(
                Range.from_data(0, 0, text_document_position.position.line,
                                text_document_position.position.character)))
        text: str = script_file.get_all_text()
        script_parse_info.document = Document(text, cursor_position)

        operation = QueuedOperation(
            script_parse_info.connection_key,
            functools.partial(self.send_definition_using_connected_completions,
                              request_context, script_parse_info,
                              text_document_position),
            functools.partial(do_send_default_empty_response))
        self.operations_queue.add_operation(operation)
        request_context.send_notification(
            STATUS_CHANGE_NOTIFICATION,
            StatusChangeParams(
                owner_uri=text_document_position.text_document.uri,
                status="DefinitionRequestCompleted"))
Esempio n. 21
0
    def _handle_execute_query_request(
            self, request_context: RequestContext,
            params: ExecuteRequestParamsBase) -> None:
        """Kick off thread to execute query in response to an incoming execute query request"""
        def before_query_initialize(before_query_initialize_params):
            # Send a response to indicate that the query was kicked off
            request_context.send_response(before_query_initialize_params)

        def on_batch_start(batch_event_params):
            request_context.send_notification(BATCH_START_NOTIFICATION,
                                              batch_event_params)

        def on_message_notification(notice_message_params):
            request_context.send_notification(MESSAGE_NOTIFICATION,
                                              notice_message_params)

        def on_resultset_complete(result_set_params):
            request_context.send_notification(
                RESULT_SET_AVAILABLE_NOTIFICATION, result_set_params)
            request_context.send_notification(RESULT_SET_COMPLETE_NOTIFICATION,
                                              result_set_params)

        def on_batch_complete(batch_event_params):
            request_context.send_notification(BATCH_COMPLETE_NOTIFICATION,
                                              batch_event_params)

        def on_query_complete(query_complete_params):
            request_context.send_notification(QUERY_COMPLETE_NOTIFICATION,
                                              query_complete_params)

        # Get a connection for the query
        try:
            conn = self._get_connection(params.owner_uri, ConnectionType.QUERY)
        except Exception as e:
            if self._service_provider.logger is not None:
                self._service_provider.logger.exception(
                    'Encountered exception while handling query request'
                )  # TODO: Localize
            request_context.send_unhandled_error_response(e)
            return

        worker_args = ExecuteRequestWorkerArgs(
            params.owner_uri, conn, request_context,
            ResultSetStorageType.FILE_STORAGE, before_query_initialize,
            on_batch_start, on_message_notification, on_resultset_complete,
            on_batch_complete, on_query_complete)

        self._start_query_execution_thread(request_context, params,
                                           worker_args)
 def test_initialization_with_none_as_schema_param(self):
     request_context = RequestContext(None, None)
     self._initialize_edit_request.schema_name = None
     errormsg = 'Parameter schema_name contains a None, empty, or whitespace string'
     self.assert_exception_on_method_call(
         ValueError, errormsg, self._service_under_test._edit_initialize,
         request_context, self._initialize_edit_request)
Esempio n. 23
0
    def _resolve_query_exception(self,
                                 e: Exception,
                                 query: Query,
                                 request_context: RequestContext,
                                 conn: 'psycopg2.connection',
                                 is_rollback_error=False):
        utils.log.log_debug(
            self._service_provider.logger,
            f'Query execution failed for following query: {query.query_text}\n {e}'
        )
        if isinstance(e, psycopg2.DatabaseError) or isinstance(
                e, RuntimeError) or isinstance(
                    e, psycopg2.extensions.QueryCanceledError):
            error_message = str(e)
        else:
            error_message = 'Unhandled exception while executing query: {}'.format(
                str(e))  # TODO: Localize
            if self._service_provider.logger is not None:
                self._service_provider.logger.exception(
                    'Unhandled exception while executing query')

        # If the error occured during rollback, add a note about it
        if is_rollback_error:
            error_message = 'Error while rolling back open transaction due to previous failure: ' + error_message  # TODO: Localize

        # Send a message with the error to the client
        result_message_params = self.build_message_params(
            query.owner_uri, query.batches[query.current_batch_index].id,
            error_message, True)
        request_context.send_notification(MESSAGE_NOTIFICATION,
                                          result_message_params)

        # If there was a failure in the middle of a transaction, roll it back.
        # Note that conn.rollback() won't work since the connection is in autocommit mode
        if not is_rollback_error and conn.get_transaction_status(
        ) is psycopg2.extensions.TRANSACTION_STATUS_INERROR:
            rollback_query = Query(
                query.owner_uri, 'ROLLBACK',
                QueryExecutionSettings(ExecutionPlanOptions(), None),
                QueryEvents())
            try:
                rollback_query.execute(conn)
            except Exception as rollback_exception:
                # If the rollback failed, handle the error as usual but don't try to roll back again
                self._resolve_query_exception(rollback_exception,
                                              rollback_query, request_context,
                                              conn, True)
    def test_initialization_with_empty_object_type(self):
        request_context = RequestContext(None, None)
        self._initialize_edit_request.object_type = ' '
        errormsg = 'Parameter object_type contains a None, empty, or whitespace string'

        self.assert_exception_on_method_call(
            ValueError, errormsg, self._service_under_test._edit_initialize,
            request_context, self._initialize_edit_request)
Esempio n. 25
0
 def send_connected_completions(self, request_context: RequestContext,
                                scriptparseinfo: ScriptParseInfo,
                                params: TextDocumentPosition,
                                context: ConnectionContext) -> bool:
     if not context or not context.is_connected:
         return False
     # Else use the completer to query for completions
     completer: PGCompleter = context.pgcompleter
     completions: List[Completion] = completer.get_completions(
         scriptparseinfo.document, None)
     if completions:
         response = [
             LanguageService.to_completion_item(completion, params)
             for completion in completions
         ]
         request_context.send_response(response)
         return True
     # Else return false so the timeout task can be sent instead
     return False
    def _handle_initialize_request(request_context: RequestContext, params: Optional[InitializeRequestParams]) -> None:
        """
        Sends the capabilities of the tools service language features
        :param request_context: Context for the request
        :param params: Initialization request parameters
        """
        capabilities = ServerCapabilities(
            text_document_sync=TextDocumentSyncKind.INCREMENTAL,
            definition_provider=True,
            references_provider=False,
            document_formatting_provider=True,
            document_range_formatting_provider=True,
            document_highlight_provider=False,
            hover_provider=False,
            completion_provider=CompletionOptions(True, ['.', '-', ':', '\\', '[', '"'])
        )
        result = InitializeResult(capabilities)

        # Send the request
        request_context.send_response(result)
Esempio n. 27
0
    def _handle_get_database_info_request(
            self, request_context: RequestContext,
            params: GetDatabaseInfoParameters) -> None:
        # Retrieve the connection service
        connection_service = self._service_provider[
            constants.CONNECTION_SERVICE_NAME]
        connection = connection_service.get_connection(params.owner_uri,
                                                       ConnectionType.DEFAULT)

        # Get database info
        database_name = connection.get_dsn_parameters()['dbname']
        owner_query = 'SELECT pg_catalog.pg_get_userbyid(db.datdba) FROM pg_catalog.pg_database db WHERE db.datname = %s'
        with connection.cursor() as cursor:
            cursor.execute(owner_query, (database_name, ))
            owner_result = cursor.fetchall()[0][0]

        # Set up and send the response
        options = {DatabaseInfo.OWNER: owner_result}
        request_context.send_response(
            GetDatabaseInfoResponse(DatabaseInfo(options)))
 def test_initialization(self, mockdataeditorsession):
     queue = Queue()
     message = JSONRPCMessage.from_dictionary({
         'id': '123',
         'method': 'edit/initialize',
         'params': {}
     })
     request_context = RequestContext(message, queue)
     self._service_under_test._edit_initialize(
         request_context, self._initialize_edit_request)
     mockdataeditorsession.assert_called()
    def _get_session(self, request_context: RequestContext, params: ExpandParameters) -> Optional[ObjectExplorerSession]:
        try:
            utils.validate.is_not_none('params', params)
            utils.validate.is_not_none_or_whitespace('params.node_path', params.node_path)
            utils.validate.is_not_none_or_whitespace('params.session_id', params.session_id)

            session = self._session_map.get(params.session_id)
            if session is None:
                raise ValueError(f'OE session with ID {params.session_id} does not exist')   # TODO: Localize

            if not session.is_ready:
                raise ValueError(f'Object Explorer session with ID {params.session_id} is not ready, yet.')     # TODO: Localize

            request_context.send_response(True)
            return session
        except Exception as e:
            message = f'Failed to expand node: {str(e)}'    # TODO: Localize
            if self._service_provider.logger is not None:
                self._service_provider.logger.error(message)
            request_context.send_error(message)
            return
Esempio n. 30
0
    def _handle_scriptas_request(self, request_context: RequestContext,
                                 params: ScriptAsParameters) -> None:
        try:
            utils.validate.is_not_none('params', params)

            scripting_operation = params.operation
            connection_service = self._service_provider[
                utils.constants.CONNECTION_SERVICE_NAME]
            connection = connection_service.get_connection(
                params.owner_uri, ConnectionType.QUERY)
            object_metadata = self.create_metadata(params)
            scripter = Scripter(connection)

            script = scripter.script(scripting_operation, object_metadata)
            request_context.send_response(
                ScriptAsResponse(params.owner_uri, script))
        except Exception as e:
            if self._service_provider.logger is not None:
                self._service_provider.logger.exception(
                    'Scripting operation failed')
            request_context.send_error(str(e), params)