Example #1
0
 def _handle_DQL(self, sql, params):
     with self.connection.database.snapshot() as snapshot:
         # Reference
         #  https://googleapis.dev/python/spanner/latest/session-api.html#google.cloud.spanner_v1.session.Session.execute_sql
         sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params)
         res = snapshot.execute_sql(sql,
                                    params=params,
                                    param_types=get_param_types(params))
         if type(res) == int:
             self._row_count = res
             self._itr = None
         else:
             # Immediately using:
             #   iter(response)
             # here, because this Spanner API doesn't provide
             # easy mechanisms to detect when only a single item
             # is returned or many, yet mixing results that
             # are for .fetchone() with those that would result in
             # many items returns a RuntimeError if .fetchone() is
             # invoked and vice versa.
             self._result_set = res
             # Read the first element so that the StreamedResultSet can
             # return the metadata after a DQL statement. See issue #155.
             self._itr = PeekIterator(self._result_set)
             # Unfortunately, Spanner doesn't seem to send back
             # information about the number of rows available.
             self._row_count = _UNSET_COUNT
Example #2
0
 def _handle_DQL_with_snapshot(self, snapshot, sql, params):
     self._result_set = snapshot.execute_sql(sql, params, get_param_types(params))
     # Read the first element so that the StreamedResultSet can
     # return the metadata after a DQL statement.
     self._itr = PeekIterator(self._result_set)
     # Unfortunately, Spanner doesn't seem to send back
     # information about the number of rows available.
     self._row_count = _UNSET_COUNT
Example #3
0
    def _do_execute_update(self, transaction, sql, params):
        result = transaction.execute_update(
            sql, params=params, param_types=get_param_types(params)
        )
        self._itr = None
        if type(result) == int:
            self._row_count = result

        return result
def _execute_insert_heterogenous(transaction, sql_params_list):
    for sql, params in sql_params_list:
        sql, params = sql_pyformat_args_to_spanner(sql, params)
        param_types = get_param_types(params)
        res = transaction.execute_sql(sql,
                                      params=params,
                                      param_types=param_types)
        # TODO: File a bug with Cloud Spanner and the Python client maintainers
        # about a lost commit when res isn't read from.
        _ = list(res)
Example #5
0
    def _do_execute_update(self, transaction, sql, params, param_types=None):
        sql = parse_utils.ensure_where_clause(sql)
        sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params)

        result = transaction.execute_update(
            sql, params=params, param_types=get_param_types(params))
        self._itr = None
        if type(result) == int:
            self._row_count = result

        return result
Example #6
0
    def execute(self, sql, args=None):
        """Prepares and executes a Spanner database operation.

        :type sql: str
        :param sql: A SQL query statement.

        :type args: list
        :param args: Additional parameters to supplement the SQL query.
        """
        if not self.connection:
            raise ProgrammingError("Cursor is not connected to the database")

        self._raise_if_closed()

        self._result_set = None

        # Classify whether this is a read-only SQL statement.
        try:
            classification = parse_utils.classify_stmt(sql)
            if classification == parse_utils.STMT_DDL:
                self.connection._ddl_statements.append(sql)
                return

            # For every other operation, we've got to ensure that
            # any prior DDL statements were run.
            # self._run_prior_DDL_statements()
            self.connection.run_prior_DDL_statements()

            if not self.connection.autocommit:
                transaction = self.connection.transaction_checkout()

                sql, params = parse_utils.sql_pyformat_args_to_spanner(
                    sql, args)

                self._result_set = transaction.execute_sql(
                    sql, params, param_types=get_param_types(params))
                self._itr = PeekIterator(self._result_set)
                return

            if classification == parse_utils.STMT_NON_UPDATING:
                self._handle_DQL(sql, args or None)
            elif classification == parse_utils.STMT_INSERT:
                _helpers.handle_insert(self.connection, sql, args or None)
            else:
                self.connection.database.run_in_transaction(
                    self._do_execute_update, sql, args or None)
        except (AlreadyExists, FailedPrecondition) as e:
            raise IntegrityError(e.details if hasattr(e, "details") else e)
        except InvalidArgument as e:
            raise ProgrammingError(e.details if hasattr(e, "details") else e)
        except InternalServerError as e:
            raise OperationalError(e.details if hasattr(e, "details") else e)
