예제 #1
0
    def batch_update(self, statements, request_options=None):
        """Perform a batch of DML statements via an ``ExecuteBatchDml`` request.

        :type statements:
            Sequence[Union[ str, Tuple[str, Dict[str, Any], Dict[str, Union[dict, .types.Type]]]]]

        :param statements:
            List of DML statements, with optional params / param types.
            If passed, 'params' is a dict mapping names to the values
            for parameter replacement.  Keys must match the names used in the
            corresponding DML statement.  If 'params' is passed, 'param_types'
            must also be passed, as a dict mapping names to the type of
            value passed in 'params'.

        :type request_options:
            :class:`google.cloud.spanner_v1.types.RequestOptions`
        :param request_options:
                (Optional) Common options for this request.
                If a dict is provided, it must be of the same form as the protobuf
                message :class:`~google.cloud.spanner_v1.types.RequestOptions`.

        :rtype:
            Tuple(status, Sequence[int])
        :returns:
            Status code, plus counts of rows affected by each completed DML
            statement.  Note that if the status code is not ``OK``, the
            statement triggering the error will not have an entry in the
            list, nor will any statements following that one.
        """
        parsed = []
        for statement in statements:
            if isinstance(statement, str):
                parsed.append(ExecuteBatchDmlRequest.Statement(sql=statement))
            else:
                dml, params, param_types = statement
                params_pb = self._make_params_pb(params, param_types)
                parsed.append(
                    ExecuteBatchDmlRequest.Statement(sql=dml,
                                                     params=params_pb,
                                                     param_types=param_types))

        database = self._session._database
        metadata = _metadata_with_prefix(database.name)
        transaction = self._make_txn_selector()
        api = database.spanner_api

        seqno, self._execute_sql_count = (
            self._execute_sql_count,
            self._execute_sql_count + 1,
        )

        if type(request_options) == dict:
            request_options = RequestOptions(request_options)

        trace_attributes = {
            # Get just the queries from the DML statement batch
            "db.statement": ";".join([statement.sql for statement in parsed])
        }
        request = ExecuteBatchDmlRequest(
            session=self._session.name,
            transaction=transaction,
            statements=parsed,
            seqno=seqno,
            request_options=request_options,
        )
        with trace_call("CloudSpanner.DMLTransaction", self._session,
                        trace_attributes):
            response = api.execute_batch_dml(request=request,
                                             metadata=metadata)
        row_counts = [
            result_set.stats.row_count_exact
            for result_set in response.result_sets
        ]
        return response.status, row_counts
예제 #2
0
    def _batch_update_helper(self, error_after=None, count=0):
        from google.rpc.status_pb2 import Status
        from google.protobuf.struct_pb2 import Struct
        from google.cloud.spanner_v1 import param_types
        from google.cloud.spanner_v1 import ResultSet
        from google.cloud.spanner_v1 import ResultSetStats
        from google.cloud.spanner_v1 import ExecuteBatchDmlRequest
        from google.cloud.spanner_v1 import ExecuteBatchDmlResponse
        from google.cloud.spanner_v1 import TransactionSelector
        from google.cloud.spanner_v1._helpers import _make_value_pb

        insert_dml = "INSERT INTO table(pkey, desc) VALUES (%pkey, %desc)"
        insert_params = {"pkey": 12345, "desc": "DESCRIPTION"}
        insert_param_types = {"pkey": param_types.INT64, "desc": param_types.STRING}
        update_dml = 'UPDATE table SET desc = desc + "-amended"'
        delete_dml = "DELETE FROM table WHERE desc IS NULL"

        dml_statements = [
            (insert_dml, insert_params, insert_param_types),
            update_dml,
            delete_dml,
        ]

        stats_pbs = [
            ResultSetStats(row_count_exact=1),
            ResultSetStats(row_count_exact=2),
            ResultSetStats(row_count_exact=3),
        ]
        if error_after is not None:
            stats_pbs = stats_pbs[:error_after]
            expected_status = Status(code=400)
        else:
            expected_status = Status(code=200)
        expected_row_counts = [stats.row_count_exact for stats in stats_pbs]

        response = ExecuteBatchDmlResponse(
            status=expected_status,
            result_sets=[ResultSet(stats=stats_pb) for stats_pb in stats_pbs],
        )
        database = _Database()
        api = database.spanner_api = self._make_spanner_api()
        api.execute_batch_dml.return_value = response
        session = _Session(database)
        transaction = self._make_one(session)
        transaction._transaction_id = self.TRANSACTION_ID
        transaction._execute_sql_count = count

        status, row_counts = transaction.batch_update(dml_statements)

        self.assertEqual(status, expected_status)
        self.assertEqual(row_counts, expected_row_counts)

        expected_transaction = TransactionSelector(id=self.TRANSACTION_ID)
        expected_insert_params = Struct(
            fields={
                key: _make_value_pb(value) for (key, value) in insert_params.items()
            }
        )
        expected_statements = [
            ExecuteBatchDmlRequest.Statement(
                sql=insert_dml,
                params=expected_insert_params,
                param_types=insert_param_types,
            ),
            ExecuteBatchDmlRequest.Statement(sql=update_dml),
            ExecuteBatchDmlRequest.Statement(sql=delete_dml),
        ]

        expected_request = ExecuteBatchDmlRequest(
            session=self.SESSION_NAME,
            transaction=expected_transaction,
            statements=expected_statements,
            seqno=count,
        )
        api.execute_batch_dml.assert_called_once_with(
            request=expected_request,
            metadata=[("google-cloud-resource-prefix", database.name)],
        )

        self.assertEqual(transaction._execute_sql_count, count + 1)