Example #1
0
 def __init__(self, database, request_options=None):
     self._database = database
     self._session = self._batch = None
     if request_options is None:
         self._request_options = RequestOptions()
     elif type(request_options) == dict:
         self._request_options = RequestOptions(request_options)
     else:
         self._request_options = request_options
Example #2
0
    def commit(self, return_commit_stats=False, request_options=None):
        """Commit mutations to the database.

        :type return_commit_stats: bool
        :param return_commit_stats:
          If true, the response will return commit stats which can be accessed though commit_stats.

        :type request_options:
            :class:`google.cloud.spanner_v1.types.RequestOptions`
        :param request_options:
                (Optional) Common options for this request.
                If a dict is provided, it must be of the same form as the protobuf
                message :class:`~google.cloud.spanner_v1.types.RequestOptions`.

        :rtype: datetime
        :returns: timestamp of the committed changes.
        :raises ValueError: if there are no mutations to commit.
        """
        self._check_state()

        database = self._session._database
        api = database.spanner_api
        metadata = _metadata_with_prefix(database.name)
        trace_attributes = {"num_mutations": len(self._mutations)}

        if type(request_options) == dict:
            request_options = RequestOptions(request_options)

        request = CommitRequest(
            session=self._session.name,
            mutations=self._mutations,
            transaction_id=self._transaction_id,
            return_commit_stats=return_commit_stats,
            request_options=request_options,
        )
        with trace_call("CloudSpanner.Commit", self._session,
                        trace_attributes):
            response = api.commit(
                request=request,
                metadata=metadata,
            )
        self.committed = response.commit_timestamp
        if return_commit_stats:
            self.commit_stats = response.commit_stats
        del self._session._transaction
        return self.committed
Example #3
0
    def execute_sql(
        self,
        sql,
        params=None,
        param_types=None,
        query_mode=None,
        query_options=None,
        request_options=None,
        partition=None,
        retry=gapic_v1.method.DEFAULT,
        timeout=gapic_v1.method.DEFAULT,
    ):
        """Perform an ``ExecuteStreamingSql`` API request.

        :type sql: str
        :param sql: SQL query statement

        :type params: dict, {str -> column value}
        :param params: values for parameter replacement.  Keys must match
                       the names used in ``sql``.

        :type param_types: dict[str -> Union[dict, .types.Type]]
        :param param_types:
            (Optional) maps explicit types for one or more param values;
            required if parameters are passed.

        :type query_mode:
            :class:`~google.cloud.spanner_v1.types.ExecuteSqlRequest.QueryMode`
        :param query_mode: Mode governing return of results / query plan.
            See:
            `QueryMode <https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.ExecuteSqlRequest.QueryMode>`_.

        :type query_options:
            :class:`~google.cloud.spanner_v1.types.ExecuteSqlRequest.QueryOptions`
                or :class:`dict`
        :param query_options:
                (Optional) Query optimizer configuration to use for the given query.
                If a dict is provided, it must be of the same form as the protobuf
                message :class:`~google.cloud.spanner_v1.types.QueryOptions`

        :type request_options:
            :class:`google.cloud.spanner_v1.types.RequestOptions`
        :param request_options:
                (Optional) Common options for this request.
                If a dict is provided, it must be of the same form as the protobuf
                message :class:`~google.cloud.spanner_v1.types.RequestOptions`.

        :type partition: bytes
        :param partition: (Optional) one of the partition tokens returned
                          from :meth:`partition_query`.

        :rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet`
        :returns: a result set instance which can be used to consume rows.

        :type retry: :class:`~google.api_core.retry.Retry`
        :param retry: (Optional) The retry settings for this request.

        :type timeout: float
        :param timeout: (Optional) The timeout for this request.

        :raises ValueError:
            for reuse of single-use snapshots, or if a transaction ID is
            already pending for multiple-use snapshots.
        """
        if self._read_request_count > 0:
            if not self._multi_use:
                raise ValueError("Cannot re-use single-use snapshot.")
            if self._transaction_id is None:
                raise ValueError("Transaction ID pending.")

        if params is not None:
            if param_types is None:
                raise ValueError(
                    "Specify 'param_types' when passing 'params'.")
            params_pb = Struct(fields={
                key: _make_value_pb(value)
                for key, value in params.items()
            })
        else:
            params_pb = {}

        database = self._session._database
        metadata = _metadata_with_prefix(database.name)
        transaction = self._make_txn_selector()
        api = database.spanner_api

        # Query-level options have higher precedence than client-level and
        # environment-level options
        default_query_options = database._instance._client._query_options
        query_options = _merge_query_options(default_query_options,
                                             query_options)

        if type(request_options) == dict:
            request_options = RequestOptions(request_options)

        request = ExecuteSqlRequest(
            session=self._session.name,
            sql=sql,
            transaction=transaction,
            params=params_pb,
            param_types=param_types,
            query_mode=query_mode,
            partition_token=partition,
            seqno=self._execute_sql_count,
            query_options=query_options,
            request_options=request_options,
        )
        restart = functools.partial(
            api.execute_streaming_sql,
            request=request,
            metadata=metadata,
            retry=retry,
            timeout=timeout,
        )

        trace_attributes = {"db.statement": sql}
        iterator = _restart_on_unavailable(
            restart,
            request,
            "CloudSpanner.ReadWriteTransaction",
            self._session,
            trace_attributes,
        )

        self._read_request_count += 1
        self._execute_sql_count += 1

        if self._multi_use:
            return StreamedResultSet(iterator, source=self)
        else:
            return StreamedResultSet(iterator)