Example #7
0
    def test_get_param_types(self):
        import datetime
        import decimal

        from google.cloud.spanner_dbapi.parse_utils import (
            DateStr,
            TimestampStr,
            get_param_types,
        )

        params = {
            "a1": 10,
            "b1": "string",
            "c1": 10.39,
            "d1": TimestampStr("2005-08-30T01:01:01.000001Z"),
            "e1": DateStr("2019-12-05"),
            "f1": True,
            "g1": datetime.datetime(2011, 9, 1, 13, 20, 30),
            "h1": datetime.date(2011, 9, 1),
            "i1": b"bytes",
            "j1": None,
            "k1": decimal.Decimal("3.194387483193242e+19"),
            "l1": JsonObject({"key": "value"}),
        }
        want_types = {
            "a1": param_types.INT64,
            "b1": param_types.STRING,
            "c1": param_types.FLOAT64,
            "d1": param_types.TIMESTAMP,
            "e1": param_types.DATE,
            "f1": param_types.BOOL,
            "g1": param_types.TIMESTAMP,
            "h1": param_types.DATE,
            "i1": param_types.BYTES,
            "k1": param_types.NUMERIC,
            "l1": param_types.JSON,
        }
        got_types = get_param_types(params)
        self.assertEqual(got_types, want_types)
Example #8
0
    def test_get_param_types_none(self):
        from google.cloud.spanner_dbapi.parse_utils import get_param_types

        self.assertEqual(get_param_types(None), None)
Example #9
0
def _execute_insert_heterogenous(transaction, sql_params_list):
    for sql, params in sql_params_list:
        sql, params = sql_pyformat_args_to_spanner(sql, params)
        transaction.execute_update(sql, params, get_param_types(params))
Example #10
0
    def executemany(self, operation, seq_of_params):
        """Execute the given SQL with every parameters set
        from the given sequence of parameters.

        :type operation: str
        :param operation: SQL code to execute.

        :type seq_of_params: list
        :param seq_of_params: Sequence of additional parameters to run
                              the query with.
        """
        self._raise_if_closed()

        classification = parse_utils.classify_stmt(operation)
        if classification == parse_utils.STMT_DDL:
            raise ProgrammingError(
                "Executing DDL statements with executemany() method is not allowed."
            )

        many_result_set = StreamedManyResultSets()

        if classification in (parse_utils.STMT_INSERT, parse_utils.STMT_UPDATING):
            statements = []

            for params in seq_of_params:
                sql, params = parse_utils.sql_pyformat_args_to_spanner(
                    operation, params
                )
                statements.append((sql, params, get_param_types(params)))

            if self.connection.autocommit:
                self.connection.database.run_in_transaction(
                    self._do_batch_update, statements, many_result_set
                )
            else:
                retried = False
                while True:
                    try:
                        transaction = self.connection.transaction_checkout()

                        res_checksum = ResultsChecksum()
                        if not retried:
                            self.connection._statements.append(
                                (statements, res_checksum)
                            )

                        status, res = transaction.batch_update(statements)
                        many_result_set.add_iter(res)
                        res_checksum.consume_result(res)
                        res_checksum.consume_result(status.code)

                        if status.code == ABORTED:
                            self.connection._transaction = None
                            raise Aborted(status.details)
                        elif status.code != OK:
                            raise OperationalError(status.details)
                        break
                    except Aborted:
                        self.connection.retry_transaction()
                        retried = True

        else:
            for params in seq_of_params:
                self.execute(operation, params)
                many_result_set.add_iter(self._itr)

        self._result_set = many_result_set
        self._itr = many_result_set
