Beispiel #1
0
    def test_ensure_where_clause(self):
        cases = [
            (
                "UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1",
                "UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1",
            ),
            (
                "UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5",
                "UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5 WHERE 1=1",
            ),
            (
                "UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2",
                "UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2",
            ),
            (
                "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
                "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
            ),
            (
                "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
                "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
            ),
            ("DELETE * FROM TABLE", "DELETE * FROM TABLE WHERE 1=1"),
        ]

        for sql, want in cases:
            with self.subTest(sql=sql):
                got = ensure_where_clause(sql)
                self.assertEqual(got, want)
Beispiel #2
0
    def test_ensure_where_clause(self):
        from google.cloud.spanner_dbapi.parse_utils import ensure_where_clause

        cases = (
            "UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1",
            "UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2",
            "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
        )
        err_cases = (
            "UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5",
            "DELETE * FROM TABLE",
        )
        for sql in cases:
            with self.subTest(sql=sql):
                ensure_where_clause(sql)

        for sql in err_cases:
            with self.subTest(sql=sql):
                self.assertEqual(ensure_where_clause(sql), sql + " WHERE 1=1")
Beispiel #3
0
    def _do_execute_update(self, transaction, sql, params, param_types=None):
        sql = parse_utils.ensure_where_clause(sql)
        sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params)

        result = transaction.execute_update(
            sql, params=params, param_types=get_param_types(params))
        self._itr = None
        if type(result) == int:
            self._row_count = result

        return result
Beispiel #4
0
    def execute(self, sql, args=None):
        """Prepares and executes a Spanner database operation.

        :type sql: str
        :param sql: A SQL query statement.

        :type args: list
        :param args: Additional parameters to supplement the SQL query.
        """
        if not self.connection:
            raise ProgrammingError("Cursor is not connected to the database")

        self._raise_if_closed()

        self._result_set = None

        # Classify whether this is a read-only SQL statement.
        try:
            classification = parse_utils.classify_stmt(sql)
            if classification == parse_utils.STMT_DDL:
                ddl_statements = []
                for ddl in sqlparse.split(sql):
                    if ddl:
                        if ddl[-1] == ";":
                            ddl = ddl[:-1]
                        if parse_utils.classify_stmt(ddl) != parse_utils.STMT_DDL:
                            raise ValueError("Only DDL statements may be batched.")
                        ddl_statements.append(ddl)
                # Only queue DDL statements if they are all correctly classified.
                self.connection._ddl_statements.extend(ddl_statements)
                if self.connection.autocommit:
                    self.connection.run_prior_DDL_statements()
                return

            # For every other operation, we've got to ensure that
            # any prior DDL statements were run.
            # self._run_prior_DDL_statements()
            self.connection.run_prior_DDL_statements()

            if not self.connection.autocommit:
                if classification == parse_utils.STMT_UPDATING:
                    sql = parse_utils.ensure_where_clause(sql)

                if classification != parse_utils.STMT_INSERT:
                    sql, args = sql_pyformat_args_to_spanner(sql, args or None)

                statement = Statement(
                    sql,
                    args,
                    get_param_types(args or None)
                    if classification != parse_utils.STMT_INSERT
                    else {},
                    ResultsChecksum(),
                    classification == parse_utils.STMT_INSERT,
                )
                (self._result_set, self._checksum,) = self.connection.run_statement(
                    statement
                )
                while True:
                    try:
                        self._itr = PeekIterator(self._result_set)
                        break
                    except Aborted:
                        self.connection.retry_transaction()
                return

            if classification == parse_utils.STMT_NON_UPDATING:
                self._handle_DQL(sql, args or None)
            elif classification == parse_utils.STMT_INSERT:
                _helpers.handle_insert(self.connection, sql, args or None)
            else:
                self.connection.database.run_in_transaction(
                    self._do_execute_update, sql, args or None
                )
        except (AlreadyExists, FailedPrecondition) as e:
            raise IntegrityError(e.details if hasattr(e, "details") else e)
        except InvalidArgument as e:
            raise ProgrammingError(e.details if hasattr(e, "details") else e)
        except InternalServerError as e:
            raise OperationalError(e.details if hasattr(e, "details") else e)
Beispiel #5
0
    def execute(self, sql, args=None):
        """Prepares and executes a Spanner database operation.

        :type sql: str
        :param sql: A SQL query statement.

        :type args: list
        :param args: Additional parameters to supplement the SQL query.
        """
        self._result_set = None

        try:
            if self.connection.read_only:
                self._handle_DQL(sql, args or None)
                return

            class_ = parse_utils.classify_stmt(sql)
            if class_ == parse_utils.STMT_DDL:
                self._batch_DDLs(sql)
                if self.connection.autocommit:
                    self.connection.run_prior_DDL_statements()
                return

            # For every other operation, we've got to ensure that
            # any prior DDL statements were run.
            # self._run_prior_DDL_statements()
            self.connection.run_prior_DDL_statements()

            if class_ == parse_utils.STMT_UPDATING:
                sql = parse_utils.ensure_where_clause(sql)

            if class_ != parse_utils.STMT_INSERT:
                sql, args = sql_pyformat_args_to_spanner(sql, args or None)

            if not self.connection.autocommit:
                statement = Statement(
                    sql,
                    args,
                    get_param_types(args or None)
                    if class_ != parse_utils.STMT_INSERT
                    else {},
                    ResultsChecksum(),
                    class_ == parse_utils.STMT_INSERT,
                )

                (
                    self._result_set,
                    self._checksum,
                ) = self.connection.run_statement(statement)
                while True:
                    try:
                        self._itr = PeekIterator(self._result_set)
                        break
                    except Aborted:
                        self.connection.retry_transaction()
                return

            if class_ == parse_utils.STMT_NON_UPDATING:
                self._handle_DQL(sql, args or None)
            elif class_ == parse_utils.STMT_INSERT:
                _helpers.handle_insert(self.connection, sql, args or None)
            else:
                self.connection.database.run_in_transaction(
                    self._do_execute_update, sql, args or None
                )
        except (AlreadyExists, FailedPrecondition, OutOfRange) as e:
            raise IntegrityError(getattr(e, "details", e))
        except InvalidArgument as e:
            raise ProgrammingError(getattr(e, "details", e))
        except InternalServerError as e:
            raise OperationalError(getattr(e, "details", e))