def test_read_grpc_error(self): from google.cloud.proto.spanner.v1.transaction_pb2 import ( TransactionSelector) from google.gax.errors import GaxError from google.cloud.spanner.keyset import KeySet KEYSET = KeySet(all_=True) database = _Database() api = database.spanner_api = _FauxSpannerAPI(_random_gax_error=True) session = _Session(database) derived = self._makeDerived(session) with self.assertRaises(GaxError): list(derived.read(TABLE_NAME, COLUMNS, KEYSET)) (r_session, table, columns, key_set, transaction, index, limit, resume_token, options) = api._streaming_read_with self.assertEqual(r_session, self.SESSION_NAME) self.assertTrue(transaction.single_use.read_only.strong) self.assertEqual(table, TABLE_NAME) self.assertEqual(columns, COLUMNS) self.assertEqual(key_set, KEYSET.to_pb()) self.assertIsInstance(transaction, TransactionSelector) self.assertEqual(index, '') self.assertEqual(limit, 0) self.assertEqual(resume_token, b'') self.assertEqual(options.kwargs['metadata'], [('google-cloud-resource-prefix', database.name)])
def test_read_w_ranges(self): ROW_COUNT = 4000 START = 1000 END = 2000 session, committed = self._set_up_table(ROW_COUNT) snapshot = session.snapshot(read_timestamp=committed, multi_use=True) all_data_rows = list(self._row_data(ROW_COUNT)) closed_closed = KeyRange(start_closed=[START], end_closed=[END]) keyset = KeySet(ranges=(closed_closed,)) rows = list(snapshot.read( self.TABLE, self.COLUMNS, keyset)) expected = all_data_rows[START:END+1] self._check_row_data(rows, expected) closed_open = KeyRange(start_closed=[START], end_open=[END]) keyset = KeySet(ranges=(closed_open,)) rows = list(snapshot.read( self.TABLE, self.COLUMNS, keyset)) expected = all_data_rows[START:END] self._check_row_data(rows, expected) open_open = KeyRange(start_open=[START], end_open=[END]) keyset = KeySet(ranges=(open_open,)) rows = list(snapshot.read( self.TABLE, self.COLUMNS, keyset)) expected = all_data_rows[START+1:END] self._check_row_data(rows, expected) open_closed = KeyRange(start_open=[START], end_closed=[END]) keyset = KeySet(ranges=(open_closed,)) rows = list(snapshot.read( self.TABLE, self.COLUMNS, keyset)) expected = all_data_rows[START+1:END+1] self._check_row_data(rows, expected)
def _read_w_concurrent_update(self, transaction, pkey): keyset = KeySet(keys=[(pkey, )]) rows = list(transaction.read(COUNTERS_TABLE, COUNTERS_COLUMNS, keyset)) self.assertEqual(len(rows), 1) pkey, value = rows[0] transaction.update(COUNTERS_TABLE, COUNTERS_COLUMNS, [[pkey, value + 1]])
def test_commit_ok(self): import datetime from google.cloud.proto.spanner.v1.spanner_pb2 import CommitResponse from google.cloud.spanner.keyset import KeySet from google.cloud._helpers import UTC from google.cloud._helpers import _datetime_to_pb_timestamp now = datetime.datetime.utcnow().replace(tzinfo=UTC) now_pb = _datetime_to_pb_timestamp(now) keys = [[0], [1], [2]] keyset = KeySet(keys=keys) response = CommitResponse(commit_timestamp=now_pb) database = _Database() api = database.spanner_api = _FauxSpannerAPI( _commit_response=response) session = _Session(database) transaction = self._make_one(session) transaction._transaction_id = self.TRANSACTION_ID transaction.delete(TABLE_NAME, keyset) transaction.commit() self.assertEqual(transaction.committed, now) self.assertIsNone(session._transaction) session_id, mutations, txn_id, options = api._committed self.assertEqual(session_id, session.name) self.assertEqual(txn_id, self.TRANSACTION_ID) self.assertEqual(mutations, transaction._mutations) self.assertEqual(options.kwargs['metadata'], [('google-cloud-resource-prefix', database.name)])
def test_read(self): from google.cloud.spanner.keyset import KeySet TABLE_NAME = 'citizens' COLUMNS = ['email', 'first_name', 'last_name', 'age'] KEYS = ['*****@*****.**', '*****@*****.**'] KEYSET = KeySet(keys=KEYS) INDEX = 'email-address-index' LIMIT = 20 TOKEN = b'DEADBEEF' client = _Client() instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() session = _Session() pool.put(session) database = self._make_one(self.DATABASE_ID, instance, pool=pool) rows = list(database.read( TABLE_NAME, COLUMNS, KEYSET, INDEX, LIMIT, TOKEN)) self.assertEqual(rows, []) (table, columns, key_set, index, limit, resume_token) = session._read_with self.assertEqual(table, TABLE_NAME) self.assertEqual(columns, COLUMNS) self.assertEqual(key_set, KEYSET) self.assertEqual(index, INDEX) self.assertEqual(limit, LIMIT) self.assertEqual(resume_token, TOKEN)
def test_read_w_single_key(self): ROW_COUNT = 40 session, committed = self._set_up_table(ROW_COUNT) snapshot = session.snapshot(read_timestamp=committed) rows = list(snapshot.read( self.TABLE, self.COLUMNS, KeySet(keys=[(0,)]))) all_data_rows = list(self._row_data(ROW_COUNT)) expected = [all_data_rows[0]] self._check_row_data(rows, expected)
def test_read(self): from google.cloud.spanner import session as MUT from google.cloud._testing import _Monkey from google.cloud.spanner.keyset import KeySet TABLE_NAME = 'citizens' COLUMNS = ['email', 'first_name', 'last_name', 'age'] KEYS = ['*****@*****.**', '*****@*****.**'] KEYSET = KeySet(keys=KEYS) INDEX = 'email-address-index' LIMIT = 20 TOKEN = b'DEADBEEF' database = _Database(self.DATABASE_NAME) session = self._make_one(database) session._session_id = 'DEADBEEF' _read_with = [] expected = object() class _Snapshot(object): def __init__(self, session, **kwargs): self._session = session self._kwargs = kwargs.copy() def read(self, table, columns, keyset, index='', limit=0, resume_token=b''): _read_with.append( (table, columns, keyset, index, limit, resume_token)) return expected with _Monkey(MUT, Snapshot=_Snapshot): found = session.read(TABLE_NAME, COLUMNS, KEYSET, index=INDEX, limit=LIMIT, resume_token=TOKEN) self.assertIs(found, expected) self.assertEqual(len(_read_with), 1) (table, columns, key_set, index, limit, resume_token) = _read_with[0] self.assertEqual(table, TABLE_NAME) self.assertEqual(columns, COLUMNS) self.assertEqual(key_set, KEYSET) self.assertEqual(index, INDEX) self.assertEqual(limit, LIMIT) self.assertEqual(resume_token, TOKEN)
def test_read_not_created(self): from google.cloud.spanner.keyset import KeySet TABLE_NAME = 'citizens' COLUMNS = ['email', 'first_name', 'last_name', 'age'] KEYS = ['*****@*****.**', '*****@*****.**'] KEYSET = KeySet(keys=KEYS) database = _Database(self.DATABASE_NAME) session = self._makeOne(database) with self.assertRaises(ValueError): session.read(TABLE_NAME, COLUMNS, KEYSET)
def test_commit_already_committed(self): from google.cloud.spanner.keyset import KeySet keys = [[0], [1], [2]] keyset = KeySet(keys=keys) database = _Database() session = _Session(database) batch = self._makeOne(session) batch.committed = object() batch.delete(TABLE_NAME, keyset=keyset) with self.assertRaises(ValueError): batch.commit()
def test_read_w_multiple_keys(self): ROW_COUNT = 40 indices = [0, 5, 17] session, committed = self._set_up_table(ROW_COUNT) snapshot = session.snapshot(read_timestamp=committed) rows = list(snapshot.read( self.TABLE, self.COLUMNS, KeySet(keys=[(index,) for index in indices]))) all_data_rows = list(self._row_data(ROW_COUNT)) expected = [row for row in all_data_rows if row[0] in indices] self._check_row_data(rows, expected)
def _handle_abort_unit_of_work(self, transaction): keyset_1 = KeySet(keys=[(self.KEY1,)]) rows_1 = list( transaction.read(COUNTERS_TABLE, COUNTERS_COLUMNS, keyset_1)) assert len(rows_1) == 1 row_1 = rows_1[0] value_1 = row_1[1] self.handler_running.set() self.provoker_done.wait() keyset_2 = KeySet(keys=[(self.KEY2,)]) rows_2 = list( transaction.read(COUNTERS_TABLE, COUNTERS_COLUMNS, keyset_2)) assert len(rows_2) == 1 row_2 = rows_2[0] value_2 = row_2[1] transaction.update( COUNTERS_TABLE, COUNTERS_COLUMNS, [[self.KEY2, value_1 + value_2]])
def _provoke_abort_unit_of_work(self, transaction): keyset = KeySet(keys=[(self.KEY1, )]) rows = list(transaction.read(COUNTERS_TABLE, COUNTERS_COLUMNS, keyset)) assert len(rows) == 1 row = rows[0] value = row[1] self.provoker_started.set() self.handler_running.wait() transaction.update(COUNTERS_TABLE, COUNTERS_COLUMNS, [[self.KEY1, value + 1]])
def test_read_w_range(self): from google.cloud.spanner.keyset import KeyRange ROW_COUNT = 4000 START_CLOSED = 1000 END_OPEN = 2000 session, committed = self._set_up_table(ROW_COUNT) key_range = KeyRange(start_closed=[START_CLOSED], end_open=[END_OPEN]) keyset = KeySet(ranges=(key_range, )) snapshot = session.snapshot(read_timestamp=committed) rows = list(snapshot.read(self.TABLE, self.COLUMNS, keyset)) all_data_rows = list(self._row_data(ROW_COUNT)) expected = all_data_rows[START_CLOSED:END_OPEN] self._check_row_data(rows, expected)
def populate_table_2_columns(database, table_name, row_count, val_size): all_ = KeySet(all_=True) columns = ('pkey', 'chunk_me', 'chunk_me_2') rows = list( database.execute_sql('SELECT COUNT(*) FROM {}'.format(table_name))) assert len(rows) == 1 count = rows[0][0] if count != row_count: print_func("Repopulating table: {}".format(table_name)) chunk_me = 'X' * val_size row_data = [(index, chunk_me, chunk_me) for index in range(row_count)] with database.batch() as batch: batch.delete(table_name, all_) batch.insert(table_name, columns, row_data) else: print_func("Leaving table: {}".format(table_name))
def populate_table(database, table_desc): all_ = KeySet(all_=True) columns = ('pkey', 'chunk_me') with database.snapshot() as snapshot: rows = list(snapshot.execute_sql( 'SELECT COUNT(*) FROM {}'.format(table_desc.table))) assert len(rows) == 1 count = rows[0][0] if count != table_desc.row_count: print_func("Repopulating table: {}".format(table_desc.table)) chunk_me = table_desc.value() row_data = [(index, chunk_me) for index in range(table_desc.row_count)] with database.batch() as batch: batch.delete(table_desc.table, all_) batch.insert(table_desc.table, columns, row_data) else: print_func("Leaving table: {}".format(table_desc.table))
def _transaction_concurrency_helper(self, unit_of_work, pkey): INITIAL_VALUE = 123 NUM_THREADS = 3 # conforms to equivalent Java systest. retry = RetryInstanceState(_has_all_ddl) retry(self._db.reload)() session = self._db.session() session.create() self.to_delete.append(session) with session.batch() as batch: batch.insert_or_update( self.COUNTERS_TABLE, self.COUNTERS_COLUMNS, [[pkey, INITIAL_VALUE]]) # We don't want to run the threads' transactions in the current # session, which would fail. txn_sessions = [] for _ in range(NUM_THREADS): txn_session = self._db.session() txn_sessions.append(txn_session) txn_session.create() self.to_delete.append(txn_session) threads = [ threading.Thread( target=txn_session.run_in_transaction, args=(unit_of_work, pkey)) for txn_session in txn_sessions] for thread in threads: thread.start() for thread in threads: thread.join() keyset = KeySet(keys=[(pkey,)]) rows = list(session.read( self.COUNTERS_TABLE, self.COUNTERS_COLUMNS, keyset)) self.assertEqual(len(rows), 1) _, value = rows[0] self.assertEqual(value, INITIAL_VALUE + len(threads))
class _TestData(object): TABLE = 'contacts' COLUMNS = ('contact_id', 'first_name', 'last_name', 'email') ROW_DATA = ( (1, u'Phred', u'Phlyntstone', u'*****@*****.**'), (2, u'Bharney', u'Rhubble', u'*****@*****.**'), (3, u'Wylma', u'Phlyntstone', u'*****@*****.**'), ) ALL = KeySet(all_=True) SQL = 'SELECT * FROM contacts ORDER BY contact_id' def _assert_timestamp(self, value, nano_value): self.assertIsInstance(value, datetime.datetime) self.assertIsNone(value.tzinfo) self.assertIs(nano_value.tzinfo, UTC) self.assertEqual(value.year, nano_value.year) self.assertEqual(value.month, nano_value.month) self.assertEqual(value.day, nano_value.day) self.assertEqual(value.hour, nano_value.hour) self.assertEqual(value.minute, nano_value.minute) self.assertEqual(value.second, nano_value.second) self.assertEqual(value.microsecond, nano_value.microsecond) if isinstance(value, TimestampWithNanoseconds): self.assertEqual(value.nanosecond, nano_value.nanosecond) else: self.assertEqual(value.microsecond * 1000, nano_value.nanosecond) def _check_row_data(self, row_data, expected=None): if expected is None: expected = self.ROW_DATA self.assertEqual(len(row_data), len(expected)) for found, expected in zip(row_data, expected): self.assertEqual(len(found), len(expected)) for found_cell, expected_cell in zip(found, expected): if isinstance(found_cell, TimestampWithNanoseconds): self._assert_timestamp(expected_cell, found_cell) elif isinstance(found_cell, float) and math.isnan(found_cell): self.assertTrue(math.isnan(expected_cell)) else: self.assertEqual(found_cell, expected_cell)
def test_delete(self): from google.cloud.proto.spanner.v1.mutation_pb2 import Mutation from google.cloud.spanner.keyset import KeySet keys = [[0], [1], [2]] keyset = KeySet(keys=keys) session = _Session() base = self._makeOne(session) base.delete(TABLE_NAME, keyset=keyset) self.assertEqual(len(base._mutations), 1) mutation = base._mutations[0] self.assertIsInstance(mutation, Mutation) delete = mutation.delete self.assertIsInstance(delete, Mutation.Delete) self.assertEqual(delete.table, TABLE_NAME) key_set_pb = delete.key_set self.assertEqual(len(key_set_pb.ranges), 0) self.assertEqual(len(key_set_pb.keys), len(keys)) for found, expected in zip(key_set_pb.keys, keys): self.assertEqual( [int(value.string_value) for value in found.values], expected)
def test_commit_grpc_error(self): from google.gax.errors import GaxError from google.cloud.proto.spanner.v1.transaction_pb2 import ( TransactionOptions) from google.cloud.proto.spanner.v1.mutation_pb2 import ( Mutation as MutationPB) from google.cloud.spanner.keyset import KeySet keys = [[0], [1], [2]] keyset = KeySet(keys=keys) database = _Database() api = database.spanner_api = _FauxSpannerAPI( _random_gax_error=True) session = _Session(database) batch = self._make_one(session) batch.delete(TABLE_NAME, keyset=keyset) with self.assertRaises(GaxError): batch.commit() (session, mutations, single_use_txn, options) = api._committed self.assertEqual(session, self.SESSION_NAME) self.assertTrue(len(mutations), 1) mutation = mutations[0] self.assertIsInstance(mutation, MutationPB) self.assertTrue(mutation.HasField('delete')) delete = mutation.delete self.assertEqual(delete.table, TABLE_NAME) keyset_pb = delete.key_set self.assertEqual(len(keyset_pb.ranges), 0) self.assertEqual(len(keyset_pb.keys), len(keys)) for found, expected in zip(keyset_pb.keys, keys): self.assertEqual( [int(value.string_value) for value in found.values], expected) self.assertIsInstance(single_use_txn, TransactionOptions) self.assertTrue(single_use_txn.HasField('read_write')) self.assertEqual(options.kwargs['metadata'], [('google-cloud-resource-prefix', database.name)])
def _read_helper(self, multi_use, first=True, count=0): from google.protobuf.struct_pb2 import Struct from google.cloud.proto.spanner.v1.result_set_pb2 import ( PartialResultSet, ResultSetMetadata, ResultSetStats) from google.cloud.proto.spanner.v1.transaction_pb2 import ( TransactionSelector) from google.cloud.proto.spanner.v1.type_pb2 import Type, StructType from google.cloud.proto.spanner.v1.type_pb2 import STRING, INT64 from google.cloud.spanner.keyset import KeySet from google.cloud.spanner._helpers import _make_value_pb TXN_ID = b'DEADBEEF' VALUES = [ [u'bharney', 31], [u'phred', 32], ] VALUE_PBS = [[_make_value_pb(item) for item in row] for row in VALUES] struct_type_pb = StructType(fields=[ StructType.Field(name='name', type=Type(code=STRING)), StructType.Field(name='age', type=Type(code=INT64)), ]) metadata_pb = ResultSetMetadata(row_type=struct_type_pb) stats_pb = ResultSetStats(query_stats=Struct( fields={ 'rows_returned': _make_value_pb(2), })) result_sets = [ PartialResultSet(values=VALUE_PBS[0], metadata=metadata_pb), PartialResultSet(values=VALUE_PBS[1], stats=stats_pb), ] KEYS = ['*****@*****.**', '*****@*****.**'] KEYSET = KeySet(keys=KEYS) INDEX = 'email-address-index' LIMIT = 20 TOKEN = b'DEADBEEF' database = _Database() api = database.spanner_api = _FauxSpannerAPI( _streaming_read_response=_MockIterator(*result_sets)) session = _Session(database) derived = self._makeDerived(session) derived._multi_use = multi_use derived._read_request_count = count if not first: derived._transaction_id = TXN_ID result_set = derived.read(TABLE_NAME, COLUMNS, KEYSET, index=INDEX, limit=LIMIT) self.assertEqual(derived._read_request_count, count + 1) if multi_use: self.assertIs(result_set._source, derived) else: self.assertIsNone(result_set._source) result_set.consume_all() self.assertEqual(list(result_set.rows), VALUES) self.assertEqual(result_set.metadata, metadata_pb) self.assertEqual(result_set.stats, stats_pb) (r_session, table, columns, key_set, transaction, index, limit, resume_token, options) = api._streaming_read_with self.assertEqual(r_session, self.SESSION_NAME) self.assertEqual(table, TABLE_NAME) self.assertEqual(columns, COLUMNS) self.assertEqual(key_set, KEYSET.to_pb()) self.assertIsInstance(transaction, TransactionSelector) if multi_use: if first: self.assertTrue(transaction.begin.read_only.strong) else: self.assertEqual(transaction.id, TXN_ID) else: self.assertTrue(transaction.single_use.read_only.strong) self.assertEqual(index, INDEX) self.assertEqual(limit, LIMIT) self.assertEqual(resume_token, b'') self.assertEqual(options.kwargs['metadata'], [('google-cloud-resource-prefix', database.name)])