Example #11
0
    def execute(self, sql, args=None):
        """Prepares and executes a Spanner database operation.

        :type sql: str
        :param sql: A SQL query statement.

        :type args: list
        :param args: Additional parameters to supplement the SQL query.
        """
        if not self.connection:
            raise ProgrammingError("Cursor is not connected to the database")

        self._raise_if_closed()

        self._result_set = None

        # Classify whether this is a read-only SQL statement.
        try:
            classification = parse_utils.classify_stmt(sql)
            if classification == parse_utils.STMT_DDL:
                ddl_statements = []
                for ddl in sqlparse.split(sql):
                    if ddl:
                        if ddl[-1] == ";":
                            ddl = ddl[:-1]
                        if parse_utils.classify_stmt(ddl) != parse_utils.STMT_DDL:
                            raise ValueError("Only DDL statements may be batched.")
                        ddl_statements.append(ddl)
                # Only queue DDL statements if they are all correctly classified.
                self.connection._ddl_statements.extend(ddl_statements)
                if self.connection.autocommit:
                    self.connection.run_prior_DDL_statements()
                return

            # For every other operation, we've got to ensure that
            # any prior DDL statements were run.
            # self._run_prior_DDL_statements()
            self.connection.run_prior_DDL_statements()

            if not self.connection.autocommit:
                if classification == parse_utils.STMT_UPDATING:
                    sql = parse_utils.ensure_where_clause(sql)

                if classification != parse_utils.STMT_INSERT:
                    sql, args = sql_pyformat_args_to_spanner(sql, args or None)

                statement = Statement(
                    sql,
                    args,
                    get_param_types(args or None)
                    if classification != parse_utils.STMT_INSERT
                    else {},
                    ResultsChecksum(),
                    classification == parse_utils.STMT_INSERT,
                )
                (self._result_set, self._checksum,) = self.connection.run_statement(
                    statement
                )
                while True:
                    try:
                        self._itr = PeekIterator(self._result_set)
                        break
                    except Aborted:
                        self.connection.retry_transaction()
                return

            if classification == parse_utils.STMT_NON_UPDATING:
                self._handle_DQL(sql, args or None)
            elif classification == parse_utils.STMT_INSERT:
                _helpers.handle_insert(self.connection, sql, args or None)
            else:
                self.connection.database.run_in_transaction(
                    self._do_execute_update, sql, args or None
                )
        except (AlreadyExists, FailedPrecondition) as e:
            raise IntegrityError(e.details if hasattr(e, "details") else e)
        except InvalidArgument as e:
            raise ProgrammingError(e.details if hasattr(e, "details") else e)
        except InternalServerError as e:
            raise OperationalError(e.details if hasattr(e, "details") else e)
Example #12
0
    def execute(self, sql, args=None):
        """Prepares and executes a Spanner database operation.

        :type sql: str
        :param sql: A SQL query statement.

        :type args: list
        :param args: Additional parameters to supplement the SQL query.
        """
        self._result_set = None

        try:
            if self.connection.read_only:
                self._handle_DQL(sql, args or None)
                return

            class_ = parse_utils.classify_stmt(sql)
            if class_ == parse_utils.STMT_DDL:
                self._batch_DDLs(sql)
                if self.connection.autocommit:
                    self.connection.run_prior_DDL_statements()
                return

            # For every other operation, we've got to ensure that
            # any prior DDL statements were run.
            # self._run_prior_DDL_statements()
            self.connection.run_prior_DDL_statements()

            if class_ == parse_utils.STMT_UPDATING:
                sql = parse_utils.ensure_where_clause(sql)

            if class_ != parse_utils.STMT_INSERT:
                sql, args = sql_pyformat_args_to_spanner(sql, args or None)

            if not self.connection.autocommit:
                statement = Statement(
                    sql,
                    args,
                    get_param_types(args or None)
                    if class_ != parse_utils.STMT_INSERT
                    else {},
                    ResultsChecksum(),
                    class_ == parse_utils.STMT_INSERT,
                )

                (
                    self._result_set,
                    self._checksum,
                ) = self.connection.run_statement(statement)
                while True:
                    try:
                        self._itr = PeekIterator(self._result_set)
                        break
                    except Aborted:
                        self.connection.retry_transaction()
                return

            if class_ == parse_utils.STMT_NON_UPDATING:
                self._handle_DQL(sql, args or None)
            elif class_ == parse_utils.STMT_INSERT:
                _helpers.handle_insert(self.connection, sql, args or None)
            else:
                self.connection.database.run_in_transaction(
                    self._do_execute_update, sql, args or None
                )
        except (AlreadyExists, FailedPrecondition, OutOfRange) as e:
            raise IntegrityError(getattr(e, "details", e))
        except InvalidArgument as e:
            raise ProgrammingError(getattr(e, "details", e))
        except InternalServerError as e:
            raise OperationalError(getattr(e, "details", e))
