def test_ensure_where_clause(self): cases = [ ( "UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1", "UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1", ), ( "UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5", "UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5 WHERE 1=1", ), ( "UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2", "UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2", ), ( "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", ), ( "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", ), ("DELETE * FROM TABLE", "DELETE * FROM TABLE WHERE 1=1"), ] for sql, want in cases: with self.subTest(sql=sql): got = ensure_where_clause(sql) self.assertEqual(got, want)
def test_ensure_where_clause(self): from google.cloud.spanner_dbapi.parse_utils import ensure_where_clause cases = ( "UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1", "UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2", "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", ) err_cases = ( "UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5", "DELETE * FROM TABLE", ) for sql in cases: with self.subTest(sql=sql): ensure_where_clause(sql) for sql in err_cases: with self.subTest(sql=sql): self.assertEqual(ensure_where_clause(sql), sql + " WHERE 1=1")
def _do_execute_update(self, transaction, sql, params, param_types=None): sql = parse_utils.ensure_where_clause(sql) sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params) result = transaction.execute_update( sql, params=params, param_types=get_param_types(params)) self._itr = None if type(result) == int: self._row_count = result return result
def execute(self, sql, args=None): """Prepares and executes a Spanner database operation. :type sql: str :param sql: A SQL query statement. :type args: list :param args: Additional parameters to supplement the SQL query. """ if not self.connection: raise ProgrammingError("Cursor is not connected to the database") self._raise_if_closed() self._result_set = None # Classify whether this is a read-only SQL statement. try: classification = parse_utils.classify_stmt(sql) if classification == parse_utils.STMT_DDL: ddl_statements = [] for ddl in sqlparse.split(sql): if ddl: if ddl[-1] == ";": ddl = ddl[:-1] if parse_utils.classify_stmt(ddl) != parse_utils.STMT_DDL: raise ValueError("Only DDL statements may be batched.") ddl_statements.append(ddl) # Only queue DDL statements if they are all correctly classified. self.connection._ddl_statements.extend(ddl_statements) if self.connection.autocommit: self.connection.run_prior_DDL_statements() return # For every other operation, we've got to ensure that # any prior DDL statements were run. # self._run_prior_DDL_statements() self.connection.run_prior_DDL_statements() if not self.connection.autocommit: if classification == parse_utils.STMT_UPDATING: sql = parse_utils.ensure_where_clause(sql) if classification != parse_utils.STMT_INSERT: sql, args = sql_pyformat_args_to_spanner(sql, args or None) statement = Statement( sql, args, get_param_types(args or None) if classification != parse_utils.STMT_INSERT else {}, ResultsChecksum(), classification == parse_utils.STMT_INSERT, ) (self._result_set, self._checksum,) = self.connection.run_statement( statement ) while True: try: self._itr = PeekIterator(self._result_set) break except Aborted: self.connection.retry_transaction() return if classification == parse_utils.STMT_NON_UPDATING: self._handle_DQL(sql, args or None) elif classification == parse_utils.STMT_INSERT: _helpers.handle_insert(self.connection, sql, args or None) else: self.connection.database.run_in_transaction( self._do_execute_update, sql, args or None ) except (AlreadyExists, FailedPrecondition) as e: raise IntegrityError(e.details if hasattr(e, "details") else e) except InvalidArgument as e: raise ProgrammingError(e.details if hasattr(e, "details") else e) except InternalServerError as e: raise OperationalError(e.details if hasattr(e, "details") else e)
def execute(self, sql, args=None): """Prepares and executes a Spanner database operation. :type sql: str :param sql: A SQL query statement. :type args: list :param args: Additional parameters to supplement the SQL query. """ self._result_set = None try: if self.connection.read_only: self._handle_DQL(sql, args or None) return class_ = parse_utils.classify_stmt(sql) if class_ == parse_utils.STMT_DDL: self._batch_DDLs(sql) if self.connection.autocommit: self.connection.run_prior_DDL_statements() return # For every other operation, we've got to ensure that # any prior DDL statements were run. # self._run_prior_DDL_statements() self.connection.run_prior_DDL_statements() if class_ == parse_utils.STMT_UPDATING: sql = parse_utils.ensure_where_clause(sql) if class_ != parse_utils.STMT_INSERT: sql, args = sql_pyformat_args_to_spanner(sql, args or None) if not self.connection.autocommit: statement = Statement( sql, args, get_param_types(args or None) if class_ != parse_utils.STMT_INSERT else {}, ResultsChecksum(), class_ == parse_utils.STMT_INSERT, ) ( self._result_set, self._checksum, ) = self.connection.run_statement(statement) while True: try: self._itr = PeekIterator(self._result_set) break except Aborted: self.connection.retry_transaction() return if class_ == parse_utils.STMT_NON_UPDATING: self._handle_DQL(sql, args or None) elif class_ == parse_utils.STMT_INSERT: _helpers.handle_insert(self.connection, sql, args or None) else: self.connection.database.run_in_transaction( self._do_execute_update, sql, args or None ) except (AlreadyExists, FailedPrecondition, OutOfRange) as e: raise IntegrityError(getattr(e, "details", e)) except InvalidArgument as e: raise ProgrammingError(getattr(e, "details", e)) except InternalServerError as e: raise OperationalError(getattr(e, "details", e))