Example #4
0
    def read(
        self,
        table,
        columns,
        keyset,
        index="",
        limit=0,
        partition=None,
        request_options=None,
        *,
        retry=gapic_v1.method.DEFAULT,
        timeout=gapic_v1.method.DEFAULT,
    ):
        """Perform a ``StreamingRead`` API request for rows in a table.

        :type table: str
        :param table: name of the table from which to fetch data

        :type columns: list of str
        :param columns: names of columns to be retrieved

        :type keyset: :class:`~google.cloud.spanner_v1.keyset.KeySet`
        :param keyset: keys / ranges identifying rows to be retrieved

        :type index: str
        :param index: (Optional) name of index to use, rather than the
                      table's primary key

        :type limit: int
        :param limit: (Optional) maximum number of rows to return.
                      Incompatible with ``partition``.

        :type partition: bytes
        :param partition: (Optional) one of the partition tokens returned
                          from :meth:`partition_read`.  Incompatible with
                          ``limit``.

        :type request_options:
            :class:`google.cloud.spanner_v1.types.RequestOptions`
        :param request_options:
                (Optional) Common options for this request.
                If a dict is provided, it must be of the same form as the protobuf
                message :class:`~google.cloud.spanner_v1.types.RequestOptions`.

        :type retry: :class:`~google.api_core.retry.Retry`
        :param retry: (Optional) The retry settings for this request.

        :type timeout: float
        :param timeout: (Optional) The timeout for this request.

        :rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet`
        :returns: a result set instance which can be used to consume rows.

        :raises ValueError:
            for reuse of single-use snapshots, or if a transaction ID is
            already pending for multiple-use snapshots.
        """
        if self._read_request_count > 0:
            if not self._multi_use:
                raise ValueError("Cannot re-use single-use snapshot.")
            if self._transaction_id is None:
                raise ValueError("Transaction ID pending.")

        database = self._session._database
        api = database.spanner_api
        metadata = _metadata_with_prefix(database.name)
        transaction = self._make_txn_selector()

        if type(request_options) == dict:
            request_options = RequestOptions(request_options)

        request = ReadRequest(
            session=self._session.name,
            table=table,
            columns=columns,
            key_set=keyset._to_pb(),
            transaction=transaction,
            index=index,
            limit=limit,
            partition_token=partition,
            request_options=request_options,
        )
        restart = functools.partial(
            api.streaming_read,
            request=request,
            metadata=metadata,
            retry=retry,
            timeout=timeout,
        )

        trace_attributes = {"table_id": table, "columns": columns}
        iterator = _restart_on_unavailable(
            restart,
            request,
            "CloudSpanner.ReadOnlyTransaction",
            self._session,
            trace_attributes,
        )

        self._read_request_count += 1

        if self._multi_use:
            return StreamedResultSet(iterator, source=self)
        else:
            return StreamedResultSet(iterator)
