def setUp(self):
        """Set up the test by creating a query with multiple batches"""
        self.statement_list = statement_list = [
            'select version;', 'select * from t1;'
        ]
        self.statement_str = ''.join(statement_list)
        self.query_uri = 'test_uri'
        self.query = Query(
            self.query_uri, self.statement_str,
            QueryExecutionSettings(ExecutionPlanOptions(),
                                   ResultSetStorageType.FILE_STORAGE),
            QueryEvents())

        self.mock_query_results = [('Id1', 'Value1'), ('Id2', 'Value2')]
        self.cursor = utils.MockCursor(self.mock_query_results)
        self.connection = utils.MockConnection(cursor=self.cursor)

        self.columns_info = []
        db_column_id = DbColumn()
        db_column_id.data_type = 'text'
        db_column_id.column_name = 'Id'
        db_column_value = DbColumn()
        db_column_value.data_type = 'text'
        db_column_value.column_name = 'Value'
        self.columns_info = [db_column_id, db_column_value]
        self.get_columns_info_mock = mock.Mock(return_value=self.columns_info)
    def execute_get_subset_raises_error_when_index_not_in_range(
            self, batch_index: int):
        full_query = 'Select * from t1;'
        query = Query('test_uri', full_query,
                      QueryExecutionSettings(ExecutionPlanOptions(), None),
                      QueryEvents())

        with self.assertRaises(IndexError) as context_manager:
            query.get_subset(batch_index, 0, 10)
            self.assertEquals(
                'Batch index cannot be less than 0 or greater than the number of batches',
                context_manager.exception.args[0])
    def test_get_subset(self):
        full_query = 'Select * from t1;'
        query = Query('test_uri', full_query,
                      QueryExecutionSettings(ExecutionPlanOptions(), None),
                      QueryEvents())
        expected_subset = []

        mock_batch = mock.MagicMock()
        mock_batch.get_subset = mock.Mock(return_value=expected_subset)

        query._batches = [mock_batch]

        subset = query.get_subset(0, 0, 10)

        self.assertEqual(expected_subset, subset)
        mock_batch.get_subset.assert_called_once_with(0, 10)
    def test_batches_strip_comments(self):
        """Test that we do not attempt to execute a batch consisting only of comments"""
        full_query = '''select * from t1;
-- test
-- test
;select * from t1;
-- test
-- test;'''

        # If I build a query that contains a batch consisting of only comments, in addition to other valid batches
        query = Query('test_uri', full_query,
                      QueryExecutionSettings(ExecutionPlanOptions(), None),
                      QueryEvents())

        # Then there is only a batch for each non-comment statement
        self.assertEqual(len(query.batches), 2)

        # And each batch should have the correct location information
        expected_selections = [
            SelectionData(start_line=0,
                          start_column=0,
                          end_line=0,
                          end_column=17),
            SelectionData(start_line=3,
                          start_column=1,
                          end_line=3,
                          end_column=18)
        ]

        for index, batch in enumerate(query.batches):
            self.assertEqual(
                _tuple_from_selection_data(batch.selection),
                _tuple_from_selection_data(expected_selections[index]))
    def test_batch_selections_do_block(self):
        """Test that the query sets up batch objects with correct selection information for blocks containing statements"""
        full_query = '''DO $$
BEGIN
RAISE NOTICE 'Hello world 1';
RAISE NOTICE 'Hello world 2';
END $$;
select * from t1;'''

        # If I build a query that contains a block that contains several statements
        query = Query('test_uri', full_query,
                      QueryExecutionSettings(ExecutionPlanOptions(), None),
                      QueryEvents())

        # Then there is a batch for each top-level statement
        self.assertEqual(len(query.batches), 2)

        # And each batch should have the correct location information
        expected_selections = [
            SelectionData(start_line=0,
                          start_column=0,
                          end_line=4,
                          end_column=7),
            SelectionData(start_line=5,
                          start_column=0,
                          end_line=5,
                          end_column=17)
        ]
        for index, batch in enumerate(query.batches):
            self.assertEqual(
                _tuple_from_selection_data(batch.selection),
                _tuple_from_selection_data(expected_selections[index]))
