def test_commit_retry_aborted_statements(self, mock_client): """Check that retried transaction executing the same statements.""" from google.api_core.exceptions import Aborted from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.connection import connect from google.cloud.spanner_dbapi.cursor import Statement row = ["field1", "field2"] connection = connect("test-instance", "test-database") cursor = connection.cursor() cursor._checksum = ResultsChecksum() cursor._checksum.consume_result(row) statement = Statement("SELECT 1", [], {}, cursor._checksum, False) connection._statements.append(statement) mock_transaction = mock.Mock(rolled_back=False, committed=False) connection._transaction = mock_transaction mock_transaction.commit.side_effect = [Aborted("Aborted"), None] run_mock = connection.run_statement = mock.Mock() run_mock.return_value = ([row], ResultsChecksum()) connection.commit() run_mock.assert_called_with(statement, retried=True)
def test_retry_transaction(self): """Check retrying an aborted transaction.""" from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.cursor import Statement row = ["field1", "field2"] connection = self._make_connection() checksum = ResultsChecksum() checksum.consume_result(row) retried_checkum = ResultsChecksum() statement = Statement("SELECT 1", [], {}, checksum, False) connection._statements.append(statement) with mock.patch( "google.cloud.spanner_dbapi.connection.Connection.run_statement", return_value=([row], retried_checkum), ) as run_mock: with mock.patch( "google.cloud.spanner_dbapi.connection._compare_checksums" ) as compare_mock: connection.retry_transaction() compare_mock.assert_called_with(checksum, retried_checkum) run_mock.assert_called_with(statement, retried=True)
def test_commit_retry_aborted_statements(self): """Check that retried transaction executing the same statements.""" from google.api_core.exceptions import Aborted from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.connection import connect from google.cloud.spanner_dbapi.cursor import Statement row = ["field1", "field2"] with mock.patch( "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, ): with mock.patch( "google.cloud.spanner_v1.database.Database.exists", return_value=True, ): connection = connect("test-instance", "test-database") cursor = connection.cursor() cursor._checksum = ResultsChecksum() cursor._checksum.consume_result(row) statement = Statement("SELECT 1", [], {}, cursor._checksum, False) connection._statements.append(statement) connection._transaction = mock.Mock(rolled_back=False, committed=False) with mock.patch.object( connection._transaction, "commit", side_effect=(Aborted("Aborted"), None), ): with mock.patch( "google.cloud.spanner_dbapi.connection.Connection.run_statement", return_value=([row], ResultsChecksum()), ) as run_mock: connection.commit() run_mock.assert_called_with(statement, retried=True)
def test_retry_transaction_checksum_mismatch(self): """ Check retrying an aborted transaction with results checksums mismatch. """ from google.cloud.spanner_dbapi.exceptions import RetryAborted from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.cursor import Statement row = ["field1", "field2"] retried_row = ["field3", "field4"] connection = self._make_connection() checksum = ResultsChecksum() checksum.consume_result(row) retried_checkum = ResultsChecksum() statement = Statement("SELECT 1", [], {}, checksum, False) connection._statements.append(statement) with mock.patch( "google.cloud.spanner_dbapi.connection.Connection.run_statement", return_value=([retried_row], retried_checkum), ): with self.assertRaises(RetryAborted): connection.retry_transaction()
def test_fetchall_retry_aborted_statements(self, mock_client): """Check that retried transaction executing the same statements.""" from google.api_core.exceptions import Aborted from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.connection import connect from google.cloud.spanner_dbapi.cursor import Statement row = ["field1", "field2"] connection = connect("test-instance", "test-database") cursor = connection.cursor() cursor._checksum = ResultsChecksum() cursor._checksum.consume_result(row) statement = Statement("SELECT 1", [], {}, cursor._checksum, False) connection._statements.append(statement) with mock.patch( "google.cloud.spanner_dbapi.cursor.Cursor.__iter__", side_effect=(Aborted("Aborted"), iter(row)), ): with mock.patch( "google.cloud.spanner_dbapi.connection.Connection.run_statement", return_value=([row], ResultsChecksum()), ) as run_mock: cursor.fetchall() run_mock.assert_called_with(statement, retried=True)
def test_fetchmany_retry_aborted_statements_checksums_mismatch( self, mock_client): """Check transaction retrying with underlying data being changed.""" from google.api_core.exceptions import Aborted from google.cloud.spanner_dbapi.exceptions import RetryAborted from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.connection import connect from google.cloud.spanner_dbapi.cursor import Statement row = ["field1", "field2"] row2 = ["updated_field1", "field2"] connection = connect("test-instance", "test-database") cursor = connection.cursor() cursor._checksum = ResultsChecksum() cursor._checksum.consume_result(row) statement = Statement("SELECT 1", [], {}, cursor._checksum, False) connection._statements.append(statement) with mock.patch( "google.cloud.spanner_dbapi.cursor.Cursor.__next__", side_effect=(Aborted("Aborted"), None), ): with mock.patch( "google.cloud.spanner_dbapi.connection.Connection.run_statement", return_value=([row2], ResultsChecksum()), ) as run_mock: with self.assertRaises(RetryAborted): cursor.fetchmany(len(row)) run_mock.assert_called_with(statement, retried=True)
def test_execute_insert_statement_autocommit_off(self): from google.cloud.spanner_dbapi import parse_utils from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.utils import PeekIterator connection = self._make_connection(self.INSTANCE, mock.MagicMock()) cursor = self._make_one(connection) cursor.connection._autocommit = False cursor.connection.transaction_checkout = mock.MagicMock(autospec=True) cursor._checksum = ResultsChecksum() with mock.patch( "google.cloud.spanner_dbapi.parse_utils.classify_stmt", return_value=parse_utils.STMT_INSERT, ): with mock.patch( "google.cloud.spanner_dbapi.connection.Connection.run_statement", return_value=(mock.MagicMock(), ResultsChecksum()), ): cursor.execute( sql= "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)" ) self.assertIsInstance(cursor._result_set, mock.MagicMock) self.assertIsInstance(cursor._itr, PeekIterator)
def test_retry_aborted_retry_without_delay(self, mock_client): """ Check that in case of a retried transaction failed, the connection will retry it once again. """ from google.api_core.exceptions import Aborted from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.connection import connect from google.cloud.spanner_dbapi.cursor import Statement row = ["field1", "field2"] connection = connect("test-instance", "test-database") cursor = connection.cursor() cursor._checksum = ResultsChecksum() cursor._checksum.consume_result(row) statement = Statement("SELECT 1", [], {}, cursor._checksum, False) connection._statements.append(statement) metadata_mock = mock.Mock() metadata_mock.trailing_metadata.return_value = {} run_mock = connection.run_statement = mock.Mock() run_mock.side_effect = [ Aborted("Aborted", errors=[metadata_mock]), ([row], ResultsChecksum()), ] connection._get_retry_delay = mock.Mock(return_value=False) connection.retry_transaction() run_mock.assert_has_calls( (mock.call(statement, retried=True), mock.call(statement, retried=True),) )
def test_less_results(self): from google.cloud.spanner_dbapi.checksum import _compare_checksums from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.exceptions import RetryAborted original = ResultsChecksum() original.consume_result(5) retried = ResultsChecksum() with self.assertRaises(RetryAborted): _compare_checksums(original, retried)
def run_statement(self, statement, retried=False): """Run single SQL statement in begun transaction. This method is never used in autocommit mode. In !autocommit mode however it remembers every executed SQL statement with its parameters. :type statement: :class:`dict` :param statement: SQL statement to execute. :type retried: bool :param retried: (Optional) Retry the SQL statement if statement execution failed. Defaults to false. :rtype: :class:`google.cloud.spanner_v1.streamed.StreamedResultSet`, :class:`google.cloud.spanner_dbapi.checksum.ResultsChecksum` :returns: Streamed result set of the statement and a checksum of this statement results. """ transaction = self.transaction_checkout() if not retried: self._statements.append(statement) if statement.is_insert: parts = parse_insert(statement.sql, statement.params) if parts.get("homogenous"): _execute_insert_homogenous(transaction, parts) return ( iter(()), ResultsChecksum() if retried else statement.checksum, ) else: _execute_insert_heterogenous( transaction, parts.get("sql_params_list"), ) return ( iter(()), ResultsChecksum() if retried else statement.checksum, ) return ( transaction.execute_sql( statement.sql, statement.params, param_types=statement.param_types, ), ResultsChecksum() if retried else statement.checksum, )
def test_fetchmany_retry_aborted(self): """Check that aborted fetch re-executing transaction.""" from google.api_core.exceptions import Aborted from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.connection import connect with mock.patch( "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, ): with mock.patch( "google.cloud.spanner_v1.database.Database.exists", return_value=True, ): connection = connect("test-instance", "test-database") cursor = connection.cursor() cursor._checksum = ResultsChecksum() with mock.patch( "google.cloud.spanner_dbapi.cursor.Cursor.__next__", side_effect=(Aborted("Aborted"), None), ): with mock.patch( "google.cloud.spanner_dbapi.connection.Connection.retry_transaction" ) as retry_mock: cursor.fetchmany() retry_mock.assert_called_with()
def test_retry_aborted_retry(self): """ Check that in case of a retried transaction failed, the connection will retry it once again. """ from google.api_core.exceptions import Aborted from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.connection import connect from google.cloud.spanner_dbapi.cursor import Statement row = ["field1", "field2"] with mock.patch( "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, ): with mock.patch( "google.cloud.spanner_v1.database.Database.exists", return_value=True, ): connection = connect("test-instance", "test-database") cursor = connection.cursor() cursor._checksum = ResultsChecksum() cursor._checksum.consume_result(row) statement = Statement("SELECT 1", [], {}, cursor._checksum, False) connection._statements.append(statement) metadata_mock = mock.Mock() metadata_mock.trailing_metadata.return_value = {} with mock.patch.object( connection, "run_statement", side_effect=( Aborted("Aborted", errors=[metadata_mock]), ([row], ResultsChecksum()), ), ) as retry_mock: connection.retry_transaction() retry_mock.assert_has_calls( ( mock.call(statement, retried=True), mock.call(statement, retried=True), ) )
def test_fetchall(self): from google.cloud.spanner_dbapi.checksum import ResultsChecksum connection = self._make_connection(self.INSTANCE, mock.MagicMock()) cursor = self._make_one(connection) cursor._checksum = ResultsChecksum() lst = [(1, ), (2, ), (3, )] cursor._itr = iter(lst) self.assertEqual(cursor.fetchall(), lst)
def test_equal(self): from google.cloud.spanner_dbapi.checksum import _compare_checksums from google.cloud.spanner_dbapi.checksum import ResultsChecksum original = ResultsChecksum() original.consume_result(5) retried = ResultsChecksum() retried.consume_result(5) self.assertIsNone(_compare_checksums(original, retried))
def _rerun_previous_statements(self): """ Helper to run all the remembered statements from the last transaction. """ for statement in self._statements: if isinstance(statement, list): statements, checksum = statement transaction = self.transaction_checkout() status, res = transaction.batch_update(statements) if status.code == ABORTED: self.connection._transaction = None raise Aborted(status.details) retried_checksum = ResultsChecksum() retried_checksum.consume_result(res) retried_checksum.consume_result(status.code) _compare_checksums(checksum, retried_checksum) else: res_iter, retried_checksum = self.run_statement(statement, retried=True) # executing all the completed statements if statement != self._statements[-1]: for res in res_iter: retried_checksum.consume_result(res) _compare_checksums(statement.checksum, retried_checksum) # executing the failed statement else: # streaming up to the failed result or # to the end of the streaming iterator while len(retried_checksum) < len(statement.checksum): try: res = next(iter(res_iter)) retried_checksum.consume_result(res) except StopIteration: break _compare_checksums(statement.checksum, retried_checksum)
def test_fetchone(self): from google.cloud.spanner_dbapi.checksum import ResultsChecksum connection = self._make_connection(self.INSTANCE, mock.MagicMock()) cursor = self._make_one(connection) cursor._checksum = ResultsChecksum() lst = [1, 2, 3] cursor._itr = iter(lst) for i in range(len(lst)): self.assertEqual(cursor.fetchone(), lst[i]) self.assertIsNone(cursor.fetchone())
def test_retry_transaction_raise_max_internal_retries(self): """Check retrying raise an error of max internal retries.""" from google.cloud.spanner_dbapi import connection as conn from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.cursor import Statement conn.MAX_INTERNAL_RETRIES = 0 row = ["field1", "field2"] connection = self._make_connection() checksum = ResultsChecksum() checksum.consume_result(row) statement = Statement("SELECT 1", [], {}, checksum, False) connection._statements.append(statement) with self.assertRaises(Exception): connection.retry_transaction() conn.MAX_INTERNAL_RETRIES = 50
def test_fetchmany_w_autocommit(self): from google.cloud.spanner_dbapi.checksum import ResultsChecksum connection = self._make_connection(self.INSTANCE, mock.MagicMock()) connection.autocommit = True cursor = self._make_one(connection) cursor._checksum = ResultsChecksum() lst = [(1, ), (2, ), (3, )] cursor._itr = iter(lst) self.assertEqual(cursor.fetchmany(), [lst[0]]) result = cursor.fetchmany(len(lst)) self.assertEqual(result, lst[1:])
def test_run_statement_w_retried(self): """Check that Connection doesn't remember re-executed statements.""" from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.cursor import Statement sql = """SELECT 23 FROM table WHERE id = @a1""" params = {"a1": "value"} param_types = {"a1": str} connection = self._make_connection() connection.transaction_checkout = mock.Mock() statement = Statement(sql, params, param_types, ResultsChecksum(), False) connection.run_statement(statement, retried=True) self.assertEqual(len(connection._statements), 0)
def test_run_statement_w_homogeneous_insert_statements(self): """Check that Connection executed homogeneous insert statements.""" from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.cursor import Statement sql = "INSERT INTO T (f1, f2) VALUES (%s, %s), (%s, %s)" params = ["a", "b", "c", "d"] param_types = {"f1": str, "f2": str} connection = self._make_connection() connection.transaction_checkout = mock.Mock() statement = Statement(sql, params, param_types, ResultsChecksum(), True) connection.run_statement(statement, retried=True) self.assertEqual(len(connection._statements), 0)
def test_run_statement_wo_retried(self): """Check that Connection remembers executed statements.""" from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.cursor import Statement sql = """SELECT 23 FROM table WHERE id = @a1""" params = {"a1": "value"} param_types = {"a1": str} connection = self._make_connection() connection.transaction_checkout = mock.Mock() statement = Statement(sql, params, param_types, ResultsChecksum(), False) connection.run_statement(statement) self.assertEqual(connection._statements[0].sql, sql) self.assertEqual(connection._statements[0].params, params) self.assertEqual(connection._statements[0].param_types, param_types) self.assertIsInstance(connection._statements[0].checksum, ResultsChecksum)
def test_run_statement_w_heterogenous_insert_statements(self): """Check that Connection executed heterogenous insert statements.""" from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.cursor import Statement sql = "INSERT INTO T (f1, f2) VALUES (1, 2)" params = None param_types = None connection = self._make_connection() statement = Statement(sql, params, param_types, ResultsChecksum(), True) with mock.patch( "google.cloud.spanner_dbapi.connection.Connection.transaction_checkout" ): connection.run_statement(statement, retried=True) self.assertEqual(len(connection._statements), 0)
def test_run_statement_w_heterogenous_insert_statements(self): """Check that Connection executed heterogenous insert statements.""" from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.cursor import Statement from google.rpc.status_pb2 import Status from google.rpc.code_pb2 import OK sql = "INSERT INTO T (f1, f2) VALUES (1, 2)" params = None param_types = None connection = self._make_connection() transaction = mock.MagicMock() connection.transaction_checkout = mock.Mock(return_value=transaction) transaction.batch_update = mock.Mock(return_value=(Status(code=OK), 1)) statement = Statement(sql, params, param_types, ResultsChecksum(), True) connection.run_statement(statement, retried=True) self.assertEqual(len(connection._statements), 0)
def test_fetchall_retry_aborted(self, mock_client): """Check that aborted fetch re-executing transaction.""" from google.api_core.exceptions import Aborted from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.connection import connect connection = connect("test-instance", "test-database") cursor = connection.cursor() cursor._checksum = ResultsChecksum() with mock.patch( "google.cloud.spanner_dbapi.cursor.Cursor.__iter__", side_effect=(Aborted("Aborted"), iter([])), ): with mock.patch( "google.cloud.spanner_dbapi.connection.Connection.retry_transaction" ) as retry_mock: cursor.fetchall() retry_mock.assert_called_with()
def execute(self, sql, args=None): """Prepares and executes a Spanner database operation. :type sql: str :param sql: A SQL query statement. :type args: list :param args: Additional parameters to supplement the SQL query. """ self._result_set = None try: if self.connection.read_only: self._handle_DQL(sql, args or None) return class_ = parse_utils.classify_stmt(sql) if class_ == parse_utils.STMT_DDL: self._batch_DDLs(sql) if self.connection.autocommit: self.connection.run_prior_DDL_statements() return # For every other operation, we've got to ensure that # any prior DDL statements were run. # self._run_prior_DDL_statements() self.connection.run_prior_DDL_statements() if class_ == parse_utils.STMT_UPDATING: sql = parse_utils.ensure_where_clause(sql) if class_ != parse_utils.STMT_INSERT: sql, args = sql_pyformat_args_to_spanner(sql, args or None) if not self.connection.autocommit: statement = Statement( sql, args, get_param_types(args or None) if class_ != parse_utils.STMT_INSERT else {}, ResultsChecksum(), class_ == parse_utils.STMT_INSERT, ) ( self._result_set, self._checksum, ) = self.connection.run_statement(statement) while True: try: self._itr = PeekIterator(self._result_set) break except Aborted: self.connection.retry_transaction() return if class_ == parse_utils.STMT_NON_UPDATING: self._handle_DQL(sql, args or None) elif class_ == parse_utils.STMT_INSERT: _helpers.handle_insert(self.connection, sql, args or None) else: self.connection.database.run_in_transaction( self._do_execute_update, sql, args or None ) except (AlreadyExists, FailedPrecondition, OutOfRange) as e: raise IntegrityError(getattr(e, "details", e)) except InvalidArgument as e: raise ProgrammingError(getattr(e, "details", e)) except InternalServerError as e: raise OperationalError(getattr(e, "details", e))
def execute(self, sql, args=None): """Prepares and executes a Spanner database operation. :type sql: str :param sql: A SQL query statement. :type args: list :param args: Additional parameters to supplement the SQL query. """ if not self.connection: raise ProgrammingError("Cursor is not connected to the database") self._raise_if_closed() self._result_set = None # Classify whether this is a read-only SQL statement. try: classification = parse_utils.classify_stmt(sql) if classification == parse_utils.STMT_DDL: ddl_statements = [] for ddl in sqlparse.split(sql): if ddl: if ddl[-1] == ";": ddl = ddl[:-1] if parse_utils.classify_stmt(ddl) != parse_utils.STMT_DDL: raise ValueError("Only DDL statements may be batched.") ddl_statements.append(ddl) # Only queue DDL statements if they are all correctly classified. self.connection._ddl_statements.extend(ddl_statements) if self.connection.autocommit: self.connection.run_prior_DDL_statements() return # For every other operation, we've got to ensure that # any prior DDL statements were run. # self._run_prior_DDL_statements() self.connection.run_prior_DDL_statements() if not self.connection.autocommit: if classification == parse_utils.STMT_UPDATING: sql = parse_utils.ensure_where_clause(sql) if classification != parse_utils.STMT_INSERT: sql, args = sql_pyformat_args_to_spanner(sql, args or None) statement = Statement( sql, args, get_param_types(args or None) if classification != parse_utils.STMT_INSERT else {}, ResultsChecksum(), classification == parse_utils.STMT_INSERT, ) (self._result_set, self._checksum,) = self.connection.run_statement( statement ) while True: try: self._itr = PeekIterator(self._result_set) break except Aborted: self.connection.retry_transaction() return if classification == parse_utils.STMT_NON_UPDATING: self._handle_DQL(sql, args or None) elif classification == parse_utils.STMT_INSERT: _helpers.handle_insert(self.connection, sql, args or None) else: self.connection.database.run_in_transaction( self._do_execute_update, sql, args or None ) except (AlreadyExists, FailedPrecondition) as e: raise IntegrityError(e.details if hasattr(e, "details") else e) except InvalidArgument as e: raise ProgrammingError(e.details if hasattr(e, "details") else e) except InternalServerError as e: raise OperationalError(e.details if hasattr(e, "details") else e)
def executemany(self, operation, seq_of_params): """Execute the given SQL with every parameters set from the given sequence of parameters. :type operation: str :param operation: SQL code to execute. :type seq_of_params: list :param seq_of_params: Sequence of additional parameters to run the query with. """ self._raise_if_closed() classification = parse_utils.classify_stmt(operation) if classification == parse_utils.STMT_DDL: raise ProgrammingError( "Executing DDL statements with executemany() method is not allowed." ) many_result_set = StreamedManyResultSets() if classification in (parse_utils.STMT_INSERT, parse_utils.STMT_UPDATING): statements = [] for params in seq_of_params: sql, params = parse_utils.sql_pyformat_args_to_spanner( operation, params ) statements.append((sql, params, get_param_types(params))) if self.connection.autocommit: self.connection.database.run_in_transaction( self._do_batch_update, statements, many_result_set ) else: retried = False while True: try: transaction = self.connection.transaction_checkout() res_checksum = ResultsChecksum() if not retried: self.connection._statements.append( (statements, res_checksum) ) status, res = transaction.batch_update(statements) many_result_set.add_iter(res) res_checksum.consume_result(res) res_checksum.consume_result(status.code) if status.code == ABORTED: self.connection._transaction = None raise Aborted(status.details) elif status.code != OK: raise OperationalError(status.details) break except Aborted: self.connection.retry_transaction() retried = True else: for params in seq_of_params: self.execute(operation, params) many_result_set.add_iter(self._itr) self._result_set = many_result_set self._itr = many_result_set