Example #5
0
    def batch_update(self, statements, request_options=None):
        """Perform a batch of DML statements via an ``ExecuteBatchDml`` request.

        :type statements:
            Sequence[Union[ str, Tuple[str, Dict[str, Any], Dict[str, Union[dict, .types.Type]]]]]

        :param statements:
            List of DML statements, with optional params / param types.
            If passed, 'params' is a dict mapping names to the values
            for parameter replacement.  Keys must match the names used in the
            corresponding DML statement.  If 'params' is passed, 'param_types'
            must also be passed, as a dict mapping names to the type of
            value passed in 'params'.

        :type request_options:
            :class:`google.cloud.spanner_v1.types.RequestOptions`
        :param request_options:
                (Optional) Common options for this request.
                If a dict is provided, it must be of the same form as the protobuf
                message :class:`~google.cloud.spanner_v1.types.RequestOptions`.

        :rtype:
            Tuple(status, Sequence[int])
        :returns:
            Status code, plus counts of rows affected by each completed DML
            statement.  Note that if the status code is not ``OK``, the
            statement triggering the error will not have an entry in the
            list, nor will any statements following that one.
        """
        parsed = []
        for statement in statements:
            if isinstance(statement, str):
                parsed.append(ExecuteBatchDmlRequest.Statement(sql=statement))
            else:
                dml, params, param_types = statement
                params_pb = self._make_params_pb(params, param_types)
                parsed.append(
                    ExecuteBatchDmlRequest.Statement(sql=dml,
                                                     params=params_pb,
                                                     param_types=param_types))

        database = self._session._database
        metadata = _metadata_with_prefix(database.name)
        transaction = self._make_txn_selector()
        api = database.spanner_api

        seqno, self._execute_sql_count = (
            self._execute_sql_count,
            self._execute_sql_count + 1,
        )

        if type(request_options) == dict:
            request_options = RequestOptions(request_options)

        trace_attributes = {
            # Get just the queries from the DML statement batch
            "db.statement": ";".join([statement.sql for statement in parsed])
        }
        request = ExecuteBatchDmlRequest(
            session=self._session.name,
            transaction=transaction,
            statements=parsed,
            seqno=seqno,
            request_options=request_options,
        )
        with trace_call("CloudSpanner.DMLTransaction", self._session,
                        trace_attributes):
            response = api.execute_batch_dml(request=request,
                                             metadata=metadata)
        row_counts = [
            result_set.stats.row_count_exact
            for result_set in response.result_sets
        ]
        return response.status, row_counts
Example #6
0
    def execute_update(
        self,
        dml,
        params=None,
        param_types=None,
        query_mode=None,
        query_options=None,
        request_options=None,
        *,
        retry=gapic_v1.method.DEFAULT,
        timeout=gapic_v1.method.DEFAULT,
    ):
        """Perform an ``ExecuteSql`` API request with DML.

        :type dml: str
        :param dml: SQL DML statement

        :type params: dict, {str -> column value}
        :param params: values for parameter replacement.  Keys must match
                       the names used in ``dml``.

        :type param_types: dict[str -> Union[dict, .types.Type]]
        :param param_types:
            (Optional) maps explicit types for one or more param values;
            required if parameters are passed.

        :type query_mode:
            :class:`~google.cloud.spanner_v1.types.ExecuteSqlRequest.QueryMode`
        :param query_mode: Mode governing return of results / query plan.
            See:
            `QueryMode <https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.ExecuteSqlRequest.QueryMode>`_.

        :type query_options:
            :class:`~google.cloud.spanner_v1.types.ExecuteSqlRequest.QueryOptions`
            or :class:`dict`
        :param query_options: (Optional) Options that are provided for query plan stability.

        :type request_options:
            :class:`google.cloud.spanner_v1.types.RequestOptions`
        :param request_options:
                (Optional) Common options for this request.
                If a dict is provided, it must be of the same form as the protobuf
                message :class:`~google.cloud.spanner_v1.types.RequestOptions`.

        :type retry: :class:`~google.api_core.retry.Retry`
        :param retry: (Optional) The retry settings for this request.

        :type timeout: float
        :param timeout: (Optional) The timeout for this request.

        :rtype: int
        :returns: Count of rows affected by the DML statement.
        """
        params_pb = self._make_params_pb(params, param_types)
        database = self._session._database
        metadata = _metadata_with_prefix(database.name)
        transaction = self._make_txn_selector()
        api = database.spanner_api

        seqno, self._execute_sql_count = (
            self._execute_sql_count,
            self._execute_sql_count + 1,
        )

        # Query-level options have higher precedence than client-level and
        # environment-level options
        default_query_options = database._instance._client._query_options
        query_options = _merge_query_options(default_query_options,
                                             query_options)

        if type(request_options) == dict:
            request_options = RequestOptions(request_options)

        trace_attributes = {"db.statement": dml}

        request = ExecuteSqlRequest(
            session=self._session.name,
            sql=dml,
            transaction=transaction,
            params=params_pb,
            param_types=param_types,
            query_mode=query_mode,
            query_options=query_options,
            seqno=seqno,
            request_options=request_options,
        )
        with trace_call("CloudSpanner.ReadWriteTransaction", self._session,
                        trace_attributes):
            response = api.execute_sql(request=request,
                                       metadata=metadata,
                                       retry=retry,
                                       timeout=timeout)
        return response.stats.row_count_exact
 def test_batch_update_wo_errors(self):
     self._batch_update_helper(request_options=RequestOptions(
         priority=RequestOptions.Priority.PRIORITY_MEDIUM), )
 def test_execute_update_w_request_options(self):
     self._execute_update_helper(request_options=RequestOptions(
         priority=RequestOptions.Priority.PRIORITY_MEDIUM))