Beispiel #6
0
    def _resolve_query_exception(self,
                                 e: Exception,
                                 query: Query,
                                 request_context: RequestContext,
                                 conn: 'psycopg2.connection',
                                 is_rollback_error=False):
        utils.log.log_debug(
            self._service_provider.logger,
            f'Query execution failed for following query: {query.query_text}\n {e}'
        )
        if isinstance(e, psycopg2.DatabaseError) or isinstance(
                e, RuntimeError) or isinstance(
                    e, psycopg2.extensions.QueryCanceledError):
            error_message = str(e)
        else:
            error_message = 'Unhandled exception while executing query: {}'.format(
                str(e))  # TODO: Localize
            if self._service_provider.logger is not None:
                self._service_provider.logger.exception(
                    'Unhandled exception while executing query')

        # If the error occured during rollback, add a note about it
        if is_rollback_error:
            error_message = 'Error while rolling back open transaction due to previous failure: ' + error_message  # TODO: Localize

        # Send a message with the error to the client
        result_message_params = self.build_message_params(
            query.owner_uri, query.batches[query.current_batch_index].id,
            error_message, True)
        request_context.send_notification(MESSAGE_NOTIFICATION,
                                          result_message_params)

        # If there was a failure in the middle of a transaction, roll it back.
        # Note that conn.rollback() won't work since the connection is in autocommit mode
        if not is_rollback_error and conn.get_transaction_status(
        ) is psycopg2.extensions.TRANSACTION_STATUS_INERROR:
            rollback_query = Query(
                query.owner_uri, 'ROLLBACK',
                QueryExecutionSettings(ExecutionPlanOptions(), None),
                QueryEvents())
            try:
                rollback_query.execute(conn)
            except Exception as rollback_exception:
                # If the rollback failed, handle the error as usual but don't try to roll back again
                self._resolve_query_exception(rollback_exception,
                                              rollback_query, request_context,
                                              conn, True)
    def test_hash_character_processed_correctly(self):
        """Test that xor operator is not taken for an inline comment delimiter"""
        full_query = "select 42 # 24;"
        query = Query('test_uri', full_query,
                      QueryExecutionSettings(ExecutionPlanOptions(), None),
                      QueryEvents())

        self.assertEqual(len(query.batches), 1)
        self.assertEqual(full_query, query.batches[0].batch_text)
    def test_initialize_calls_failure_when_query_status_is_not_executed(self):
        query = Query('owner', '', QueryExecutionSettings(None, None),
                      QueryEvents())
        self._query_executer = mock.MagicMock(
            return_value=DataEditSessionExecutionState(query))

        self._data_editor_session.initialize(self._initialize_edit_request,
                                             self._connection,
                                             self._query_executer,
                                             self._on_success,
                                             self._on_failure)

        self._query_executer.assert_called_once()
    def test_initialize_calls_success(self):
        query = Query('owner', '', QueryExecutionSettings(None, None),
                      QueryEvents())
        query._execution_state = ExecutionState.EXECUTED

        rows = [("Result1", 53), (
            "Result2",
            None,
        )]
        result_set = self.get_result_set(rows)

        batch = Batch('', 1, None)
        batch._result_set = result_set

        query._batches = [batch]
        self._query_executer = mock.MagicMock(
            return_value=DataEditSessionExecutionState(query))

        self._data_editor_session.initialize(self._initialize_edit_request,
                                             self._connection,
                                             self._query_executer,
                                             self._on_success,
                                             self._on_failure)
        self._query_executer.assert_called_once()
Beispiel #10
0
    def test_batch_selections(self):
        """Test that the query sets up batch objects with correct selection information"""
        full_query = '''select * from
t1;
select * from t2;;;
;  ;
select version(); select * from
t3 ;
select * from t2
'''

        # If I build a query that contains several statements
        query = Query('test_uri', full_query,
                      QueryExecutionSettings(ExecutionPlanOptions(), None),
                      QueryEvents())

        # Then there is a batch for each non-empty statement
        self.assertEqual(len(query.batches), 5)

        # And each batch should have the correct location information
        expected_selections = [
            SelectionData(start_line=0,
                          start_column=0,
                          end_line=1,
                          end_column=3),
            SelectionData(start_line=2,
                          start_column=0,
                          end_line=2,
                          end_column=17),
            SelectionData(start_line=4,
                          start_column=0,
                          end_line=4,
                          end_column=17),
            SelectionData(start_line=4,
                          start_column=18,
                          end_line=5,
                          end_column=4),
            SelectionData(start_line=6,
                          start_column=0,
                          end_line=6,
                          end_column=16)
        ]
        for index, batch in enumerate(query.batches):
            self.assertEqual(
                _tuple_from_selection_data(batch.selection),
                _tuple_from_selection_data(expected_selections[index]))
