def _resolve_query_exception(self, e: Exception, query: Query, worker_args: ExecuteRequestWorkerArgs, is_rollback_error=False): utils.log.log_debug(self._service_provider.logger, f'Query execution failed for following query: {query.query_text}\n {e}') # If the error relates to the database, display the appropriate error message based on the provider if isinstance(e, worker_args.connection.database_error) or isinstance(e, worker_args.connection.query_canceled_error): # get_error_message may return None so ensure error_message is str type error_message = str(worker_args.connection.get_error_message(e)) elif isinstance(e, RuntimeError): error_message = str(e) else: error_message = 'Unhandled exception while executing query: {}'.format(str(e)) # TODO: Localize if self._service_provider.logger is not None: self._service_provider.logger.exception('Unhandled exception while executing query') # If the error occured during rollback, add a note about it if is_rollback_error: error_message = 'Error while rolling back open transaction due to previous failure: ' + error_message # TODO: Localize # Send a message with the error to the client result_message_params = self.build_message_params(query.owner_uri, query.batches[query.current_batch_index].id, error_message, True) _check_and_fire(worker_args.on_message_notification, result_message_params) # If there was a failure in the middle of a transaction, roll it back. # Note that conn.rollback() won't work since the connection is in autocommit mode if not is_rollback_error and worker_args.connection.transaction_in_error: rollback_query = Query(query.owner_uri, 'ROLLBACK', QueryExecutionSettings(ExecutionPlanOptions(), None), QueryEvents()) try: rollback_query.execute(worker_args.connection) except Exception as rollback_exception: # If the rollback failed, handle the error as usual but don't try to roll back again self._resolve_query_exception(rollback_exception, rollback_query, worker_args, True)
def setUp(self): """Set up the test by creating a query with multiple batches""" self.statement_list = statement_list = [ 'select version;', 'select * from t1;' ] self.statement_str = ''.join(statement_list) self.query_uri = 'test_uri' self.query = Query( self.query_uri, self.statement_str, QueryExecutionSettings(ExecutionPlanOptions(), ResultSetStorageType.FILE_STORAGE), QueryEvents()) self.mock_query_results = [('Id1', 'Value1'), ('Id2', 'Value2')] self.cursor = MockCursor(self.mock_query_results) self.connection = MockPGServerConnection(cur=self.cursor) self.columns_info = [] db_column_id = DbColumn() db_column_id.data_type = 'text' db_column_id.column_name = 'Id' db_column_id.provider = PG_PROVIDER_NAME db_column_value = DbColumn() db_column_value.data_type = 'text' db_column_value.column_name = 'Value' db_column_value.provider = PG_PROVIDER_NAME self.columns_info = [db_column_id, db_column_value] self.get_columns_info_mock = mock.Mock(return_value=self.columns_info)
def execute_get_subset_raises_error_when_index_not_in_range( self, batch_index: int): full_query = 'Select * from t1;' query = Query('test_uri', full_query, QueryExecutionSettings(ExecutionPlanOptions(), None), QueryEvents()) with self.assertRaises(IndexError) as context_manager: query.get_subset(batch_index, 0, 10) self.assertEquals( 'Batch index cannot be less than 0 or greater than the number of batches', context_manager.exception.args[0])
def test_initialize_calls_success(self): query = Query('owner', '', QueryExecutionSettings(None, None), QueryEvents()) query._execution_state = ExecutionState.EXECUTED rows = [("Result1", 53), ("Result2", None,)] result_set = self.get_result_set(rows) batch = Batch('', 1, None) batch._result_set = result_set query._batches = [batch] self._query_executer = mock.MagicMock(return_value=DataEditSessionExecutionState(query)) self._data_editor_session.initialize(self._initialize_edit_request, self._connection, self._query_executer, self._on_success, self._on_failure) self._query_executer.assert_called_once()
def test_get_subset(self): full_query = 'Select * from t1;' query = Query('test_uri', full_query, QueryExecutionSettings(ExecutionPlanOptions(), None), QueryEvents()) expected_subset = [] mock_batch = mock.MagicMock() mock_batch.get_subset = mock.Mock(return_value=expected_subset) query._batches = [mock_batch] subset = query.get_subset(0, 0, 10) self.assertEqual(expected_subset, subset) mock_batch.get_subset.assert_called_once_with(0, 10)
def test_batches_strip_comments(self): """Test that we do not attempt to execute a batch consisting only of comments""" full_query = '''select * from t1; -- test -- test ;select * from t1; -- test -- test;''' # If I build a query that contains a batch consisting of only comments, in addition to other valid batches query = Query('test_uri', full_query, QueryExecutionSettings(ExecutionPlanOptions(), None), QueryEvents()) # Then there is only a batch for each non-comment statement self.assertEqual(len(query.batches), 2) # And each batch should have the correct location information expected_selections = [ SelectionData(start_line=0, start_column=0, end_line=0, end_column=17), SelectionData(start_line=3, start_column=1, end_line=3, end_column=18) ] for index, batch in enumerate(query.batches): self.assertEqual( _tuple_from_selection_data(batch.selection), _tuple_from_selection_data(expected_selections[index]))
def test_batch_selections_do_block(self): """Test that the query sets up batch objects with correct selection information for blocks containing statements""" full_query = '''DO $$ BEGIN RAISE NOTICE 'Hello world 1'; RAISE NOTICE 'Hello world 2'; END $$; select * from t1;''' # If I build a query that contains a block that contains several statements query = Query('test_uri', full_query, QueryExecutionSettings(ExecutionPlanOptions(), None), QueryEvents()) # Then there is a batch for each top-level statement self.assertEqual(len(query.batches), 2) # And each batch should have the correct location information expected_selections = [ SelectionData(start_line=0, start_column=0, end_line=4, end_column=7), SelectionData(start_line=5, start_column=0, end_line=5, end_column=17) ] for index, batch in enumerate(query.batches): self.assertEqual( _tuple_from_selection_data(batch.selection), _tuple_from_selection_data(expected_selections[index]))
def test_initialize_calls_failure_when_query_status_is_not_executed(self): query = Query('owner', '', QueryExecutionSettings(None, None), QueryEvents()) self._query_executer = mock.MagicMock(return_value=DataEditSessionExecutionState(query)) self._data_editor_session.initialize(self._initialize_edit_request, self._connection, self._query_executer, self._on_success, self._on_failure) self._query_executer.assert_called_once()
def test_hash_character_processed_correctly(self): """Test that xor operator is not taken for an inline comment delimiter""" full_query = "select 42 # 24;" query = Query('test_uri', full_query, QueryExecutionSettings(ExecutionPlanOptions(), None), QueryEvents()) self.assertEqual(len(query.batches), 1) self.assertEqual(full_query, query.batches[0].batch_text)
def _start_query_execution_thread(self, request_context: RequestContext, params: ExecuteRequestParamsBase, worker_args: ExecuteRequestWorkerArgs = None): # Set up batch execution callback methods for sending notifications def _batch_execution_started_callback(batch: Batch) -> None: batch_event_params = BatchNotificationParams(batch.batch_summary, worker_args.owner_uri) _check_and_fire(worker_args.on_batch_start, batch_event_params) def _batch_execution_finished_callback(batch: Batch) -> None: # Send back notices as a separate message to avoid error coloring / highlighting of text notices = batch.notices if notices: notice_message_params = self.build_message_params(worker_args.owner_uri, batch.id, ''.join(notices), False) _check_and_fire(worker_args.on_message_notification, notice_message_params) batch_summary = batch.batch_summary # send query/resultSetComplete response result_set_params = self.build_result_set_complete_params(batch_summary, worker_args.owner_uri) _check_and_fire(worker_args.on_resultset_complete, result_set_params) # If the batch was successful, send a message to the client if not batch.has_error: rows_message = _create_rows_affected_message(batch) message_params = self.build_message_params(worker_args.owner_uri, batch.id, rows_message, False) _check_and_fire(worker_args.on_message_notification, message_params) # send query/batchComplete and query/complete response batch_event_params = BatchNotificationParams(batch_summary, worker_args.owner_uri) _check_and_fire(worker_args.on_batch_complete, batch_event_params) # Create a new query if one does not already exist or we already executed the previous one if params.owner_uri not in self.query_results or self.query_results[params.owner_uri].execution_state is ExecutionState.EXECUTED: query_text = self._get_query_text_from_execute_params(params) execution_settings = QueryExecutionSettings(params.execution_plan_options, worker_args.result_set_storage_type) query_events = QueryEvents(None, None, BatchEvents(_batch_execution_started_callback, _batch_execution_finished_callback)) self.query_results[params.owner_uri] = Query(params.owner_uri, query_text, execution_settings, query_events) elif self.query_results[params.owner_uri].execution_state is ExecutionState.EXECUTING: request_context.send_error('Another query is currently executing.') # TODO: Localize return thread = threading.Thread( target=self._execute_query_request_worker, args=(worker_args,) ) self.owner_to_thread_map[params.owner_uri] = thread thread.daemon = True thread.start()
def test_batch_selections(self): """Test that the query sets up batch objects with correct selection information""" full_query = '''select * from t1; select * from t2;;; ; ; select version(); select * from t3 ; select * from t2 ''' # If I build a query that contains several statements query = Query('test_uri', full_query, QueryExecutionSettings(ExecutionPlanOptions(), None), QueryEvents()) # Then there is a batch for each non-empty statement self.assertEqual(len(query.batches), 5) # And each batch should have the correct location information expected_selections = [ SelectionData(start_line=0, start_column=0, end_line=1, end_column=3), SelectionData(start_line=2, start_column=0, end_line=2, end_column=17), SelectionData(start_line=4, start_column=0, end_line=4, end_column=17), SelectionData(start_line=4, start_column=18, end_line=5, end_column=4), SelectionData(start_line=6, start_column=0, end_line=6, end_column=16) ] for index, batch in enumerate(query.batches): self.assertEqual( _tuple_from_selection_data(batch.selection), _tuple_from_selection_data(expected_selections[index]))
class TestQuery(unittest.TestCase): """Unit tests for Query and Batch objects""" def setUp(self): """Set up the test by creating a query with multiple batches""" self.statement_list = statement_list = [ 'select version;', 'select * from t1;' ] self.statement_str = ''.join(statement_list) self.query_uri = 'test_uri' self.query = Query( self.query_uri, self.statement_str, QueryExecutionSettings(ExecutionPlanOptions(), ResultSetStorageType.FILE_STORAGE), QueryEvents()) self.mock_query_results = [('Id1', 'Value1'), ('Id2', 'Value2')] self.cursor = MockCursor(self.mock_query_results) self.connection = MockPGServerConnection(cur=self.cursor) self.columns_info = [] db_column_id = DbColumn() db_column_id.data_type = 'text' db_column_id.column_name = 'Id' db_column_id.provider = PG_PROVIDER_NAME db_column_value = DbColumn() db_column_value.data_type = 'text' db_column_value.column_name = 'Value' db_column_value.provider = PG_PROVIDER_NAME self.columns_info = [db_column_id, db_column_value] self.get_columns_info_mock = mock.Mock(return_value=self.columns_info) def test_query_creates_batches(self): """Test that creating a query also creates batches for each statement in the query""" # Verify that the query created in setUp has a batch corresponding to each statement for index, statement in enumerate(self.statement_list): self.assertEqual(self.query.batches[index].batch_text, statement) def test_executing_query_executes_batches(self): """Test that executing a query also executes all of the query's batches in order""" # If I call query.execute with mock.patch( 'ossdbtoolsservice.query.data_storage.storage_data_reader.get_columns_info', new=self.get_columns_info_mock): self.query.execute(self.connection) # Then each of the batches executed in order expected_calls = [ mock.call(statement) for statement in self.statement_list ] self.cursor.execute.assert_has_calls(expected_calls) self.assertEqual(len(self.cursor.execute.mock_calls), 2) # And each of the batches holds the expected results for batch in self.query.batches: for index in range(0, batch.result_set.row_count): current_row = batch.result_set.get_row(index) row_tuple = () for cell in current_row: row_tuple += (cell.display_value, ) self.assertEqual(row_tuple, self.mock_query_results[index]) # And the query is marked as executed self.assertIs(self.query.execution_state, ExecutionState.EXECUTED) def test_batch_failure(self): """Test that query execution handles a batch execution failure by stopping further execution""" # Set up the cursor to fail when executed self.cursor.execute.side_effect = self.cursor.execute_failure_side_effects # If I call query.execute then it raises the database error with self.assertRaises(psycopg2.DatabaseError): self.query.execute(self.connection) # And only the first batch was executed expected_calls = [mock.call(self.statement_list[0])] self.cursor.execute.assert_has_calls(expected_calls) self.assertEqual(len(self.cursor.execute.mock_calls), 1) # And the query is marked as executed self.assertIs(self.query.execution_state, ExecutionState.EXECUTED) def test_batch_selections(self): """Test that the query sets up batch objects with correct selection information""" full_query = '''select * from t1; select * from t2;;; ; ; select version(); select * from t3 ; select * from t2 ''' # If I build a query that contains several statements query = Query('test_uri', full_query, QueryExecutionSettings(ExecutionPlanOptions(), None), QueryEvents()) # Then there is a batch for each non-empty statement self.assertEqual(len(query.batches), 5) # And each batch should have the correct location information expected_selections = [ SelectionData(start_line=0, start_column=0, end_line=1, end_column=3), SelectionData(start_line=2, start_column=0, end_line=2, end_column=17), SelectionData(start_line=4, start_column=0, end_line=4, end_column=17), SelectionData(start_line=4, start_column=18, end_line=5, end_column=4), SelectionData(start_line=6, start_column=0, end_line=6, end_column=16) ] for index, batch in enumerate(query.batches): self.assertEqual( _tuple_from_selection_data(batch.selection), _tuple_from_selection_data(expected_selections[index])) def test_batch_selections_do_block(self): """Test that the query sets up batch objects with correct selection information for blocks containing statements""" full_query = '''DO $$ BEGIN RAISE NOTICE 'Hello world 1'; RAISE NOTICE 'Hello world 2'; END $$; select * from t1;''' # If I build a query that contains a block that contains several statements query = Query('test_uri', full_query, QueryExecutionSettings(ExecutionPlanOptions(), None), QueryEvents()) # Then there is a batch for each top-level statement self.assertEqual(len(query.batches), 2) # And each batch should have the correct location information expected_selections = [ SelectionData(start_line=0, start_column=0, end_line=4, end_column=7), SelectionData(start_line=5, start_column=0, end_line=5, end_column=17) ] for index, batch in enumerate(query.batches): self.assertEqual( _tuple_from_selection_data(batch.selection), _tuple_from_selection_data(expected_selections[index])) def test_batches_strip_comments(self): """Test that we do not attempt to execute a batch consisting only of comments""" full_query = '''select * from t1; -- test -- test ;select * from t1; -- test -- test;''' # If I build a query that contains a batch consisting of only comments, in addition to other valid batches query = Query('test_uri', full_query, QueryExecutionSettings(ExecutionPlanOptions(), None), QueryEvents()) # Then there is only a batch for each non-comment statement self.assertEqual(len(query.batches), 2) # And each batch should have the correct location information expected_selections = [ SelectionData(start_line=0, start_column=0, end_line=0, end_column=17), SelectionData(start_line=3, start_column=1, end_line=3, end_column=18) ] for index, batch in enumerate(query.batches): self.assertEqual( _tuple_from_selection_data(batch.selection), _tuple_from_selection_data(expected_selections[index])) def test_hash_character_processed_correctly(self): """Test that xor operator is not taken for an inline comment delimiter""" full_query = "select 42 # 24;" query = Query('test_uri', full_query, QueryExecutionSettings(ExecutionPlanOptions(), None), QueryEvents()) self.assertEqual(len(query.batches), 1) self.assertEqual(full_query, query.batches[0].batch_text) def execute_get_subset_raises_error_when_index_not_in_range( self, batch_index: int): full_query = 'Select * from t1;' query = Query('test_uri', full_query, QueryExecutionSettings(ExecutionPlanOptions(), None), QueryEvents()) with self.assertRaises(IndexError) as context_manager: query.get_subset(batch_index, 0, 10) self.assertEquals( 'Batch index cannot be less than 0 or greater than the number of batches', context_manager.exception.args[0]) def test_get_subset_raises_error_when_index_is_negetive(self): self.execute_get_subset_raises_error_when_index_not_in_range(-1) def test_get_subset_raises_error_when_index_is_greater_than_batch_size( self): self.execute_get_subset_raises_error_when_index_not_in_range(20) def test_get_subset(self): full_query = 'Select * from t1;' query = Query('test_uri', full_query, QueryExecutionSettings(ExecutionPlanOptions(), None), QueryEvents()) expected_subset = [] mock_batch = mock.MagicMock() mock_batch.get_subset = mock.Mock(return_value=expected_subset) query._batches = [mock_batch] subset = query.get_subset(0, 0, 10) self.assertEqual(expected_subset, subset) mock_batch.get_subset.assert_called_once_with(0, 10) def test_save_as_with_invalid_batch_index(self): def execute_with_batch_index(index: int): params = SaveResultsRequestParams() params.batch_index = index with self.assertRaises(IndexError) as context_manager: self.query.save_as(params, None, None, None) self.assertEquals( 'Batch index cannot be less than 0 or greater than the number of batches', context_manager.exception.args[0]) execute_with_batch_index(-1) execute_with_batch_index(2) def test_save_as(self): params = SaveResultsRequestParams() params.batch_index = 0 file_factory = mock.MagicMock() on_success = mock.MagicMock() on_error = mock.MagicMock() batch_save_as_mock = mock.MagicMock() self.query.batches[0].save_as = batch_save_as_mock self.query.save_as(params, file_factory, on_success, on_error) batch_save_as_mock.assert_called_once_with(params, file_factory, on_success, on_error)