Exemplo n.º 1
0
    def _resolve_query_exception(self, e: Exception, query: Query, worker_args: ExecuteRequestWorkerArgs, is_rollback_error=False):
        utils.log.log_debug(self._service_provider.logger, f'Query execution failed for following query: {query.query_text}\n {e}')

        # If the error relates to the database, display the appropriate error message based on the provider
        if isinstance(e, worker_args.connection.database_error) or isinstance(e, worker_args.connection.query_canceled_error):
            # get_error_message may return None so ensure error_message is str type
            error_message = str(worker_args.connection.get_error_message(e))

        elif isinstance(e, RuntimeError):
            error_message = str(e)

        else:
            error_message = 'Unhandled exception while executing query: {}'.format(str(e))  # TODO: Localize
            if self._service_provider.logger is not None:
                self._service_provider.logger.exception('Unhandled exception while executing query')

        # If the error occured during rollback, add a note about it
        if is_rollback_error:
            error_message = 'Error while rolling back open transaction due to previous failure: ' + error_message  # TODO: Localize

        # Send a message with the error to the client
        result_message_params = self.build_message_params(query.owner_uri, query.batches[query.current_batch_index].id, error_message, True)
        _check_and_fire(worker_args.on_message_notification, result_message_params)

        # If there was a failure in the middle of a transaction, roll it back.
        # Note that conn.rollback() won't work since the connection is in autocommit mode
        if not is_rollback_error and worker_args.connection.transaction_in_error:
            rollback_query = Query(query.owner_uri, 'ROLLBACK', QueryExecutionSettings(ExecutionPlanOptions(), None), QueryEvents())
            try:
                rollback_query.execute(worker_args.connection)
            except Exception as rollback_exception:
                # If the rollback failed, handle the error as usual but don't try to roll back again
                self._resolve_query_exception(rollback_exception, rollback_query, worker_args, True)