Beispiel #11
0
    def _start_query_execution_thread(
            self,
            request_context: RequestContext,
            params: ExecuteRequestParamsBase,
            worker_args: ExecuteRequestWorkerArgs = None):

        # Set up batch execution callback methods for sending notifications
        def _batch_execution_started_callback(batch: Batch) -> None:
            batch_event_params = BatchNotificationParams(
                batch.batch_summary, worker_args.owner_uri)
            _check_and_fire(worker_args.on_batch_start, batch_event_params)

        def _batch_execution_finished_callback(batch: Batch) -> None:
            # Send back notices as a separate message to avoid error coloring / highlighting of text
            notices = batch.notices
            if notices:
                notice_message_params = self.build_message_params(
                    worker_args.owner_uri, batch.id, ''.join(notices), False)
                _check_and_fire(worker_args.on_message_notification,
                                notice_message_params)

            batch_summary = batch.batch_summary

            # send query/resultSetComplete response
            result_set_params = self.build_result_set_complete_params(
                batch_summary, worker_args.owner_uri)
            _check_and_fire(worker_args.on_resultset_complete,
                            result_set_params)

            # If the batch was successful, send a message to the client
            if not batch.has_error:
                rows_message = _create_rows_affected_message(batch)
                message_params = self.build_message_params(
                    worker_args.owner_uri, batch.id, rows_message, False)
                _check_and_fire(worker_args.on_message_notification,
                                message_params)

            # send query/batchComplete and query/complete response
            batch_event_params = BatchNotificationParams(
                batch_summary, worker_args.owner_uri)
            _check_and_fire(worker_args.on_batch_complete, batch_event_params)

        # Create a new query if one does not already exist or we already executed the previous one
        if params.owner_uri not in self.query_results or self.query_results[
                params.owner_uri].execution_state is ExecutionState.EXECUTED:
            query_text = self._get_query_text_from_execute_params(params)

            execution_settings = QueryExecutionSettings(
                params.execution_plan_options,
                worker_args.result_set_storage_type)
            query_events = QueryEvents(
                None, None,
                BatchEvents(_batch_execution_started_callback,
                            _batch_execution_finished_callback))
            self.query_results[params.owner_uri] = Query(
                params.owner_uri, query_text, execution_settings, query_events)
        elif self.query_results[
                params.owner_uri].execution_state is ExecutionState.EXECUTING:
            request_context.send_error(
                'Another query is currently executing.')  # TODO: Localize
            return

        thread = threading.Thread(target=self._execute_query_request_worker,
                                  args=(worker_args, ))
        self.owner_to_thread_map[params.owner_uri] = thread
        thread.daemon = True
        thread.start()