Example #9
0
    def execute_partitioned_dml(
        self,
        dml,
        params=None,
        param_types=None,
        query_options=None,
        request_options=None,
    ):
        """Execute a partitionable DML statement.

        :type dml: str
        :param dml: DML statement

        :type params: dict, {str -> column value}
        :param params: values for parameter replacement.  Keys must match
                       the names used in ``dml``.

        :type param_types: dict[str -> Union[dict, .types.Type]]
        :param param_types:
            (Optional) maps explicit types for one or more param values;
            required if parameters are passed.

        :type query_options:
            :class:`~google.cloud.spanner_v1.types.ExecuteSqlRequest.QueryOptions`
            or :class:`dict`
        :param query_options:
                (Optional) Query optimizer configuration to use for the given query.
                If a dict is provided, it must be of the same form as the protobuf
                message :class:`~google.cloud.spanner_v1.types.QueryOptions`

        :type request_options:
            :class:`google.cloud.spanner_v1.types.RequestOptions`
        :param request_options:
            (Optional) Common options for this request.
            If a dict is provided, it must be of the same form as the protobuf
            message :class:`~google.cloud.spanner_v1.types.RequestOptions`.
            Please note, the `transactionTag` setting will be ignored as it is
            not supported for partitioned DML.

        :rtype: int
        :returns: Count of rows affected by the DML statement.
        """
        query_options = _merge_query_options(
            self._instance._client._query_options, query_options)
        if request_options is None:
            request_options = RequestOptions()
        elif type(request_options) == dict:
            request_options = RequestOptions(request_options)
        request_options.transaction_tag = None

        if params is not None:
            from google.cloud.spanner_v1.transaction import Transaction

            if param_types is None:
                raise ValueError(
                    "Specify 'param_types' when passing 'params'.")
            params_pb = Transaction._make_params_pb(params, param_types)
        else:
            params_pb = {}

        api = self.spanner_api

        txn_options = TransactionOptions(
            partitioned_dml=TransactionOptions.PartitionedDml())

        metadata = _metadata_with_prefix(self.name)

        def execute_pdml():
            with SessionCheckout(self._pool) as session:

                txn = api.begin_transaction(session=session.name,
                                            options=txn_options,
                                            metadata=metadata)

                txn_selector = TransactionSelector(id=txn.id)

                request = ExecuteSqlRequest(
                    session=session.name,
                    sql=dml,
                    transaction=txn_selector,
                    params=params_pb,
                    param_types=param_types,
                    query_options=query_options,
                    request_options=request_options,
                )
                method = functools.partial(
                    api.execute_streaming_sql,
                    metadata=metadata,
                )

                iterator = _restart_on_unavailable(method, request)

                result_set = StreamedResultSet(iterator)
                list(result_set)  # consume all partials

                return result_set.stats.row_count_lower_bound

        return _retry_on_aborted(execute_pdml, DEFAULT_RETRY_BACKOFF)()