Example #13
0
    def test_get_param_types(self):
        cases = [
            (
                {
                    "a1": 10,
                    "b1": "2005-08-30T01:01:01.000001Z",
                    "c1": "2019-12-05",
                    "d1": 10.39,
                },
                {
                    "a1": param_types.INT64,
                    "b1": param_types.STRING,
                    "c1": param_types.STRING,
                    "d1": param_types.FLOAT64,
                },
            ),
            (
                {
                    "a1": 10,
                    "b1": TimestampStr("2005-08-30T01:01:01.000001Z"),
                    "c1": "2019-12-05",
                },
                {
                    "a1": param_types.INT64,
                    "b1": param_types.TIMESTAMP,
                    "c1": param_types.STRING,
                },
            ),
            (
                {
                    "a1": 10,
                    "b1": "2005-08-30T01:01:01.000001Z",
                    "c1": DateStr("2019-12-05"),
                },
                {
                    "a1": param_types.INT64,
                    "b1": param_types.STRING,
                    "c1": param_types.DATE,
                },
            ),
            (
                {
                    "a1": 10,
                    "b1": "2005-08-30T01:01:01.000001Z"
                },
                {
                    "a1": param_types.INT64,
                    "b1": param_types.STRING
                },
            ),
            (
                {
                    "a1": 10,
                    "b1": TimestampStr("2005-08-30T01:01:01.000001Z"),
                    "c1": DateStr("2005-08-30"),
                },
                {
                    "a1": param_types.INT64,
                    "b1": param_types.TIMESTAMP,
                    "c1": param_types.DATE,
                },
            ),
            (
                {
                    "a1": 10,
                    "b1": "aaaaa08-30T01:01:01.000001Z"
                },
                {
                    "a1": param_types.INT64,
                    "b1": param_types.STRING
                },
            ),
            (
                {
                    "a1": 10,
                    "b1": "2005-08-30T01:01:01.000001",
                    "t1": True,
                    "t2": False,
                    "f1": 99e9,
                },
                {
                    "a1": param_types.INT64,
                    "b1": param_types.STRING,
                    "t1": param_types.BOOL,
                    "t2": param_types.BOOL,
                    "f1": param_types.FLOAT64,
                },
            ),
            (
                {
                    "a1": 10,
                    "b1": "2019-11-26T02:55:41.000000Z"
                },
                {
                    "a1": param_types.INT64,
                    "b1": param_types.STRING
                },
            ),
            (
                {
                    "a1": 10,
                    "b1": TimestampStr("2019-11-26T02:55:41.000000Z"),
                    "dt1": datetime.datetime(2011, 9, 1, 13, 20, 30),
                    "d1": datetime.date(2011, 9, 1),
                },
                {
                    "a1": param_types.INT64,
                    "b1": param_types.TIMESTAMP,
                    "dt1": param_types.TIMESTAMP,
                    "d1": param_types.DATE,
                },
            ),
            (
                {
                    "a1": 10,
                    "b1": TimestampStr("2019-11-26T02:55:41.000000Z")
                },
                {
                    "a1": param_types.INT64,
                    "b1": param_types.TIMESTAMP
                },
            ),
            ({
                "a1": b"bytes"
            }, {
                "a1": param_types.BYTES
            }),
            ({
                "a1": 10,
                "b1": None
            }, {
                "a1": param_types.INT64
            }),
            (None, None),
        ]

        for i, (params, want_param_types) in enumerate(cases):
            with self.subTest(i=i):
                got_param_types = get_param_types(params)
                self.assertEqual(got_param_types, want_param_types)