Beispiel #12
0
class TestQuery(unittest.TestCase):
    """Unit tests for Query and Batch objects"""
    def setUp(self):
        """Set up the test by creating a query with multiple batches"""
        self.statement_list = statement_list = [
            'select version;', 'select * from t1;'
        ]
        self.statement_str = ''.join(statement_list)
        self.query_uri = 'test_uri'
        self.query = Query(
            self.query_uri, self.statement_str,
            QueryExecutionSettings(ExecutionPlanOptions(),
                                   ResultSetStorageType.FILE_STORAGE),
            QueryEvents())

        self.mock_query_results = [('Id1', 'Value1'), ('Id2', 'Value2')]
        self.cursor = utils.MockCursor(self.mock_query_results)
        self.connection = utils.MockConnection(cursor=self.cursor)

        self.columns_info = []
        db_column_id = DbColumn()
        db_column_id.data_type = 'text'
        db_column_id.column_name = 'Id'
        db_column_value = DbColumn()
        db_column_value.data_type = 'text'
        db_column_value.column_name = 'Value'
        self.columns_info = [db_column_id, db_column_value]
        self.get_columns_info_mock = mock.Mock(return_value=self.columns_info)

    def test_query_creates_batches(self):
        """Test that creating a query also creates batches for each statement in the query"""
        # Verify that the query created in setUp has a batch corresponding to each statement
        for index, statement in enumerate(self.statement_list):
            self.assertEqual(self.query.batches[index].batch_text, statement)

    def test_executing_query_executes_batches(self):
        """Test that executing a query also executes all of the query's batches in order"""

        # If I call query.execute
        with mock.patch(
                'pgsqltoolsservice.query.data_storage.storage_data_reader.get_columns_info',
                new=self.get_columns_info_mock):
            self.query.execute(self.connection)

        # Then each of the batches executed in order
        expected_calls = [
            mock.call(statement) for statement in self.statement_list
        ]

        self.cursor.execute.assert_has_calls(expected_calls)
        self.assertEqual(len(self.cursor.execute.mock_calls), 2)

        # And each of the batches holds the expected results
        for batch in self.query.batches:
            for index in range(0, batch.result_set.row_count):
                current_row = batch.result_set.get_row(index)
                row_tuple = ()
                for cell in current_row:
                    row_tuple += (cell.display_value, )
                self.assertEqual(row_tuple, self.mock_query_results[index])

        # And the query is marked as executed
        self.assertIs(self.query.execution_state, ExecutionState.EXECUTED)

    def test_batch_failure(self):
        """Test that query execution handles a batch execution failure by stopping further execution"""
        # Set up the cursor to fail when executed
        self.cursor.execute.side_effect = self.cursor.execute_failure_side_effects

        # If I call query.execute then it raises the database error
        with self.assertRaises(psycopg2.DatabaseError):
            self.query.execute(self.connection)

        # And only the first batch was executed
        expected_calls = [mock.call(self.statement_list[0])]
        self.cursor.execute.assert_has_calls(expected_calls)
        self.assertEqual(len(self.cursor.execute.mock_calls), 1)

        # And the query is marked as executed
        self.assertIs(self.query.execution_state, ExecutionState.EXECUTED)

    def test_batch_selections(self):
        """Test that the query sets up batch objects with correct selection information"""
        full_query = '''select * from
t1;
select * from t2;;;
;  ;
select version(); select * from
t3 ;
select * from t2
'''

        # If I build a query that contains several statements
        query = Query('test_uri', full_query,
                      QueryExecutionSettings(ExecutionPlanOptions(), None),
                      QueryEvents())

        # Then there is a batch for each non-empty statement
        self.assertEqual(len(query.batches), 5)

        # And each batch should have the correct location information
        expected_selections = [
            SelectionData(start_line=0,
                          start_column=0,
                          end_line=1,
                          end_column=3),
            SelectionData(start_line=2,
                          start_column=0,
                          end_line=2,
                          end_column=17),
            SelectionData(start_line=4,
                          start_column=0,
                          end_line=4,
                          end_column=17),
            SelectionData(start_line=4,
                          start_column=18,
                          end_line=5,
                          end_column=4),
            SelectionData(start_line=6,
                          start_column=0,
                          end_line=6,
                          end_column=16)
        ]
        for index, batch in enumerate(query.batches):
            self.assertEqual(
                _tuple_from_selection_data(batch.selection),
                _tuple_from_selection_data(expected_selections[index]))

    def test_batch_selections_do_block(self):
        """Test that the query sets up batch objects with correct selection information for blocks containing statements"""
        full_query = '''DO $$
BEGIN
RAISE NOTICE 'Hello world 1';
RAISE NOTICE 'Hello world 2';
END $$;
select * from t1;'''

        # If I build a query that contains a block that contains several statements
        query = Query('test_uri', full_query,
                      QueryExecutionSettings(ExecutionPlanOptions(), None),
                      QueryEvents())

        # Then there is a batch for each top-level statement
        self.assertEqual(len(query.batches), 2)

        # And each batch should have the correct location information
        expected_selections = [
            SelectionData(start_line=0,
                          start_column=0,
                          end_line=4,
                          end_column=7),
            SelectionData(start_line=5,
                          start_column=0,
                          end_line=5,
                          end_column=17)
        ]
        for index, batch in enumerate(query.batches):
            self.assertEqual(
                _tuple_from_selection_data(batch.selection),
                _tuple_from_selection_data(expected_selections[index]))

    def test_batches_strip_comments(self):
        """Test that we do not attempt to execute a batch consisting only of comments"""
        full_query = '''select * from t1;
-- test
-- test
;select * from t1;
-- test
-- test;'''

        # If I build a query that contains a batch consisting of only comments, in addition to other valid batches
        query = Query('test_uri', full_query,
                      QueryExecutionSettings(ExecutionPlanOptions(), None),
                      QueryEvents())

        # Then there is only a batch for each non-comment statement
        self.assertEqual(len(query.batches), 2)

        # And each batch should have the correct location information
        expected_selections = [
            SelectionData(start_line=0,
                          start_column=0,
                          end_line=0,
                          end_column=17),
            SelectionData(start_line=3,
                          start_column=1,
                          end_line=3,
                          end_column=18)
        ]

        for index, batch in enumerate(query.batches):
            self.assertEqual(
                _tuple_from_selection_data(batch.selection),
                _tuple_from_selection_data(expected_selections[index]))

    def test_hash_character_processed_correctly(self):
        """Test that xor operator is not taken for an inline comment delimiter"""
        full_query = "select 42 # 24;"
        query = Query('test_uri', full_query,
                      QueryExecutionSettings(ExecutionPlanOptions(), None),
                      QueryEvents())

        self.assertEqual(len(query.batches), 1)
        self.assertEqual(full_query, query.batches[0].batch_text)

    def execute_get_subset_raises_error_when_index_not_in_range(
            self, batch_index: int):
        full_query = 'Select * from t1;'
        query = Query('test_uri', full_query,
                      QueryExecutionSettings(ExecutionPlanOptions(), None),
                      QueryEvents())

        with self.assertRaises(IndexError) as context_manager:
            query.get_subset(batch_index, 0, 10)
            self.assertEquals(
                'Batch index cannot be less than 0 or greater than the number of batches',
                context_manager.exception.args[0])

    def test_get_subset_raises_error_when_index_is_negetive(self):
        self.execute_get_subset_raises_error_when_index_not_in_range(-1)

    def test_get_subset_raises_error_when_index_is_greater_than_batch_size(
            self):
        self.execute_get_subset_raises_error_when_index_not_in_range(20)

    def test_get_subset(self):
        full_query = 'Select * from t1;'
        query = Query('test_uri', full_query,
                      QueryExecutionSettings(ExecutionPlanOptions(), None),
                      QueryEvents())
        expected_subset = []

        mock_batch = mock.MagicMock()
        mock_batch.get_subset = mock.Mock(return_value=expected_subset)

        query._batches = [mock_batch]

        subset = query.get_subset(0, 0, 10)

        self.assertEqual(expected_subset, subset)
        mock_batch.get_subset.assert_called_once_with(0, 10)

    def test_save_as_with_invalid_batch_index(self):
        def execute_with_batch_index(index: int):
            params = SaveResultsRequestParams()
            params.batch_index = index

            with self.assertRaises(IndexError) as context_manager:
                self.query.save_as(params, None, None, None)
                self.assertEquals(
                    'Batch index cannot be less than 0 or greater than the number of batches',
                    context_manager.exception.args[0])

        execute_with_batch_index(-1)

        execute_with_batch_index(2)

    def test_save_as(self):
        params = SaveResultsRequestParams()
        params.batch_index = 0

        file_factory = mock.MagicMock()
        on_success = mock.MagicMock()
        on_error = mock.MagicMock()

        batch_save_as_mock = mock.MagicMock()
        self.query.batches[0].save_as = batch_save_as_mock

        self.query.save_as(params, file_factory, on_success, on_error)

        batch_save_as_mock.assert_called_once_with(params, file_factory,
                                                   on_success, on_error)