Example #10
0
    def _batch_update_helper(self,
                             error_after=None,
                             count=0,
                             request_options=None):
        from google.rpc.status_pb2 import Status
        from google.protobuf.struct_pb2 import Struct
        from google.cloud.spanner_v1 import param_types
        from google.cloud.spanner_v1 import ResultSet
        from google.cloud.spanner_v1 import ResultSetStats
        from google.cloud.spanner_v1 import ExecuteBatchDmlRequest
        from google.cloud.spanner_v1 import ExecuteBatchDmlResponse
        from google.cloud.spanner_v1 import TransactionSelector
        from google.cloud.spanner_v1._helpers import _make_value_pb

        insert_dml = "INSERT INTO table(pkey, desc) VALUES (%pkey, %desc)"
        insert_params = {"pkey": 12345, "desc": "DESCRIPTION"}
        insert_param_types = {
            "pkey": param_types.INT64,
            "desc": param_types.STRING
        }
        update_dml = 'UPDATE table SET desc = desc + "-amended"'
        delete_dml = "DELETE FROM table WHERE desc IS NULL"

        dml_statements = [
            (insert_dml, insert_params, insert_param_types),
            update_dml,
            delete_dml,
        ]

        stats_pbs = [
            ResultSetStats(row_count_exact=1),
            ResultSetStats(row_count_exact=2),
            ResultSetStats(row_count_exact=3),
        ]
        if error_after is not None:
            stats_pbs = stats_pbs[:error_after]
            expected_status = Status(code=400)
        else:
            expected_status = Status(code=200)
        expected_row_counts = [stats.row_count_exact for stats in stats_pbs]

        response = ExecuteBatchDmlResponse(
            status=expected_status,
            result_sets=[ResultSet(stats=stats_pb) for stats_pb in stats_pbs],
        )
        database = _Database()
        api = database.spanner_api = self._make_spanner_api()
        api.execute_batch_dml.return_value = response
        session = _Session(database)
        transaction = self._make_one(session)
        transaction._transaction_id = self.TRANSACTION_ID
        transaction.transaction_tag = self.TRANSACTION_TAG
        transaction._execute_sql_count = count

        if request_options is None:
            request_options = RequestOptions()
        elif type(request_options) == dict:
            request_options = RequestOptions(request_options)

        status, row_counts = transaction.batch_update(
            dml_statements, request_options=request_options)

        self.assertEqual(status, expected_status)
        self.assertEqual(row_counts, expected_row_counts)

        expected_transaction = TransactionSelector(id=self.TRANSACTION_ID)
        expected_insert_params = Struct(fields={
            key: _make_value_pb(value)
            for (key, value) in insert_params.items()
        })
        expected_statements = [
            ExecuteBatchDmlRequest.Statement(
                sql=insert_dml,
                params=expected_insert_params,
                param_types=insert_param_types,
            ),
            ExecuteBatchDmlRequest.Statement(sql=update_dml),
            ExecuteBatchDmlRequest.Statement(sql=delete_dml),
        ]
        expected_request_options = request_options
        expected_request_options.transaction_tag = self.TRANSACTION_TAG

        expected_request = ExecuteBatchDmlRequest(
            session=self.SESSION_NAME,
            transaction=expected_transaction,
            statements=expected_statements,
            seqno=count,
            request_options=expected_request_options,
        )
        api.execute_batch_dml.assert_called_once_with(
            request=expected_request,
            metadata=[("google-cloud-resource-prefix", database.name)],
        )

        self.assertEqual(transaction._execute_sql_count, count + 1)
Example #11
0
 def test_batch_update_w_request_and_transaction_tag_success(self):
     request_options = RequestOptions(
         request_tag="tag-1",
         transaction_tag="tag-1-1",
     )
     self._batch_update_helper(request_options=request_options)
Example #12
0
 def test_execute_update_w_transaction_tag_success(self):
     request_options = RequestOptions(transaction_tag="tag-1-1", )
     self._execute_update_helper(request_options=request_options)