Exemplo n.º 2
0
class TestQuery(unittest.TestCase):
    """Unit tests for Query and Batch objects"""
    def setUp(self):
        """Set up the test by creating a query with multiple batches"""
        self.statement_list = statement_list = [
            'select version;', 'select * from t1;'
        ]
        self.statement_str = ''.join(statement_list)
        self.query_uri = 'test_uri'
        self.query = Query(
            self.query_uri, self.statement_str,
            QueryExecutionSettings(ExecutionPlanOptions(),
                                   ResultSetStorageType.FILE_STORAGE),
            QueryEvents())

        self.mock_query_results = [('Id1', 'Value1'), ('Id2', 'Value2')]
        self.cursor = MockCursor(self.mock_query_results)
        self.connection = MockPGServerConnection(cur=self.cursor)

        self.columns_info = []
        db_column_id = DbColumn()
        db_column_id.data_type = 'text'
        db_column_id.column_name = 'Id'
        db_column_id.provider = PG_PROVIDER_NAME
        db_column_value = DbColumn()
        db_column_value.data_type = 'text'
        db_column_value.column_name = 'Value'
        db_column_value.provider = PG_PROVIDER_NAME
        self.columns_info = [db_column_id, db_column_value]
        self.get_columns_info_mock = mock.Mock(return_value=self.columns_info)

    def test_query_creates_batches(self):
        """Test that creating a query also creates batches for each statement in the query"""
        # Verify that the query created in setUp has a batch corresponding to each statement
        for index, statement in enumerate(self.statement_list):
            self.assertEqual(self.query.batches[index].batch_text, statement)

    def test_executing_query_executes_batches(self):
        """Test that executing a query also executes all of the query's batches in order"""

        # If I call query.execute
        with mock.patch(
                'ossdbtoolsservice.query.data_storage.storage_data_reader.get_columns_info',
                new=self.get_columns_info_mock):
            self.query.execute(self.connection)

        # Then each of the batches executed in order
        expected_calls = [
            mock.call(statement) for statement in self.statement_list
        ]

        self.cursor.execute.assert_has_calls(expected_calls)
        self.assertEqual(len(self.cursor.execute.mock_calls), 2)

        # And each of the batches holds the expected results
        for batch in self.query.batches:
            for index in range(0, batch.result_set.row_count):
                current_row = batch.result_set.get_row(index)
                row_tuple = ()
                for cell in current_row:
                    row_tuple += (cell.display_value, )
                self.assertEqual(row_tuple, self.mock_query_results[index])

        # And the query is marked as executed
        self.assertIs(self.query.execution_state, ExecutionState.EXECUTED)

    def test_batch_failure(self):
        """Test that query execution handles a batch execution failure by stopping further execution"""
        # Set up the cursor to fail when executed
        self.cursor.execute.side_effect = self.cursor.execute_failure_side_effects

        # If I call query.execute then it raises the database error
        with self.assertRaises(psycopg2.DatabaseError):
            self.query.execute(self.connection)

        # And only the first batch was executed
        expected_calls = [mock.call(self.statement_list[0])]
        self.cursor.execute.assert_has_calls(expected_calls)
        self.assertEqual(len(self.cursor.execute.mock_calls), 1)

        # And the query is marked as executed
        self.assertIs(self.query.execution_state, ExecutionState.EXECUTED)

    def test_batch_selections(self):
        """Test that the query sets up batch objects with correct selection information"""
        full_query = '''select * from
t1;
select * from t2;;;
;  ;
select version(); select * from
t3 ;
select * from t2
'''

        # If I build a query that contains several statements
        query = Query('test_uri', full_query,
                      QueryExecutionSettings(ExecutionPlanOptions(), None),
                      QueryEvents())

        # Then there is a batch for each non-empty statement
        self.assertEqual(len(query.batches), 5)

        # And each batch should have the correct location information
        expected_selections = [
            SelectionData(start_line=0,
                          start_column=0,
                          end_line=1,
                          end_column=3),
            SelectionData(start_line=2,
                          start_column=0,
                          end_line=2,
                          end_column=17),
            SelectionData(start_line=4,
                          start_column=0,
                          end_line=4,
                          end_column=17),
            SelectionData(start_line=4,
                          start_column=18,
                          end_line=5,
                          end_column=4),
            SelectionData(start_line=6,
                          start_column=0,
                          end_line=6,
                          end_column=16)
        ]
        for index, batch in enumerate(query.batches):
            self.assertEqual(
                _tuple_from_selection_data(batch.selection),
                _tuple_from_selection_data(expected_selections[index]))

    def test_batch_selections_do_block(self):
        """Test that the query sets up batch objects with correct selection information for blocks containing statements"""
        full_query = '''DO $$
BEGIN
RAISE NOTICE 'Hello world 1';
RAISE NOTICE 'Hello world 2';
END $$;
select * from t1;'''

        # If I build a query that contains a block that contains several statements
        query = Query('test_uri', full_query,
                      QueryExecutionSettings(ExecutionPlanOptions(), None),
                      QueryEvents())

        # Then there is a batch for each top-level statement
        self.assertEqual(len(query.batches), 2)

        # And each batch should have the correct location information
        expected_selections = [
            SelectionData(start_line=0,
                          start_column=0,
                          end_line=4,
                          end_column=7),
            SelectionData(start_line=5,
                          start_column=0,
                          end_line=5,
                          end_column=17)
        ]
        for index, batch in enumerate(query.batches):
            self.assertEqual(
                _tuple_from_selection_data(batch.selection),
                _tuple_from_selection_data(expected_selections[index]))

    def test_batches_strip_comments(self):
        """Test that we do not attempt to execute a batch consisting only of comments"""
        full_query = '''select * from t1;
-- test
-- test
;select * from t1;
-- test
-- test;'''

        # If I build a query that contains a batch consisting of only comments, in addition to other valid batches
        query = Query('test_uri', full_query,
                      QueryExecutionSettings(ExecutionPlanOptions(), None),
                      QueryEvents())

        # Then there is only a batch for each non-comment statement
        self.assertEqual(len(query.batches), 2)

        # And each batch should have the correct location information
        expected_selections = [
            SelectionData(start_line=0,
                          start_column=0,
                          end_line=0,
                          end_column=17),
            SelectionData(start_line=3,
                          start_column=1,
                          end_line=3,
                          end_column=18)
        ]

        for index, batch in enumerate(query.batches):
            self.assertEqual(
                _tuple_from_selection_data(batch.selection),
                _tuple_from_selection_data(expected_selections[index]))

    def test_hash_character_processed_correctly(self):
        """Test that xor operator is not taken for an inline comment delimiter"""
        full_query = "select 42 # 24;"
        query = Query('test_uri', full_query,
                      QueryExecutionSettings(ExecutionPlanOptions(), None),
                      QueryEvents())

        self.assertEqual(len(query.batches), 1)
        self.assertEqual(full_query, query.batches[0].batch_text)

    def execute_get_subset_raises_error_when_index_not_in_range(
            self, batch_index: int):
        full_query = 'Select * from t1;'
        query = Query('test_uri', full_query,
                      QueryExecutionSettings(ExecutionPlanOptions(), None),
                      QueryEvents())

        with self.assertRaises(IndexError) as context_manager:
            query.get_subset(batch_index, 0, 10)
            self.assertEquals(
                'Batch index cannot be less than 0 or greater than the number of batches',
                context_manager.exception.args[0])

    def test_get_subset_raises_error_when_index_is_negetive(self):
        self.execute_get_subset_raises_error_when_index_not_in_range(-1)

    def test_get_subset_raises_error_when_index_is_greater_than_batch_size(
            self):
        self.execute_get_subset_raises_error_when_index_not_in_range(20)

    def test_get_subset(self):
        full_query = 'Select * from t1;'
        query = Query('test_uri', full_query,
                      QueryExecutionSettings(ExecutionPlanOptions(), None),
                      QueryEvents())
        expected_subset = []

        mock_batch = mock.MagicMock()
        mock_batch.get_subset = mock.Mock(return_value=expected_subset)

        query._batches = [mock_batch]

        subset = query.get_subset(0, 0, 10)

        self.assertEqual(expected_subset, subset)
        mock_batch.get_subset.assert_called_once_with(0, 10)

    def test_save_as_with_invalid_batch_index(self):
        def execute_with_batch_index(index: int):
            params = SaveResultsRequestParams()
            params.batch_index = index

            with self.assertRaises(IndexError) as context_manager:
                self.query.save_as(params, None, None, None)
                self.assertEquals(
                    'Batch index cannot be less than 0 or greater than the number of batches',
                    context_manager.exception.args[0])

        execute_with_batch_index(-1)

        execute_with_batch_index(2)

    def test_save_as(self):
        params = SaveResultsRequestParams()
        params.batch_index = 0

        file_factory = mock.MagicMock()
        on_success = mock.MagicMock()
        on_error = mock.MagicMock()

        batch_save_as_mock = mock.MagicMock()
        self.query.batches[0].save_as = batch_save_as_mock

        self.query.save_as(params, file_factory, on_success, on_error)

        batch_save_as_mock.assert_called_once_with(params, file_factory,
                                                   on_success, on_error)