Example #13
0
    def _execute_update_helper(
        self,
        count=0,
        query_options=None,
        request_options=None,
        retry=gapic_v1.method.DEFAULT,
        timeout=gapic_v1.method.DEFAULT,
    ):
        from google.protobuf.struct_pb2 import Struct
        from google.cloud.spanner_v1 import (
            ResultSet,
            ResultSetStats,
        )
        from google.cloud.spanner_v1 import TransactionSelector
        from google.cloud.spanner_v1._helpers import (
            _make_value_pb,
            _merge_query_options,
        )
        from google.cloud.spanner_v1 import ExecuteSqlRequest

        MODE = 2  # PROFILE
        stats_pb = ResultSetStats(row_count_exact=1)
        database = _Database()
        api = database.spanner_api = self._make_spanner_api()
        api.execute_sql.return_value = ResultSet(stats=stats_pb)
        session = _Session(database)
        transaction = self._make_one(session)
        transaction._transaction_id = self.TRANSACTION_ID
        transaction.transaction_tag = self.TRANSACTION_TAG
        transaction._execute_sql_count = count

        if request_options is None:
            request_options = RequestOptions()
        elif type(request_options) == dict:
            request_options = RequestOptions(request_options)

        row_count = transaction.execute_update(
            DML_QUERY_WITH_PARAM,
            PARAMS,
            PARAM_TYPES,
            query_mode=MODE,
            query_options=query_options,
            request_options=request_options,
            retry=retry,
            timeout=timeout,
        )

        self.assertEqual(row_count, 1)

        expected_transaction = TransactionSelector(id=self.TRANSACTION_ID)
        expected_params = Struct(fields={
            key: _make_value_pb(value)
            for (key, value) in PARAMS.items()
        })

        expected_query_options = database._instance._client._query_options
        if query_options:
            expected_query_options = _merge_query_options(
                expected_query_options, query_options)
        expected_request_options = request_options
        expected_request_options.transaction_tag = self.TRANSACTION_TAG

        expected_request = ExecuteSqlRequest(
            session=self.SESSION_NAME,
            sql=DML_QUERY_WITH_PARAM,
            transaction=expected_transaction,
            params=expected_params,
            param_types=PARAM_TYPES,
            query_mode=MODE,
            query_options=expected_query_options,
            request_options=request_options,
            seqno=count,
        )
        api.execute_sql.assert_called_once_with(
            request=expected_request,
            retry=retry,
            timeout=timeout,
            metadata=[("google-cloud-resource-prefix", database.name)],
        )

        self.assertEqual(transaction._execute_sql_count, count + 1)
Example #14
0
 def test_commit_w_request_and_transaction_tag_success(self):
     request_options = RequestOptions(
         request_tag="tag-1",
         transaction_tag="tag-1-1",
     )
     self._commit_helper(request_options=request_options)
Example #15
0
 def test_commit_w_transaction_tag_ignored_success(self):
     request_options = RequestOptions(transaction_tag="tag-1-1", )
     self._commit_helper(request_options=request_options)
Example #16
0
    def _commit_helper(self,
                       mutate=True,
                       return_commit_stats=False,
                       request_options=None):
        import datetime
        from google.cloud.spanner_v1 import CommitResponse
        from google.cloud.spanner_v1.keyset import KeySet
        from google.cloud._helpers import UTC

        now = datetime.datetime.utcnow().replace(tzinfo=UTC)
        keys = [[0], [1], [2]]
        keyset = KeySet(keys=keys)
        response = CommitResponse(commit_timestamp=now)
        if return_commit_stats:
            response.commit_stats.mutation_count = 4
        database = _Database()
        api = database.spanner_api = _FauxSpannerAPI(_commit_response=response)
        session = _Session(database)
        transaction = self._make_one(session)
        transaction._transaction_id = self.TRANSACTION_ID
        transaction.transaction_tag = self.TRANSACTION_TAG

        if mutate:
            transaction.delete(TABLE_NAME, keyset)

        transaction.commit(return_commit_stats=return_commit_stats,
                           request_options=request_options)

        self.assertEqual(transaction.committed, now)
        self.assertIsNone(session._transaction)

        session_id, mutations, txn_id, actual_request_options, metadata = api._committed

        if request_options is None:
            expected_request_options = RequestOptions(
                transaction_tag=self.TRANSACTION_TAG)
        elif type(request_options) == dict:
            expected_request_options = RequestOptions(request_options)
            expected_request_options.transaction_tag = self.TRANSACTION_TAG
            expected_request_options.request_tag = None
        else:
            expected_request_options = request_options
            expected_request_options.transaction_tag = self.TRANSACTION_TAG
            expected_request_options.request_tag = None

        self.assertEqual(session_id, session.name)
        self.assertEqual(txn_id, self.TRANSACTION_ID)
        self.assertEqual(mutations, transaction._mutations)
        self.assertEqual(metadata,
                         [("google-cloud-resource-prefix", database.name)])
        self.assertEqual(actual_request_options, expected_request_options)

        if return_commit_stats:
            self.assertEqual(transaction.commit_stats.mutation_count, 4)

        self.assertSpanAttributes(
            "CloudSpanner.Commit",
            attributes=dict(
                TestTransaction.BASE_ATTRIBUTES,
                num_mutations=len(transaction._mutations),
            ),
        )