class TestPut(unittest.TestCase, TestBase): @classmethod def setUpClass(cls): cls.set_up_class() global table_ttl table_ttl = TimeToLive.of_days(30) create_statement = ('CREATE TABLE ' + table_name + '(fld_id INTEGER, fld_long LONG, \ fld_float FLOAT, fld_double DOUBLE, fld_bool BOOLEAN, fld_str STRING, \ fld_bin BINARY, fld_time TIMESTAMP(9), fld_num NUMBER, fld_json JSON, \ fld_arr ARRAY(STRING), fld_map MAP(STRING), \ fld_rec RECORD(fld_id LONG, fld_bool BOOLEAN, fld_str STRING), \ PRIMARY KEY(fld_id)) USING TTL ' + str(table_ttl)) create_request = TableRequest().set_statement( create_statement).set_table_limits(TableLimits(50, 50, 1)) cls.table_request(create_request) @classmethod def tearDownClass(cls): cls.tear_down_class() def setUp(self): self.set_up() self.row = get_row(with_sid=False) self.key = {'fld_id': 1} self.put_request = PutRequest().set_value( self.row).set_table_name(table_name).set_timeout(timeout) self.get_request = GetRequest().set_key( self.key).set_table_name(table_name) self.ttl = TimeToLive.of_hours(24) self.hour_in_milliseconds = 60 * 60 * 1000 self.day_in_milliseconds = 24 * 60 * 60 * 1000 def tearDown(self): request = DeleteRequest().set_key(self.key).set_table_name(table_name) self.handle.delete(request) self.tear_down() def testPutSetIllegalValue(self): self.assertRaises(IllegalArgumentException, self.put_request.set_value, 'IllegalValue') def testPutSetIllegalValueFromJson(self): self.assertRaises(ValueError, self.put_request.set_value_from_json, 'IllegalJson') self.put_request.set_value_from_json('{"invalid_field": "value"}') self.assertRaises(IllegalArgumentException, self.handle.put, self.put_request) def testPutSetIllegalCompartment(self): self.assertRaises(IllegalArgumentException, self.put_request.set_compartment, {}) self.assertRaises(IllegalArgumentException, self.put_request.set_compartment, '') def testPutSetIllegalOption(self): self.put_request.set_option('IllegalOption') self.assertRaises(IllegalStateException, self.handle.put, self.put_request) def testPutSetIllegalMatchVersion(self): self.assertRaises(IllegalArgumentException, self.put_request.set_match_version, 'IllegalMatchVersion') def testPutSetIllegalTtl(self): self.assertRaises(IllegalArgumentException, self.put_request.set_ttl, 'IllegalTtl') def testPutSetIllegalUseTableDefaultTtl(self): self.assertRaises(IllegalArgumentException, self.put_request.set_use_table_default_ttl, 'IllegalUseTableDefaultTtl') def testPutSetIllegalExactMatch(self): self.assertRaises(IllegalArgumentException, self.put_request.set_exact_match, 'IllegalExactMatch') def testPutSetIllegalIdentityCacheSize(self): self.assertRaises(IllegalArgumentException, self.put_request.set_identity_cache_size, 'IllegalIdentityCacheSize') def testPutSetIllegalTimeout(self): self.assertRaises(IllegalArgumentException, self.put_request.set_timeout, 'IllegalTimeout') self.assertRaises(IllegalArgumentException, self.put_request.set_timeout, 0) self.assertRaises(IllegalArgumentException, self.put_request.set_timeout, -1) def testPutSetIllegalTableName(self): self.assertRaises(IllegalArgumentException, self.put_request.set_table_name, {'name': table_name}) self.put_request.set_table_name('IllegalTable') self.assertRaises(TableNotFoundException, self.handle.put, self.put_request) def testPutSetIllegalReturnRow(self): self.assertRaises(IllegalArgumentException, self.put_request.set_return_row, 'IllegalReturnRow') def testPutSetLargeSizeValue(self): self.row['fld_str'] = self.get_random_str(2) self.put_request.set_value(self.row) if is_onprem(): version = self.handle.put(self.put_request).get_version() self.assertIsNotNone(version) else: self.assertRaises(RequestSizeLimitException, self.handle.put, self.put_request) def testPutIfVersionWithoutMatchVersion(self): self.put_request.set_option(PutOption.IF_VERSION) self.assertRaises(IllegalArgumentException, self.handle.put, self.put_request) def testPutNoVersionWithMatchVersion(self): version = self.handle.put(self.put_request).get_version() self.put_request.set_option( PutOption.IF_ABSENT).set_match_version(version) self.assertRaises(IllegalArgumentException, self.handle.put, self.put_request) self.put_request.set_option( PutOption.IF_PRESENT).set_match_version(version) self.assertRaises(IllegalArgumentException, self.handle.put, self.put_request) def testPutSetTtlAndUseTableDefaultTtl(self): self.put_request.set_ttl(self.ttl).set_use_table_default_ttl(True) self.assertRaises(IllegalArgumentException, self.handle.put, self.put_request) def testPutGets(self): identity_cache_size = 5 version = self.handle.put(self.put_request).get_version() self.put_request.set_option( PutOption.IF_ABSENT).set_match_version(version).set_ttl( self.ttl).set_use_table_default_ttl(True).set_exact_match( True).set_identity_cache_size( identity_cache_size).set_return_row(True) self.assertEqual(self.put_request.get_value(), self.row) self.assertEqual(self.put_request.get_compartment(), tenant_id) self.assertEqual(self.put_request.get_option(), PutOption.IF_ABSENT) self.assertEqual(self.put_request.get_match_version(), version) self.assertEqual(self.put_request.get_ttl(), self.ttl) self.assertTrue(self.put_request.get_use_table_default_ttl()) self.assertTrue(self.put_request.get_update_ttl()) self.assertEqual(self.put_request.get_timeout(), timeout) self.assertEqual(self.put_request.get_table_name(), table_name) self.assertTrue(self.put_request.get_exact_match()) self.assertEqual(self.put_request.get_identity_cache_size(), identity_cache_size) self.assertTrue(self.put_request.get_return_row()) def testPutIllegalRequest(self): self.assertRaises(IllegalArgumentException, self.handle.put, 'IllegalRequest') def testPutNormal(self): # test put with normal values result = self.handle.put(self.put_request) tb_expect_expiration = table_ttl.to_expiration_time( int(round(time() * 1000))) version = result.get_version() self._check_put_result(result) self.check_cost(result, 0, 0, 1, 1) result = self.handle.get(self.get_request) self.check_get_result(result, self.row, version, tb_expect_expiration, TimeUnit.DAYS) self.check_cost(result, 1, 2, 0, 0) # put a row with the same primary key to update the row self.row['fld_long'] = 2147483649 self.put_request.set_value(self.row).set_ttl(self.ttl) result = self.handle.put(self.put_request) expect_expiration = self.ttl.to_expiration_time( int(round(time() * 1000))) version = result.get_version() self._check_put_result(result) self.check_cost(result, 0, 0, 2, 2) result = self.handle.get(self.get_request) self.check_get_result(result, self.row, version, expect_expiration, TimeUnit.HOURS) self.check_cost(result, 1, 2, 0, 0) # update the ttl of the row to never expire self.put_request.set_ttl(TimeToLive.of_days(0)) result = self.handle.put(self.put_request) version = result.get_version() self._check_put_result(result) self.check_cost(result, 0, 0, 2, 2) result = self.handle.get(self.get_request) self.check_get_result(result, self.row, version) self.check_cost(result, 1, 2, 0, 0) def testPutIfAbsent(self): # test PutIfAbsent with normal values self.put_request.set_option(PutOption.IF_ABSENT).set_ttl( self.ttl).set_return_row(True) result = self.handle.put(self.put_request) expect_expiration = self.ttl.to_expiration_time( int(round(time() * 1000))) version = result.get_version() self._check_put_result(result) self.check_cost(result, 1, 2, 1, 1) result = self.handle.get(self.get_request) self.check_get_result(result, self.row, version, expect_expiration, TimeUnit.HOURS) self.check_cost(result, 1, 2, 0, 0) # put a row with the same primary key to update the row, operation # should fail, and return the existing row result = self.handle.put(self.put_request) self._check_put_result(result, False, existing_version=version, existing_value=self.row) self.check_cost(result, 1, 2, 0, 0) def testPutIfPresent(self): # test PutIfPresent with normal values, operation should fail because # there is no existing row in store self.put_request.set_option(PutOption.IF_PRESENT) result = self.handle.put(self.put_request) self._check_put_result(result, False) self.check_cost(result, 1, 2, 0, 0) # insert a row self.put_request.set_option(PutOption.IF_ABSENT).set_ttl(self.ttl) self.handle.put(self.put_request) expect_expiration = self.ttl.to_expiration_time( int(round(time() * 1000))) # test PutIfPresent with normal values, operation should succeed self.row['fld_long'] = 2147483649 self.put_request.set_value(self.row).set_option( PutOption.IF_PRESENT).set_return_row(True) result = self.handle.put(self.put_request) version = result.get_version() self._check_put_result(result) self.check_cost(result, 1, 2, 2, 2) result = self.handle.get(self.get_request) self.check_get_result(result, self.row, version, expect_expiration, TimeUnit.HOURS) self.check_cost(result, 1, 2, 0, 0) # test PutIfPresent with normal values, update the ttl with table # default ttl self.put_request.set_ttl(None).set_use_table_default_ttl(True) result = self.handle.put(self.put_request) tb_expect_expiration = table_ttl.to_expiration_time( int(round(time() * 1000))) version = result.get_version() self._check_put_result(result) self.check_cost(result, 1, 2, 2, 2) result = self.handle.get(self.get_request) self.check_get_result(result, self.row, version, tb_expect_expiration, TimeUnit.DAYS) self.check_cost(result, 1, 2, 0, 0) def testPutIfVersion(self): # insert a row result = self.handle.put(self.put_request) version_old = result.get_version() # test PutIfVersion with normal values, operation should succeed self.row['fld_bool'] = False self.put_request.set_value(self.row).set_ttl( self.ttl).set_match_version(version_old).set_return_row(True) result = self.handle.put(self.put_request) expect_expiration = self.ttl.to_expiration_time( int(round(time() * 1000))) version = result.get_version() self._check_put_result(result) self.check_cost(result, 1, 2, 2, 2) result = self.handle.get(self.get_request) self.check_get_result(result, self.row, version, expect_expiration, TimeUnit.HOURS) self.check_cost(result, 1, 2, 0, 0) # test PutIfVersion with normal values, operation should fail because # version not match, and return the existing row self.put_request.set_ttl(None).set_use_table_default_ttl(True) result = self.handle.put(self.put_request) self._check_put_result(result, False, existing_version=version, existing_value=self.row) self.check_cost(result, 1, 2, 0, 0) def testPutWithExactMatch(self): # test put a row with an extra field not in the table, by default this # will succeed row = deepcopy(self.row) row.update({'fld_id': 2, 'extra': 5}) key = {'fld_id': 2} self.row['fld_id'] = 2 self.put_request.set_value(row) result = self.handle.put(self.put_request) tb_expect_expiration = table_ttl.to_expiration_time( int(round(time() * 1000))) version = result.get_version() self._check_put_result(result) self.check_cost(result, 0, 0, 1, 1) self.get_request.set_key(key) result = self.handle.get(self.get_request) self.check_get_result(result, self.row, version, tb_expect_expiration, TimeUnit.DAYS) self.check_cost(result, 1, 2, 0, 0) # test put a row with an extra field not in the table, this will fail # because it's not an exact match when we set exact_match=True self.put_request.set_exact_match(True) self.assertRaises(IllegalArgumentException, self.handle.put, self.put_request) def testPutWithIdentityColumn(self): id_table = table_prefix + 'Identity' create_request = TableRequest().set_statement( 'CREATE TABLE ' + id_table + '(sid INTEGER, id LONG GENERATED \ ALWAYS AS IDENTITY, name STRING, PRIMARY KEY(SHARD(sid), id))') create_request.set_table_limits(TableLimits(50, 50, 1)) self.table_request(create_request) # test put a row with an extra field not in the table, by default this # will succeed row = {'name': 'myname', 'extra': 'extra', 'sid': 1} key = {'sid': 1, 'id': 1} expected = OrderedDict() expected['sid'] = 1 expected['id'] = 1 expected['name'] = 'myname' self.put_request.set_table_name(id_table).set_value(row) result = self.handle.put(self.put_request) version = result.get_version() self._check_put_result(result, has_generated_value=True) self.check_cost(result, 0, 0, 1, 1) self.get_request.set_table_name(id_table).set_key(key) result = self.handle.get(self.get_request) self.check_get_result(result, expected, version) self.check_cost(result, 1, 2, 0, 0) # test put a row with identity field, this will fail because id is # 'generated always' and in that path it is not legal to provide a value # for id row['id'] = 1 self.assertRaises(IllegalArgumentException, self.handle.put, self.put_request) def _check_put_result(self, result, has_version=True, has_generated_value=False, existing_version=None, existing_value=None): # check version version = result.get_version() (self.assertIsNotNone(version) if has_version else self.assertIsNone(version)) # check generated_value generated_value = result.get_generated_value() (self.assertIsNotNone(generated_value) if has_generated_value else self.assertIsNone(generated_value)) # check existing version ver = result.get_existing_version() (self.assertIsNone(ver) if existing_version is None else self.assertEqual(ver.get_bytes(), existing_version.get_bytes())) # check existing value self.assertEqual(result.get_existing_value(), existing_value)
class TestGet(unittest.TestCase, TestBase): @classmethod def setUpClass(cls): cls.handle = None cls.set_up_class() table_ttl = TimeToLive.of_hours(16) create_statement = ('CREATE TABLE ' + table_name + '(fld_sid INTEGER, fld_id INTEGER, \ fld_long LONG, fld_float FLOAT, fld_double DOUBLE, fld_bool BOOLEAN, \ fld_str STRING, fld_bin BINARY, fld_time TIMESTAMP(7), fld_num NUMBER, \ fld_json JSON, fld_arr ARRAY(STRING), fld_map MAP(STRING), \ fld_rec RECORD(fld_id LONG, fld_bool BOOLEAN, fld_str STRING), \ PRIMARY KEY(SHARD(fld_sid), fld_id)) USING TTL ' + str(table_ttl)) create_request = TableRequest().set_statement( create_statement).set_table_limits(TableLimits(100, 100, 1)) cls.table_request(create_request) global row, tb_expect_expiration, version row = get_row() put_request = PutRequest().set_value(row).set_table_name(table_name) version = cls.handle.put(put_request).get_version() tb_expect_expiration = table_ttl.to_expiration_time( int(round(time() * 1000))) @classmethod def tearDownClass(cls): cls.tear_down_class() def setUp(self): self.set_up() self.key = {'fld_sid': 1, 'fld_id': 1} self.get_request = GetRequest().set_key( self.key).set_table_name(table_name).set_timeout(timeout) def tearDown(self): self.tear_down() def testGetSetIllegalKey(self): self.assertRaises(IllegalArgumentException, self.get_request.set_key, 'IllegalKey') self.get_request.set_key({'fld_sid': 1}) self.assertRaises(IllegalArgumentException, self.handle.get, self.get_request) self.get_request.set_key({'fld_id': 1}) self.assertRaises(IllegalArgumentException, self.handle.get, self.get_request) def testGetSetIllegalKeyFromJson(self): self.assertRaises(ValueError, self.get_request.set_key_from_json, 'IllegalJson') self.get_request.set_key_from_json('{"invalid_field": "key"}') self.assertRaises(IllegalArgumentException, self.handle.get, self.get_request) def testGetSetIllegalTableName(self): self.assertRaises(IllegalArgumentException, self.get_request.set_table_name, {'name': table_name}) self.get_request.set_table_name('IllegalTable') self.assertRaises(TableNotFoundException, self.handle.get, self.get_request) def testGetSetIllegalCompartment(self): self.assertRaises(IllegalArgumentException, self.get_request.set_compartment, {}) self.assertRaises(IllegalArgumentException, self.get_request.set_compartment, '') def testGetSetIllegalConsistency(self): self.assertRaises(IllegalArgumentException, self.get_request.set_consistency, 'IllegalConsistency') def testGetSetIllegalTimeout(self): self.assertRaises(IllegalArgumentException, self.get_request.set_timeout, 'IllegalTimeout') self.assertRaises(IllegalArgumentException, self.get_request.set_timeout, 0) self.assertRaises(IllegalArgumentException, self.get_request.set_timeout, -1) def testGetWithoutKey(self): self.get_request.set_key(None) self.assertRaises(IllegalArgumentException, self.handle.get, self.get_request) def testGetGets(self): self.assertEqual(self.get_request.get_key(), self.key) self.assertIsNone(self.get_request.get_compartment()) def testGetIllegalRequest(self): self.assertRaises(IllegalArgumentException, self.handle.get, 'IllegalRequest') def testGetNormal(self): result = self.handle.get(self.get_request) self.check_get_result(result, row, version, tb_expect_expiration, TimeUnit.HOURS) self.check_cost(result, 1, 2, 0, 0) def testGetEventual(self): self.get_request.set_consistency(Consistency.EVENTUAL) result = self.handle.get(self.get_request) self.check_get_result(result, row, expect_expiration=tb_expect_expiration, timeunit=TimeUnit.HOURS, ver_eq=False) self.check_cost(result, 1, 1, 0, 0) def testGetNonExisting(self): self.get_request.set_key({'fld_sid': 2, 'fld_id': 2}) result = self.handle.get(self.get_request) self.check_get_result(result) self.check_cost(result, 1, 2, 0, 0)
class TestQuery(unittest.TestCase, TestBase): @classmethod def setUpClass(cls): cls.set_up_class() index_name = 'idx_' + table_name create_statement = ('CREATE TABLE ' + table_name + '(fld_sid INTEGER, fld_id INTEGER, \ fld_long LONG, fld_float FLOAT, fld_double DOUBLE, fld_bool BOOLEAN, \ fld_str STRING, fld_bin BINARY, fld_time TIMESTAMP(6), fld_num NUMBER, \ fld_json JSON, fld_arr ARRAY(STRING), fld_map MAP(STRING), \ fld_rec RECORD(fld_id LONG, fld_bool BOOLEAN, fld_str STRING), \ PRIMARY KEY(SHARD(fld_sid), fld_id))') limits = TableLimits(100, 100, 1) create_request = TableRequest().set_statement( create_statement).set_table_limits(limits) cls.table_request(create_request) create_idx_request = TableRequest() create_idx_statement = ('CREATE INDEX ' + index_name + '1 ON ' + table_name + '(fld_long)') create_idx_request.set_statement(create_idx_statement) cls.table_request(create_idx_request) create_idx_statement = ('CREATE INDEX ' + index_name + '2 ON ' + table_name + '(fld_str)') create_idx_request.set_statement(create_idx_statement) cls.table_request(create_idx_request) create_idx_statement = ('CREATE INDEX ' + index_name + '3 ON ' + table_name + '(fld_bool)') create_idx_request.set_statement(create_idx_statement) cls.table_request(create_idx_request) create_idx_statement = ('CREATE INDEX ' + index_name + '4 ON ' + table_name + '(fld_json.location as point)') create_idx_request.set_statement(create_idx_statement) cls.table_request(create_idx_request) global prepare_cost prepare_cost = 2 global query_statement query_statement = ('SELECT fld_sid, fld_id FROM ' + table_name + ' WHERE fld_sid = 1') @classmethod def tearDownClass(cls): cls.tear_down_class() def setUp(self): self.set_up() self.handle_config = get_handle_config(tenant_id) self.min_time = list() self.max_time = list() shardkeys = 2 ids = 6 write_multiple_request = WriteMultipleRequest() for sk in range(shardkeys): for i in range(ids): row = get_row() if i == 0: self.min_time.append(row['fld_time']) elif i == ids - 1: self.max_time.append(row['fld_time']) row['fld_sid'] = sk row['fld_id'] = i row['fld_bool'] = False if sk == 0 else True row['fld_str'] = ( '{"name": u' + str(shardkeys * ids - sk * ids - i - 1).zfill(2) + '}') row['fld_json']['location']['coordinates'] = ([ 23.549 - sk * 0.5 - i, 35.2908 + sk * 0.5 + i ]) write_multiple_request.add( PutRequest().set_value(row).set_table_name(table_name), True) self.handle.write_multiple(write_multiple_request) write_multiple_request.clear() prepare_statement_update = ( 'DECLARE $fld_sid INTEGER; $fld_id INTEGER; UPDATE ' + table_name + ' u SET u.fld_long = u.fld_long + 1 WHERE fld_sid = $fld_sid ' + 'AND fld_id = $fld_id') prepare_request_update = PrepareRequest().set_statement( prepare_statement_update) self.prepare_result_update = self.handle.prepare( prepare_request_update) prepare_statement_select = ( 'DECLARE $fld_long LONG; SELECT fld_sid, fld_id, fld_long FROM ' + table_name + ' WHERE fld_long = $fld_long') prepare_request_select = PrepareRequest().set_statement( prepare_statement_select) self.prepare_result_select = self.handle.prepare( prepare_request_select) self.query_request = QueryRequest().set_timeout(timeout) self.get_request = GetRequest().set_table_name(table_name) def tearDown(self): self.tear_down() def testQuerySetIllegalCompartment(self): self.assertRaises(IllegalArgumentException, self.query_request.set_compartment, {}) self.assertRaises(IllegalArgumentException, self.query_request.set_compartment, '') def testQuerySetIllegalLimit(self): self.assertRaises(IllegalArgumentException, self.query_request.set_limit, 'IllegalLimit') self.assertRaises(IllegalArgumentException, self.query_request.set_limit, -1) def testQuerySetIllegalMaxReadKb(self): self.assertRaises(IllegalArgumentException, self.query_request.set_max_read_kb, 'IllegalMaxReadKb') self.assertRaises(IllegalArgumentException, self.query_request.set_max_read_kb, -1) def testQuerySetIllegalMaxWriteKb(self): self.assertRaises(IllegalArgumentException, self.query_request.set_max_write_kb, 'IllegalMaxWriteKb') self.assertRaises(IllegalArgumentException, self.query_request.set_max_write_kb, -1) def testQuerySetIllegalMaxMemoryConsumption(self): self.assertRaises(IllegalArgumentException, self.query_request.set_max_memory_consumption, 'IllegalMaxMemoryConsumption') self.assertRaises(IllegalArgumentException, self.query_request.set_max_memory_consumption, -1) def testQuerySetIllegalMathContext(self): self.assertRaises(IllegalArgumentException, self.query_request.set_math_context, 'IllegalMathContext') def testQuerySetIllegalConsistency(self): self.assertRaises(IllegalArgumentException, self.query_request.set_consistency, 'IllegalConsistency') def testQuerySetIllegalContinuationKey(self): self.assertRaises(IllegalArgumentException, self.query_request.set_continuation_key, 'IllegalContinuationKey') def testQuerySetIllegalStatement(self): self.assertRaises(IllegalArgumentException, self.query_request.set_statement, {}) self.assertRaises(IllegalArgumentException, self.query_request.set_statement, '') self.query_request.set_statement('IllegalStatement') self.assertRaises(IllegalArgumentException, self.handle.query, self.query_request) self.query_request.set_statement('SELECT fld_id FROM IllegalTableName') self.assertRaises(TableNotFoundException, self.handle.query, self.query_request) def testQuerySetIllegalPreparedStatement(self): self.assertRaises(IllegalArgumentException, self.query_request.set_prepared_statement, 'IllegalPreparedStatement') def testQuerySetIllegalTimeout(self): self.assertRaises(IllegalArgumentException, self.query_request.set_timeout, 'IllegalTimeout') self.assertRaises(IllegalArgumentException, self.query_request.set_timeout, 0) self.assertRaises(IllegalArgumentException, self.query_request.set_timeout, -1) def testQuerySetIllegalDefaults(self): self.assertRaises(IllegalArgumentException, self.query_request.set_defaults, 'IllegalDefaults') def testQuerySetDefaults(self): self.query_request.set_defaults(self.handle_config) self.assertEqual(self.query_request.get_timeout(), timeout) self.assertEqual(self.query_request.get_consistency(), Consistency.ABSOLUTE) def testQueryNoStatementAndBothStatement(self): self.assertRaises(IllegalArgumentException, self.handle.query, self.query_request) self.query_request.set_statement(query_statement) self.query_request.set_prepared_statement(self.prepare_result_select) self.assertRaises(IllegalArgumentException, self.handle.query, self.query_request) def testQueryGets(self): continuation_key = bytearray(5) context = Context(prec=10, rounding=ROUND_HALF_EVEN) self.query_request.set_consistency(Consistency.EVENTUAL).set_statement( query_statement).set_prepared_statement( self.prepare_result_select).set_limit(3).set_max_read_kb( 2).set_max_write_kb(3).set_max_memory_consumption( 5).set_math_context(context).set_continuation_key( continuation_key) self.assertIsNone(self.query_request.get_compartment()) self.assertFalse(self.query_request.is_done()) self.assertEqual(self.query_request.get_limit(), 3) self.assertEqual(self.query_request.get_max_read_kb(), 2) self.assertEqual(self.query_request.get_max_write_kb(), 3) self.assertEqual(self.query_request.get_max_memory_consumption(), 5) self.assertEqual(self.query_request.get_math_context(), context) self.assertEqual(self.query_request.get_consistency(), Consistency.EVENTUAL) self.assertEqual(self.query_request.get_continuation_key(), continuation_key) self.assertEqual(self.query_request.get_statement(), query_statement) self.assertEqual(self.query_request.get_prepared_statement(), self.prepare_result_select.get_prepared_statement()) self.assertEqual(self.query_request.get_timeout(), timeout) def testQueryIllegalRequest(self): self.assertRaises(IllegalArgumentException, self.handle.query, 'IllegalRequest') def testQueryStatementSelect(self): num_records = 6 self.query_request.set_statement(query_statement) result = self.handle.query(self.query_request) records = self.check_query_result(result, num_records) for idx in range(num_records): self.assertEqual(records[idx], self._expected_row(1, idx)) self.check_cost(result, num_records + prepare_cost, num_records * 2 + prepare_cost, 0, 0) def testQueryStatementSelectWithLimit(self): limit = 3 self.query_request.set_statement(query_statement).set_limit(limit) result = self.handle.query(self.query_request) records = self.check_query_result(result, limit, True) for idx in range(limit): self.assertEqual(records[idx], self._expected_row(1, idx)) self.check_cost(result, limit + prepare_cost, limit * 2 + prepare_cost, 0, 0) def testQueryStatementSelectWithMaxReadKb(self): num_records = 6 max_read_kb = 4 self.query_request.set_statement(query_statement).set_max_read_kb( max_read_kb) result = self.handle.query(self.query_request) # TODO: [#27744] KV doesn't honor max read kb for on-prem proxy because # it has no table limits. if is_onprem(): records = self.check_query_result(result, num_records) else: records = self.check_query_result(result, max_read_kb + 1, True) for idx in range(len(records)): self.assertEqual(records[idx], self._expected_row(1, idx)) self.check_cost(result, max_read_kb + prepare_cost + 1, max_read_kb * 2 + prepare_cost + 2, 0, 0) def testQueryStatementSelectWithConsistency(self): num_records = 6 self.query_request.set_statement(query_statement).set_consistency( Consistency.ABSOLUTE) result = self.handle.query(self.query_request) records = self.check_query_result(result, num_records) for idx in range(num_records): self.assertEqual(records[idx], self._expected_row(1, idx)) self.check_cost(result, num_records + prepare_cost, num_records * 2 + prepare_cost, 0, 0) def testQueryStatementSelectWithContinuationKey(self): num_records = 6 limit = 4 self.query_request.set_statement(query_statement).set_limit(limit) count = 0 while True: completed = count * limit result = self.handle.query(self.query_request) if completed + limit <= num_records: num_get = limit read_kb = num_get records = self.check_query_result(result, num_get, True) else: num_get = num_records - completed read_kb = (1 if num_get == 0 else num_get) records = self.check_query_result(result, num_get) for idx in range(num_get): self.assertEqual(records[idx], self._expected_row(1, completed + idx)) self.check_cost(result, read_kb + (prepare_cost if count == 0 else 0), read_kb * 2 + (prepare_cost if count == 0 else 0), 0, 0) count += 1 if result.get_continuation_key() is None: break self.query_request.set_continuation_key( result.get_continuation_key()) self.assertEqual(count, num_records // limit + 1) def testQueryStatementSelectWithDefault(self): num_records = 6 self.query_request.set_statement(query_statement).set_defaults( self.handle_config) result = self.handle.query(self.query_request) records = self.check_query_result(result, num_records) for idx in range(num_records): self.assertEqual(records[idx], self._expected_row(1, idx)) self.check_cost(result, num_records + prepare_cost, num_records * 2 + prepare_cost, 0, 0) def testQueryPreparedStatementUpdate(self): fld_sid = 0 fld_id = 2 fld_long = 2147483649 prepared_statement = self.prepare_result_update.get_prepared_statement( ) # update a non-existing row prepared_statement.set_variable('$fld_sid', 2).set_variable('$fld_id', 0) self.query_request.set_prepared_statement(self.prepare_result_update) result = self.handle.query(self.query_request) records = self.check_query_result(result, 1) self.assertEqual(records[0], {'NumRowsUpdated': 0}) self.check_cost(result, 1, 2, 0, 0) # update an existing row prepared_statement.set_variable('$fld_sid', fld_sid).set_variable( '$fld_id', fld_id) self.query_request.set_prepared_statement(self.prepare_result_update) result = self.handle.query(self.query_request) records = self.check_query_result(result, 1) self.assertEqual(records[0], {'NumRowsUpdated': 1}) self.check_cost(result, 2, 4, 4, 4) # check the updated row prepared_statement = self.prepare_result_select.get_prepared_statement( ) prepared_statement.set_variable('$fld_long', fld_long) self.query_request.set_prepared_statement(prepared_statement) result = self.handle.query(self.query_request) records = self.check_query_result(result, 1) self.assertEqual(records[0], self._expected_row(fld_sid, fld_id, fld_long)) self.check_cost(result, 1, 2, 0, 0) def testQueryPreparedStatementUpdateWithLimit(self): fld_sid = 1 fld_id = 5 fld_long = 2147483649 prepared_statement = self.prepare_result_update.get_prepared_statement( ) prepared_statement.set_variable('$fld_sid', fld_sid).set_variable( '$fld_id', fld_id) self.query_request.set_prepared_statement( self.prepare_result_update).set_limit(1) result = self.handle.query(self.query_request) records = self.check_query_result(result, 1) self.assertEqual(records[0], {'NumRowsUpdated': 1}) self.check_cost(result, 2, 4, 4, 4) # check the updated row prepared_statement = self.prepare_result_select.get_prepared_statement( ) prepared_statement.set_variable('$fld_long', fld_long) self.query_request.set_prepared_statement(prepared_statement) result = self.handle.query(self.query_request) records = self.check_query_result(result, 1, True) self.assertEqual(records[0], self._expected_row(fld_sid, fld_id, fld_long)) self.check_cost(result, 1, 2, 0, 0) def testQueryPreparedStatementUpdateWithMaxReadKb(self): fld_sid = 0 fld_id = 1 fld_long = 2147483649 # set a small max_read_kb to read a row to update prepared_statement = self.prepare_result_update.get_prepared_statement( ) prepared_statement.set_variable('$fld_sid', fld_sid).set_variable( '$fld_id', fld_id) self.query_request.set_prepared_statement( self.prepare_result_update).set_max_read_kb(1) if not is_onprem(): self.assertRaises(IllegalArgumentException, self.handle.query, self.query_request) # set a enough max_read_kb to read a row to update self.query_request.set_max_read_kb(2) result = self.handle.query(self.query_request) records = self.check_query_result(result, 1) self.assertEqual(records[0], {'NumRowsUpdated': 1}) self.check_cost(result, 2, 4, 4, 4) # check the updated row prepared_statement = self.prepare_result_select.get_prepared_statement( ) prepared_statement.set_variable('$fld_long', fld_long) self.query_request.set_prepared_statement(prepared_statement) result = self.handle.query(self.query_request) records = self.check_query_result(result, 1) self.assertEqual(records[0], self._expected_row(fld_sid, fld_id, fld_long)) self.check_cost(result, 1, 2, 0, 0) def testQueryPreparedStatementUpdateWithConsistency(self): fld_sid = 1 fld_id = 2 fld_long = 2147483649 prepared_statement = self.prepare_result_update.get_prepared_statement( ) prepared_statement.set_variable('$fld_sid', fld_sid).set_variable( '$fld_id', fld_id) self.query_request.set_prepared_statement( self.prepare_result_update).set_consistency(Consistency.ABSOLUTE) result = self.handle.query(self.query_request) records = self.check_query_result(result, 1) self.assertEqual(records[0], {'NumRowsUpdated': 1}) self.check_cost(result, 2, 4, 4, 4) # check the updated row prepared_statement = self.prepare_result_select.get_prepared_statement( ) prepared_statement.set_variable('$fld_long', fld_long) self.query_request.set_prepared_statement(prepared_statement) result = self.handle.query(self.query_request) records = self.check_query_result(result, 1) self.assertEqual(records[0], self._expected_row(fld_sid, fld_id, fld_long)) self.check_cost(result, 1, 2, 0, 0) def testQueryPreparedStatementUpdateWithContinuationKey(self): fld_sid = 1 fld_id = 3 fld_long = 2147483649 num_records = 1 limit = 3 prepared_statement = self.prepare_result_update.get_prepared_statement( ) prepared_statement.set_variable('$fld_sid', fld_sid).set_variable( '$fld_id', fld_id) self.query_request.set_prepared_statement( self.prepare_result_update).set_limit(limit) count = 0 while True: completed = count * limit result = self.handle.query(self.query_request) records = self.check_query_result(result, 1) if completed + limit <= num_records: self.assertEqual(records[0], {'NumRowsUpdated': limit}) read_kb = limit * 2 write_kb = limit * 4 else: num_update = num_records - completed self.assertEqual(records[0], {'NumRowsUpdated': num_update}) read_kb = (1 if num_update == 0 else num_update * 2) write_kb = (0 if num_update == 0 else num_update * 4) self.check_cost(result, read_kb, read_kb * 2, write_kb, write_kb) count += 1 if result.get_continuation_key() is None: break self.query_request.set_continuation_key( result.get_continuation_key()) self.assertEqual(count, 1) # check the updated row prepared_statement = self.prepare_result_select.get_prepared_statement( ) prepared_statement.set_variable('$fld_long', fld_long) self.query_request.set_prepared_statement(prepared_statement) result = self.handle.query(self.query_request) if limit <= num_records: records = self.check_query_result(result, num_records, True) else: records = self.check_query_result(result, num_records) self.assertEqual(records[0], self._expected_row(fld_sid, fld_id, fld_long)) self.check_cost(result, 1, 2, 0, 0) def testQueryPreparedStatementUpdateWithDefault(self): fld_sid = 0 fld_id = 5 fld_long = 2147483649 prepared_statement = self.prepare_result_update.get_prepared_statement( ) prepared_statement.set_variable('$fld_sid', fld_sid).set_variable( '$fld_id', fld_id) self.query_request.set_prepared_statement( self.prepare_result_update).set_defaults(self.handle_config) result = self.handle.query(self.query_request) records = self.check_query_result(result, 1) self.assertEqual(records[0], {'NumRowsUpdated': 1}) self.check_cost(result, 2, 4, 4, 4) # check the updated row prepared_statement = self.prepare_result_select.get_prepared_statement( ) prepared_statement.set_variable('$fld_long', fld_long) self.query_request.set_prepared_statement(prepared_statement) result = self.handle.query(self.query_request) records = self.check_query_result(result, 1) self.assertEqual(records[0], self._expected_row(fld_sid, fld_id, fld_long)) self.check_cost(result, 1, 2, 0, 0) def testQueryStatementUpdateTTL(self): hour_in_milliseconds = 60 * 60 * 1000 self.query_request.set_statement( 'UPDATE ' + table_name + ' $u SET TTL CASE WHEN ' + 'remaining_hours($u) < 0 THEN 3 ELSE remaining_hours($u) + 3 END ' + 'HOURS WHERE fld_sid = 1 AND fld_id = 3') result = self.handle.query(self.query_request) ttl = TimeToLive.of_hours(3) expect_expiration = ttl.to_expiration_time(int(round(time() * 1000))) records = self.check_query_result(result, 1) self.assertEqual(records[0], {'NumRowsUpdated': 1}) self.check_cost(result, 2 + prepare_cost, 4 + prepare_cost, 6, 6) # check the record after update ttl request succeed self.get_request.set_key({'fld_sid': 1, 'fld_id': 3}) result = self.handle.get(self.get_request) actual_expiration = result.get_expiration_time() actual_expect_diff = actual_expiration - expect_expiration self.assertGreater(actual_expiration, 0) self.assertLess(actual_expect_diff, hour_in_milliseconds) self.check_cost(result, 1, 2, 0, 0) def testQueryOrderBy(self): num_records = 12 num_ids = 6 # test order by primary index field statement = ('SELECT fld_sid, fld_id FROM ' + table_name + ' ORDER BY fld_sid, fld_id') query_request = QueryRequest().set_statement(statement) count = 0 while True: count += 1 result = self.handle.query(query_request) records = result.get_results() if query_request.is_done(): self.check_query_result(result, num_records, rec=records) for idx in range(num_records): self.assertEqual( records[idx], self._expected_row(idx // num_ids, idx % num_ids)) self.check_cost(result, 0, 0, 0, 0) break else: self.check_query_result(result, 0, True, records) self.assertEqual(records, []) self.check_cost(result, 0, 0, 0, 0, True) self.assertEqual(count, 2) # test order by secondary index field statement = ('SELECT fld_str FROM ' + table_name + ' ORDER BY fld_str') query_request = QueryRequest().set_statement(statement) while True: result = self.handle.query(query_request) records = result.get_results() if query_request.is_done(): self.check_query_result(result, num_records, rec=records) for idx in range(num_records): self.assertEqual( records[idx], {'fld_str': '{"name": u' + str(idx).zfill(2) + '}'}) self.check_cost(result, 0, 0, 0, 0) break else: self.check_query_result(result, 0, True, records) self.assertEqual(records, []) self.check_cost(result, 0, 0, 0, 0, True) def testQueryFuncMinMaxGroupBy(self): num_sids = 2 # test min function statement = ('SELECT min(fld_time) FROM ' + table_name) query_request = QueryRequest().set_statement(statement) result = self.handle.query(query_request) records = self.check_query_result(result, 1) self.assertEqual(records[0], {'Column_1': self.min_time[0]}) self.check_cost(result, prepare_cost, prepare_cost, 0, 0, True) # test max function statement = ('SELECT max(fld_time) FROM ' + table_name) query_request = QueryRequest().set_statement(statement) result = self.handle.query(query_request) records = self.check_query_result(result, 1) self.assertEqual(records[0], {'Column_1': self.max_time[1]}) self.check_cost(result, prepare_cost, prepare_cost, 0, 0, True) # test min function group by primary index field statement = ('SELECT min(fld_time) FROM ' + table_name + ' GROUP BY fld_sid') query_request = QueryRequest().set_statement(statement) count = 0 while True: count += 1 result = self.handle.query(query_request) records = result.get_results() if query_request.is_done(): self.check_query_result(result, num_sids, rec=records) for idx in range(num_sids): self.assertEqual(records[idx], {'Column_1': self.min_time[idx]}) self.check_cost(result, 0, 0, 0, 0) break else: self.check_query_result(result, 0, True, records) self.assertEqual(records, []) self.check_cost(result, 0, 0, 0, 0, True) self.assertEqual(count, 2) # test max function group by primary index field statement = ('SELECT max(fld_time) FROM ' + table_name + ' GROUP BY fld_sid') query_request = QueryRequest().set_statement(statement) count = 0 while True: count += 1 result = self.handle.query(query_request) records = result.get_results() if query_request.is_done(): self.check_query_result(result, num_sids, rec=records) for idx in range(num_sids): self.assertEqual(records[idx], {'Column_1': self.max_time[idx]}) self.check_cost(result, 0, 0, 0, 0) break else: self.check_query_result(result, 0, True, records) self.assertEqual(records, []) self.check_cost(result, 0, 0, 0, 0, True) self.assertEqual(count, 2) # test min function group by secondary index field statement = ('SELECT min(fld_time) FROM ' + table_name + ' GROUP BY fld_bool') query_request = QueryRequest().set_statement(statement) while True: result = self.handle.query(query_request) records = result.get_results() if query_request.is_done(): self.check_query_result(result, num_sids, rec=records) for idx in range(num_sids): self.assertEqual(records[idx], {'Column_1': self.min_time[idx]}) self.check_cost(result, 0, 0, 0, 0) break else: self.check_query_result(result, 0, True, records) self.assertEqual(records, []) self.check_cost(result, 0, 0, 0, 0, True) # test max function group by secondary index field statement = ('SELECT max(fld_time) FROM ' + table_name + ' GROUP BY fld_bool') query_request = QueryRequest().set_statement(statement) while True: result = self.handle.query(query_request) records = result.get_results() if query_request.is_done(): self.check_query_result(result, num_sids, rec=records) for idx in range(num_sids): self.assertEqual(records[idx], {'Column_1': self.max_time[idx]}) self.check_cost(result, 0, 0, 0, 0) break else: self.check_query_result(result, 0, True, records) self.assertEqual(records, []) self.check_cost(result, 0, 0, 0, 0, True) def testQueryFuncSumGroupBy(self): num_records = 12 num_sids = 2 # test sum function statement = ('SELECT sum(fld_double) FROM ' + table_name) query_request = QueryRequest().set_statement(statement) result = self.handle.query(query_request) records = self.check_query_result(result, 1) self.assertEqual(records[0], {'Column_1': 3.1415 * num_records}) self.check_cost(result, prepare_cost, prepare_cost, 0, 0, True) # test sum function group by primary index field statement = ('SELECT sum(fld_double) FROM ' + table_name + ' GROUP BY fld_sid') query_request = QueryRequest().set_statement(statement) count = 0 while True: count += 1 result = self.handle.query(query_request) records = result.get_results() if query_request.is_done(): self.check_query_result(result, num_sids, rec=records) for idx in range(num_sids): self.assertEqual( records[idx], {'Column_1': 3.1415 * (num_records // num_sids)}) self.check_cost(result, 0, 0, 0, 0) break else: self.check_query_result(result, 0, True, records) self.assertEqual(records, []) self.check_cost(result, 0, 0, 0, 0, True) self.assertEqual(count, 2) # test sum function group by secondary index field statement = ('SELECT sum(fld_double) FROM ' + table_name + ' GROUP BY fld_bool') query_request = QueryRequest().set_statement(statement) while True: result = self.handle.query(query_request) records = result.get_results() if query_request.is_done(): self.check_query_result(result, num_sids, rec=records) for idx in range(num_sids): self.assertEqual( records[idx], {'Column_1': 3.1415 * (num_records // num_sids)}) self.check_cost(result, 0, 0, 0, 0) break else: self.check_query_result(result, 0, True, records) self.assertEqual(records, []) self.check_cost(result, 0, 0, 0, 0, True) def testQueryFuncAvgGroupBy(self): num_sids = 2 # test avg function statement = ('SELECT avg(fld_double) FROM ' + table_name) query_request = QueryRequest().set_statement(statement) result = self.handle.query(query_request) records = self.check_query_result(result, 1) self.assertEqual(records[0], {'Column_1': 3.1415}) self.check_cost(result, prepare_cost, prepare_cost, 0, 0, True) # test avg function group by primary index field statement = ('SELECT avg(fld_double) FROM ' + table_name + ' GROUP BY fld_sid') query_request = QueryRequest().set_statement(statement) count = 0 while True: count += 1 result = self.handle.query(query_request) records = result.get_results() if query_request.is_done(): self.check_query_result(result, num_sids, rec=records) for idx in range(num_sids): self.assertEqual(records[idx], {'Column_1': 3.1415}) self.check_cost(result, 0, 0, 0, 0) break else: self.check_query_result(result, 0, True, records) self.assertEqual(records, []) self.check_cost(result, 0, 0, 0, 0, True) self.assertEqual(count, 2) # test avg function group by secondary index field statement = ('SELECT avg(fld_double) FROM ' + table_name + ' GROUP BY fld_bool') query_request = QueryRequest().set_statement(statement) while True: result = self.handle.query(query_request) records = result.get_results() if query_request.is_done(): self.check_query_result(result, num_sids, rec=records) for idx in range(num_sids): self.assertEqual(records[idx], {'Column_1': 3.1415}) self.check_cost(result, 0, 0, 0, 0) break else: self.check_query_result(result, 0, True, records) self.assertEqual(records, []) self.check_cost(result, 0, 0, 0, 0, True) def testQueryFuncCountGroupBy(self): num_records = 12 num_sids = 2 # test count function statement = ('SELECT count(*) FROM ' + table_name) query_request = QueryRequest().set_statement(statement) result = self.handle.query(query_request) records = self.check_query_result(result, 1) self.assertEqual(records[0], {'Column_1': num_records}) self.check_cost(result, prepare_cost, prepare_cost, 0, 0, True) # test count function group by primary index field statement = ('SELECT count(*) FROM ' + table_name + ' GROUP BY fld_sid') query_request = QueryRequest().set_statement(statement) count = 0 while True: count += 1 result = self.handle.query(query_request) records = result.get_results() if query_request.is_done(): self.check_query_result(result, num_sids, rec=records) for idx in range(num_sids): self.assertEqual(records[idx], {'Column_1': num_records // num_sids}) self.check_cost(result, 0, 0, 0, 0) break else: self.check_query_result(result, 0, True, records) self.assertEqual(records, []) self.check_cost(result, 0, 0, 0, 0, True) self.assertEqual(count, 2) # test count function group by secondary index field statement = ('SELECT count(*) FROM ' + table_name + ' GROUP BY fld_bool') query_request = QueryRequest().set_statement(statement) while True: result = self.handle.query(query_request) records = result.get_results() if query_request.is_done(): self.check_query_result(result, num_sids, rec=records) for idx in range(num_sids): self.assertEqual(records[idx], {'Column_1': num_records // num_sids}) self.check_cost(result, 0, 0, 0, 0) break else: self.check_query_result(result, 0, True, records) self.assertEqual(records, []) self.check_cost(result, 0, 0, 0, 0, True) def testQueryOrderByWithLimit(self): num_records = 12 limit = 10 # test order by primary index field with limit statement = ('SELECT fld_str FROM ' + table_name + ' ORDER BY fld_sid, fld_id LIMIT 10') query_request = QueryRequest().set_statement(statement) count = 0 while True: count += 1 result = self.handle.query(query_request) records = result.get_results() if query_request.is_done(): self.check_query_result(result, limit, rec=records) for idx in range(limit): self.assertEqual( records[idx], { 'fld_str': '{"name": u' + str(num_records - idx - 1).zfill(2) + '}' }) self.check_cost(result, 0, 0, 0, 0) break else: self.check_query_result(result, 0, True, records) self.assertEqual(records, []) self.check_cost(result, 0, 0, 0, 0, True) self.assertEqual(count, 2) # test order by secondary index field with limit statement = ('SELECT fld_str FROM ' + table_name + ' ORDER BY fld_str LIMIT 10') query_request = QueryRequest().set_statement(statement) while True: result = self.handle.query(query_request) records = result.get_results() if query_request.is_done(): self.check_query_result(result, limit, rec=records) for idx in range(limit): self.assertEqual( records[idx], {'fld_str': '{"name": u' + str(idx).zfill(2) + '}'}) self.check_cost(result, 0, 0, 0, 0) break else: self.check_query_result(result, 0, True, records) self.assertEqual(records, []) self.check_cost(result, 0, 0, 0, 0, True) def testQueryOrderByWithOffset(self): offset = 4 num_get = 8 # test order by primary index field with offset statement = ('DECLARE $offset INTEGER; SELECT fld_str FROM ' + table_name + ' ORDER BY fld_sid, fld_id OFFSET $offset') prepare_request = PrepareRequest().set_statement(statement) prepare_result = self.handle.prepare(prepare_request) prepared_statement = prepare_result.get_prepared_statement() prepared_statement.set_variable('$offset', offset) query_request = QueryRequest().set_prepared_statement( prepared_statement) count = 0 while True: count += 1 result = self.handle.query(query_request) records = result.get_results() if query_request.is_done(): self.check_query_result(result, num_get, rec=records) for idx in range(num_get): self.assertEqual( records[idx], { 'fld_str': '{"name": u' + str(num_get - idx - 1).zfill(2) + '}' }) self.check_cost(result, 0, 0, 0, 0) break else: self.check_query_result(result, 0, True, records) self.assertEqual(records, []) self.check_cost(result, 0, 0, 0, 0, True) self.assertEqual(count, 2) # test order by secondary index field with offset statement = ('DECLARE $offset INTEGER; SELECT fld_str FROM ' + table_name + ' ORDER BY fld_str OFFSET $offset') prepare_request = PrepareRequest().set_statement(statement) prepare_result = self.handle.prepare(prepare_request) prepared_statement = prepare_result.get_prepared_statement() prepared_statement.set_variable('$offset', offset) query_request = QueryRequest().set_prepared_statement( prepared_statement) while True: result = self.handle.query(query_request) records = result.get_results() if query_request.is_done(): self.check_query_result(result, num_get, rec=records) for idx in range(num_get): self.assertEqual(records[idx], { 'fld_str': '{"name": u' + str(offset + idx).zfill(2) + '}' }) self.check_cost(result, 0, 0, 0, 0) break else: self.check_query_result(result, 0, True, records) self.assertEqual(records, []) self.check_cost(result, 0, 0, 0, 0, True) def testQueryFuncGeoNear(self): num_get = 6 longitude = 21.547 latitude = 37.291 # test geo_near function statement = ('SELECT tb.fld_json.location FROM ' + table_name + ' tb WHERE geo_near(tb.fld_json.location, ' + '{"type": "point", "coordinates": [' + str(longitude) + ', ' + str(latitude) + ']}, 215000)') query_request = QueryRequest().set_statement(statement) result = self.handle.query(query_request) records = self.check_query_result(result, num_get) for i in range(1, num_get): pre = records[i - 1]['location']['coordinates'] curr = records[i]['location']['coordinates'] self.assertLess(abs(pre[0] - longitude), abs(curr[0] - longitude)) self.assertLess(abs(pre[1] - latitude), abs(curr[1] - latitude)) self.check_cost(result, prepare_cost, prepare_cost, 0, 0, True) # test geo_near function order by primary index field statement = ( 'SELECT fld_str FROM ' + table_name + ' tb WHERE geo_near(' + 'tb.fld_json.location, {"type": "point", "coordinates": [' + str(longitude) + ', ' + str(latitude) + ']}, 215000) ' + 'ORDER BY fld_sid, fld_id') query_request = QueryRequest().set_statement(statement) count = 0 while True: count += 1 result = self.handle.query(query_request) records = result.get_results() if query_request.is_done(): self.check_query_result(result, num_get, rec=records) name = [10, 9, 8, 4, 3, 2] for i in range(num_get): self.assertEqual(records[i], { 'fld_str': '{"name": u' + str(name[i]).zfill(2) + '}' }) self.check_cost(result, 0, 0, 0, 0) break else: self.check_query_result(result, 0, True, records) self.assertEqual(records, []) self.check_cost(result, 0, 0, 0, 0, True) self.assertEqual(count, 2) # test geo_near function order by secondary index field statement = ( 'SELECT fld_str FROM ' + table_name + ' tb WHERE geo_near(' + 'tb.fld_json.location, {"type": "point", "coordinates": [' + str(longitude) + ', ' + str(latitude) + ']}, 215000) ' + 'ORDER BY fld_str') query_request = QueryRequest().set_statement(statement) while True: result = self.handle.query(query_request) records = result.get_results() if query_request.is_done(): self.check_query_result(result, num_get, rec=records) name = [2, 3, 4, 8, 9, 10] for i in range(num_get): self.assertEqual(records[i], { 'fld_str': '{"name": u' + str(name[i]).zfill(2) + '}' }) self.check_cost(result, 0, 0, 0, 0) break else: self.check_query_result(result, 0, True, records) self.assertEqual(records, []) self.check_cost(result, 0, 0, 0, 0, True) @staticmethod def _expected_row(fld_sid, fld_id, fld_long=None): expected_row = OrderedDict() expected_row['fld_sid'] = fld_sid expected_row['fld_id'] = fld_id if fld_long is not None: expected_row['fld_long'] = fld_long return expected_row
class TestGet(unittest.TestCase, TestBase): @classmethod def setUpClass(cls): TestBase.set_up_class() table_ttl = TimeToLive.of_hours(16) create_statement = ('CREATE TABLE ' + table_name + '(fld_sid INTEGER, fld_id INTEGER, \ fld_long LONG, fld_float FLOAT, fld_double DOUBLE, fld_bool BOOLEAN, \ fld_str STRING, fld_bin BINARY, fld_time TIMESTAMP(7), fld_num NUMBER, \ fld_json JSON, fld_arr ARRAY(STRING), fld_map MAP(STRING), \ fld_rec RECORD(fld_id LONG, fld_bool BOOLEAN, fld_str STRING), \ PRIMARY KEY(SHARD(fld_sid), fld_id)) USING TTL ' + str(table_ttl)) create_request = TableRequest().set_statement( create_statement).set_table_limits(TableLimits(5000, 5000, 50)) cls._result = TestBase.table_request(create_request, State.ACTIVE) global row, tb_expect_expiration, hour_in_milliseconds row = { 'fld_sid': 1, 'fld_id': 1, 'fld_long': 2147483648, 'fld_float': 3.1414999961853027, 'fld_double': 3.1415, 'fld_bool': True, 'fld_str': '{"name": u1, "phone": null}', 'fld_bin': bytearray(pack('>i', 4)), 'fld_time': datetime.now(), 'fld_num': Decimal(5), 'fld_json': { 'a': '1', 'b': None, 'c': '3' }, 'fld_arr': ['a', 'b', 'c'], 'fld_map': { 'a': '1', 'b': '2', 'c': '3' }, 'fld_rec': { 'fld_id': 1, 'fld_bool': False, 'fld_str': None } } put_request = PutRequest().set_value(row).set_table_name(table_name) cls._handle.put(put_request) tb_expect_expiration = table_ttl.to_expiration_time( int(round(time() * 1000))) hour_in_milliseconds = 60 * 60 * 1000 @classmethod def tearDownClass(cls): TestBase.tear_down_class() def setUp(self): TestBase.set_up(self) self.key = {'fld_sid': 1, 'fld_id': 1} self.get_request = GetRequest().set_key( self.key).set_table_name(table_name).set_timeout(timeout) def tearDown(self): TestBase.tear_down(self) def testGetSetIllegalKey(self): self.assertRaises(IllegalArgumentException, self.get_request.set_key, 'IllegalKey') self.get_request.set_key({'fld_sid': 1}) self.assertRaises(IllegalArgumentException, self.handle.get, self.get_request) self.get_request.set_key({'fld_id': 1}) self.assertRaises(IllegalArgumentException, self.handle.get, self.get_request) def testGetSetIllegalKeyFromJson(self): self.assertRaises(ValueError, self.get_request.set_key_from_json, 'IllegalJson') self.get_request.set_key_from_json('{"invalid_field": "key"}') self.assertRaises(IllegalArgumentException, self.handle.get, self.get_request) def testGetSetIllegalTableName(self): self.assertRaises(IllegalArgumentException, self.get_request.set_table_name, {'name': table_name}) self.get_request.set_table_name('IllegalTable') self.assertRaises(TableNotFoundException, self.handle.get, self.get_request) def testGetSetIllegalConsistency(self): self.assertRaises(IllegalArgumentException, self.get_request.set_consistency, 'IllegalConsistency') def testGetSetIllegalTimeout(self): self.assertRaises(IllegalArgumentException, self.get_request.set_timeout, 'IllegalTimeout') self.assertRaises(IllegalArgumentException, self.get_request.set_timeout, 0) self.assertRaises(IllegalArgumentException, self.get_request.set_timeout, -1) def testGetWithoutKey(self): self.get_request.set_key(None) self.assertRaises(IllegalArgumentException, self.handle.get, self.get_request) def testGetGets(self): self.assertEqual(self.get_request.get_key(), self.key) def testGetIllegalRequest(self): self.assertRaises(IllegalArgumentException, self.handle.get, 'IllegalRequest') def testGetNormal(self): result = self.handle.get(self.get_request) self.assertEqual(result.get_value(), row) actual_expiration = result.get_expiration_time() actual_expect_diff = actual_expiration - tb_expect_expiration self.assertGreater(actual_expiration, 0) self.assertLess(actual_expect_diff, hour_in_milliseconds) self.assertEqual(result.get_read_kb(), 1) self.assertEqual(result.get_read_units(), 2) self.assertEqual(result.get_write_kb(), 0) self.assertEqual(result.get_write_units(), 0) def testGetEventual(self): self.get_request.set_consistency(Consistency.EVENTUAL) result = self.handle.get(self.get_request) self.assertEqual(result.get_value(), row) self.assertIsNotNone(result.get_version()) actual_expiration = result.get_expiration_time() actual_expect_diff = actual_expiration - tb_expect_expiration self.assertGreater(actual_expiration, 0) self.assertLess(actual_expect_diff, hour_in_milliseconds) self.assertEqual(result.get_read_kb(), 1) self.assertEqual(result.get_read_units(), 1) self.assertEqual(result.get_write_kb(), 0) self.assertEqual(result.get_write_units(), 0) def testGetNonExisting(self): self.get_request.set_key({'fld_sid': 2, 'fld_id': 2}) result = self.handle.get(self.get_request) self.assertIsNone(result.get_value()) self.assertIsNone(result.get_version()) self.assertEqual(result.get_expiration_time(), 0) self.assertEqual(result.get_read_kb(), 1) self.assertEqual(result.get_read_units(), 2) self.assertEqual(result.get_write_kb(), 0) self.assertEqual(result.get_write_units(), 0)
class TestWriteMultiple(unittest.TestCase, TestBase): @classmethod def setUpClass(cls): cls.set_up_class() create_statement = ('CREATE TABLE ' + table_name + '(fld_sid INTEGER, fld_id INTEGER, \ fld_long LONG, fld_float FLOAT, fld_double DOUBLE, fld_bool BOOLEAN, \ fld_str STRING, fld_bin BINARY, fld_time TIMESTAMP(8), fld_num NUMBER, \ fld_json JSON, fld_arr ARRAY(STRING), fld_map MAP(STRING), \ fld_rec RECORD(fld_id LONG, fld_bool BOOLEAN, fld_str STRING), \ PRIMARY KEY(SHARD(fld_sid), fld_id))') limits = TableLimits(50, 50, 1) create_request = TableRequest().set_statement( create_statement).set_table_limits(limits) cls.table_request(create_request) @classmethod def tearDownClass(cls): cls.tear_down_class() def setUp(self): self.set_up() self.shardkeys = [0, 1] self.ids = [0, 1, 2, 3, 4, 5] self.rows = list() self.new_rows = list() self.versions = list() self.requests = list() self.illegal_requests = list() ttl = TimeToLive.of_days(16) for sk in self.shardkeys: self.rows.append(list()) self.new_rows.append(list()) self.versions.append(list()) for i in self.ids: row = get_row() row['fld_sid'] = sk row['fld_id'] = i new_row = deepcopy(row) new_row['fld_long'] = 2147483649 self.rows[sk].append(row) self.new_rows[sk].append(new_row) put_request = PutRequest().set_value(row).set_table_name( table_name).set_ttl(ttl) self.versions[sk].append( self.handle.put(put_request).get_version()) self.old_expect_expiration = ttl.to_expiration_time( int(round(time() * 1000))) self.ttl = TimeToLive.of_hours(1) self.ops_sk = 0 illegal_sk = 1 self.requests.append(PutRequest().set_value( self.new_rows[self.ops_sk][0]).set_table_name(table_name).set_ttl( self.ttl).set_return_row(True)) self.requests.append(PutRequest().set_value(self.new_rows[ self.ops_sk][1]).set_table_name(table_name).set_option( PutOption.IF_ABSENT).set_ttl(self.ttl).set_return_row(True)) self.requests.append(PutRequest().set_value( self.new_rows[self.ops_sk][2]).set_use_table_default_ttl( True).set_table_name(table_name).set_option( PutOption.IF_PRESENT).set_return_row(True)) self.requests.append(PutRequest().set_value( self.new_rows[self.ops_sk][3]).set_table_name(table_name).set_ttl( self.ttl).set_option(PutOption.IF_VERSION).set_match_version( self.versions[self.ops_sk][3]).set_return_row(True)) self.requests.append(DeleteRequest().set_key({ 'fld_sid': self.ops_sk, 'fld_id': 4 }).set_table_name(table_name).set_return_row(True)) self.requests.append(DeleteRequest().set_key({ 'fld_sid': self.ops_sk, 'fld_id': 5 }).set_table_name(table_name).set_return_row(True).set_match_version( self.versions[self.ops_sk][0])) self.illegal_requests.append(DeleteRequest().set_key({ 'fld_sid': self.ops_sk, 'fld_id': 0 }).set_table_name('IllegalUsers')) self.illegal_requests.append(DeleteRequest().set_key({ 'fld_sid': illegal_sk, 'fld_id': 0 }).set_table_name(table_name)) self.write_multiple_request = WriteMultipleRequest().set_timeout( timeout) self.get_request = GetRequest().set_table_name(table_name) self.hour_in_milliseconds = 60 * 60 * 1000 self.day_in_milliseconds = 24 * 60 * 60 * 1000 def tearDown(self): for sk in self.shardkeys: key = {'fld_sid': sk} request = MultiDeleteRequest().set_table_name(table_name).set_key( key) self.handle.multi_delete(request) self.tear_down() def testWriteMultipleSetIllegalCompartment(self): self.assertRaises(IllegalArgumentException, self.write_multiple_request.set_compartment, {}) self.assertRaises(IllegalArgumentException, self.write_multiple_request.set_compartment, '') def testWriteMultipleAddIllegalRequestAndAbortIfUnsuccessful(self): self.assertRaises(IllegalArgumentException, self.write_multiple_request.add, 'IllegalRequest', True) self.assertRaises(IllegalArgumentException, self.write_multiple_request.add, PutRequest(), 'IllegalAbortIfUnsuccessful') # add two operations with different table name self.write_multiple_request.add(self.requests[0], True) self.assertRaises(IllegalArgumentException, self.write_multiple_request.add, self.illegal_requests[0], False) self.write_multiple_request.clear() # add two operations with different major paths self.write_multiple_request.add(self.requests[0], True).add(self.illegal_requests[1], False) self.assertRaises(IllegalArgumentException, self.handle.write_multiple, self.write_multiple_request) if not is_onprem(): # add operations when the request size exceeded the limit self.write_multiple_request.clear() for op in range(64): row = get_row() row['fld_str'] = self.get_random_str(0.4) self.write_multiple_request.add( PutRequest().set_value(row).set_table_name(table_name), True) self.assertRaises(RequestSizeLimitException, self.handle.write_multiple, self.write_multiple_request) # add operations when sub requests reached the max number self.write_multiple_request.clear() for op in range(51): row = get_row() row['fld_id'] = op self.write_multiple_request.add( PutRequest().set_value(row).set_table_name(table_name), True) self.assertRaises(BatchOperationNumberLimitException, self.handle.write_multiple, self.write_multiple_request) def testWriteMultipleGetRequestWithIllegalIndex(self): self.assertRaises(IllegalArgumentException, self.write_multiple_request.get_request, 'IllegalIndex') self.assertRaises(IllegalArgumentException, self.write_multiple_request.get_request, -1) self.assertRaises(IndexError, self.write_multiple_request.get_request, 0) def testWriteMultipleSetIllegalTimeout(self): self.assertRaises(IllegalArgumentException, self.write_multiple_request.set_timeout, 'IllegalTimeout') self.assertRaises(IllegalArgumentException, self.write_multiple_request.set_timeout, 0) self.assertRaises(IllegalArgumentException, self.write_multiple_request.set_timeout, -1) def testWriteMultipleNoOperations(self): self.assertRaises(IllegalArgumentException, self.handle.write_multiple, self.write_multiple_request) def testWriteMultipleGets(self): num_operations = 6 for request in self.requests: self.write_multiple_request.add(request, True) self.assertIsNone(self.write_multiple_request.get_compartment()) self.assertEqual(self.write_multiple_request.get_table_name(), table_name) self.assertEqual(self.write_multiple_request.get_request(2), self.requests[2]) requests = self.write_multiple_request.get_operations() for idx in range(len(requests)): self.assertEqual(requests[idx].get_request(), self.requests[idx]) self.assertTrue(requests[idx].is_abort_if_unsuccessful()) self.assertEqual(self.write_multiple_request.get_num_operations(), num_operations) self.assertEqual(self.write_multiple_request.get_timeout(), timeout) self.write_multiple_request.clear() self.assertIsNone(self.write_multiple_request.get_table_name()) self.assertEqual(self.write_multiple_request.get_operations(), []) self.assertEqual(self.write_multiple_request.get_num_operations(), 0) self.assertEqual(self.write_multiple_request.get_timeout(), timeout) def testWriteMultipleNormal(self): num_operations = 6 for request in self.requests: self.write_multiple_request.add(request, False) result = self.handle.write_multiple(self.write_multiple_request) expect_expiration = self.ttl.to_expiration_time( int(round(time() * 1000))) op_results = self._check_write_multiple_result(result, num_operations) for idx in range(result.size()): if idx == 1 or idx == 5: # putIfAbsent and deleteIfVersion failed self._check_operation_result( op_results[idx], existing_version=self.versions[self.ops_sk][idx], existing_value=self.rows[self.ops_sk][idx]) elif idx == 4: # delete succeed self._check_operation_result(op_results[idx], success=True) else: # put, putIfPresent and putIfVersion succeed self._check_operation_result(op_results[idx], True, True) self.check_cost(result, 5, 10, 7, 7) # check the records after write_multiple request succeed for sk in self.shardkeys: for i in self.ids: self.get_request.set_key({'fld_sid': sk, 'fld_id': i}) result = self.handle.get(self.get_request) if sk == 1 or i == 1 or i == 5: self.check_get_result(result, self.rows[sk][i], self.versions[sk][i], self.old_expect_expiration, TimeUnit.DAYS) elif i == 4: self.check_get_result(result) elif i == 2: self.check_get_result(result, self.new_rows[sk][i], self.versions[sk][i], ver_eq=False) else: self.check_get_result(result, self.new_rows[sk][i], self.versions[sk][i], expect_expiration, TimeUnit.HOURS, False) self.check_cost(result, 1, 2, 0, 0) def testWriteMultipleAbortIfUnsuccessful(self): failed_idx = 1 for request in self.requests: self.write_multiple_request.add(request, True) result = self.handle.write_multiple(self.write_multiple_request) op_results = self._check_write_multiple_result(result, 1, True, failed_idx, False) self._check_operation_result( op_results[0], existing_version=self.versions[self.ops_sk][failed_idx], existing_value=self.rows[self.ops_sk][failed_idx]) failed_result = result.get_failed_operation_result() self._check_operation_result( failed_result, existing_version=self.versions[self.ops_sk][failed_idx], existing_value=self.rows[self.ops_sk][failed_idx]) self.check_cost(result, 1, 2, 2, 2) # check the records after multi_delete request failed for sk in self.shardkeys: for i in self.ids: self.get_request.set_key({'fld_sid': sk, 'fld_id': i}) result = self.handle.get(self.get_request) self.check_get_result(result, self.rows[sk][i], self.versions[sk][i], self.old_expect_expiration, TimeUnit.DAYS) self.check_cost(result, 1, 2, 0, 0) def testWriteMultipleWithIdentityColumn(self): num_operations = 10 id_table = table_prefix + 'Identity' create_request = TableRequest().set_statement( 'CREATE TABLE ' + id_table + '(sid INTEGER, id LONG GENERATED \ ALWAYS AS IDENTITY, name STRING, PRIMARY KEY(SHARD(sid), id))') create_request.set_table_limits(TableLimits(50, 50, 1)) self.table_request(create_request) # add ten operations row = {'name': 'myname', 'sid': 1} for idx in range(num_operations): put_request = PutRequest().set_table_name(id_table).set_value(row) put_request.set_identity_cache_size(idx) self.write_multiple_request.add(put_request, False) # execute the write multiple request versions = list() result = self.handle.write_multiple(self.write_multiple_request) op_results = self._check_write_multiple_result(result, num_operations) generated = 0 for idx in range(result.size()): version, generated = self._check_operation_result( op_results[idx], True, True, generated) versions.append(version) self.check_cost(result, 0, 0, num_operations, num_operations) # check the records after write_multiple request succeed self.get_request.set_table_name(id_table) for idx in range(num_operations): curr_id = generated - num_operations + idx + 1 self.get_request.set_key({'sid': 1, 'id': curr_id}) result = self.handle.get(self.get_request) expected = OrderedDict() expected['sid'] = 1 expected['id'] = curr_id expected['name'] = 'myname' self.check_get_result(result, expected, versions[idx]) self.check_cost(result, 1, 2, 0, 0) def _check_operation_result(self, op_result, version=False, success=False, last_generated=None, existing_version=None, existing_value=None): # check version of operation result ver = op_result.get_version() self.assertIsNotNone(ver) if version else self.assertIsNone(ver) # check if the operation success self.assertEqual(op_result.get_success(), success) # check generated value of operation result generated = op_result.get_generated_value() if last_generated is None: self.assertIsNone(generated) else: self.assertGreater(generated, last_generated) # check existing version existing_ver = op_result.get_existing_version() (self.assertIsNone(existing_ver) if existing_version is None else self.assertEqual( existing_ver.get_bytes(), existing_version.get_bytes())) # check existing value self.assertEqual(op_result.get_existing_value(), existing_value) return ver, generated def _check_write_multiple_result(self, result, num_operations, has_failed_operation=False, failed_index=-1, success=True): # check number of operations self.assertEqual(result.size(), num_operations) # check failed operation failed_result = result.get_failed_operation_result() (self.assertIsNotNone(failed_result) if has_failed_operation else self.assertIsNone(failed_result)) # check failed operation index self.assertEqual(result.get_failed_operation_index(), failed_index) # check operation status self.assertEqual(result.get_success(), success) return result.get_results()
def _do_rate_limited_ops(self, num_seconds, read_limit, write_limit, max_rows, check_units, use_percent, use_external_limiters): """ Runs puts and gets continuously for N seconds. Verify that the resultant RUs/WUs used match the given rate limits. """ if read_limit == 0 and write_limit == 0: return put_request = PutRequest().set_table_name(table_name) get_request = GetRequest().set_table_name(table_name) key = dict() # TODO: random sizes 0-nKB. value = dict() value['name'] = 'jane' start_time = int(round(time() * 1000)) end_time = start_time + num_seconds * 1000 read_units_used = 0 write_units_used = 0 total_delayed_ms = 0 throttle_exceptions = 0 rlim = None wlim = None max_val = float(read_limit + write_limit) if not use_external_limiters: # Reset internal limiters so they don't have unused units. self.handle.get_client().reset_rate_limiters(table_name) else: rlim = SimpleRateLimiter(read_limit * use_percent / 100.0, 1) wlim = SimpleRateLimiter(write_limit * use_percent / 100.0, 1) while True: fld_id = int(random() * max_rows) if read_limit == 0: do_put = True elif write_limit == 0: do_put = False else: v = int(random() * max_val) do_put = v >= read_limit try: if do_put: value['id'] = fld_id put_request.set_value(value).set_read_rate_limiter( None).set_write_rate_limiter(wlim) pres = self.handle.put(put_request) write_units_used += pres.get_write_units() total_delayed_ms += pres.get_rate_limit_delayed_ms() rs = pres.get_retry_stats() if rs is not None: throttle_exceptions += rs.get_num_exceptions( WriteThrottlingException.__class__.__name__) else: key['id'] = fld_id get_request.set_key(key).set_read_rate_limiter( rlim).set_write_rate_limiter(None) gres = self.handle.get(get_request) read_units_used += gres.get_read_units() total_delayed_ms += gres.get_rate_limit_delayed_ms() rs = gres.get_retry_stats() if rs is not None: throttle_exceptions += rs.get_num_exceptions( ReadThrottlingException.__class__.__name__) except ReadThrottlingException: self.fail( 'Expected no read throttling exceptions, got one.') except WriteThrottlingException: self.fail( 'Expected no write throttling exceptions, got one.') if int(round(time() * 1000)) >= end_time: break num_seconds = (int(round(time() * 1000)) - start_time) / 1000 rus = read_units_used / num_seconds wus = write_units_used / num_seconds if not check_units: return use_percent /= 100.0 if (rus < read_limit * use_percent * 0.8 or rus > read_limit * use_percent * 1.2): self.fail('Gets: Expected around ' + str(read_limit * use_percent) + ' RUs, got ' + str(rus)) if (wus < write_limit * use_percent * 0.8 or wus > write_limit * use_percent * 1.2): self.fail('Puts: Expected around ' + str(write_limit * use_percent) + ' WUs, got ' + str(wus))
class TestWriteMultiple(unittest.TestCase, TestBase): @classmethod def setUpClass(cls): TestBase.set_up_class() create_statement = ( 'CREATE TABLE ' + table_name + '(fld_sid INTEGER, fld_id INTEGER, \ fld_long LONG, fld_float FLOAT, fld_double DOUBLE, fld_bool BOOLEAN, \ fld_str STRING, fld_bin BINARY, fld_time TIMESTAMP(8), fld_num NUMBER, \ fld_json JSON, fld_arr ARRAY(STRING), fld_map MAP(STRING), \ fld_rec RECORD(fld_id LONG, fld_bool BOOLEAN, fld_str STRING), \ PRIMARY KEY(SHARD(fld_sid), fld_id))') limits = TableLimits(5000, 5000, 50) create_request = TableRequest().set_statement( create_statement).set_table_limits(limits) cls._result = TestBase.table_request(create_request, State.ACTIVE) @classmethod def tearDownClass(cls): TestBase.tear_down_class() def setUp(self): TestBase.set_up(self) self.shardkeys = [0, 1] self.ids = [0, 1, 2, 3, 4, 5] self.rows = list() self.new_rows = list() self.versions = list() self.requests = list() self.illegal_requests = list() ttl = TimeToLive.of_days(16) for sk in self.shardkeys: self.rows.append(list()) self.new_rows.append(list()) self.versions.append(list()) for i in self.ids: row = {'fld_sid': sk, 'fld_id': i, 'fld_long': 2147483648, 'fld_float': 3.1414999961853027, 'fld_double': 3.1415, 'fld_bool': True, 'fld_str': '{"name": u1, "phone": null}', 'fld_bin': bytearray(pack('>i', 4)), 'fld_time': datetime.now(), 'fld_num': Decimal(5), 'fld_json': {'a': '1', 'b': None, 'c': '3'}, 'fld_arr': ['a', 'b', 'c'], 'fld_map': {'a': '1', 'b': '2', 'c': '3'}, 'fld_rec': {'fld_id': 1, 'fld_bool': False, 'fld_str': None}} new_row = deepcopy(row) new_row.update({'fld_long': 2147483649}) self.rows[sk].append(row) self.new_rows[sk].append(new_row) put_request = PutRequest().set_value(row).set_table_name( table_name).set_ttl(ttl) self.versions[sk].append( self.handle.put(put_request).get_version()) self.old_expect_expiration = ttl.to_expiration_time( int(round(time() * 1000))) self.ttl = TimeToLive.of_hours(1) self.ops_sk = 0 illegal_sk = 1 self.requests.append(PutRequest().set_value( self.new_rows[self.ops_sk][0]).set_table_name(table_name).set_ttl( self.ttl).set_return_row(True)) self.requests.append(PutRequest().set_value( self.new_rows[self.ops_sk][1]).set_table_name( table_name).set_option(PutOption.IF_ABSENT).set_ttl( self.ttl).set_return_row(True)) self.requests.append(PutRequest().set_value( self.new_rows[self.ops_sk][2]).set_use_table_default_ttl( True).set_table_name(table_name).set_option( PutOption.IF_PRESENT).set_return_row(True)) self.requests.append(PutRequest().set_value( self.new_rows[self.ops_sk][3]).set_table_name( table_name).set_option(PutOption.IF_VERSION).set_ttl( self.ttl).set_match_version( self.versions[self.ops_sk][3]).set_return_row(True)) self.requests.append(DeleteRequest().set_key( {'fld_sid': self.ops_sk, 'fld_id': 4}).set_table_name( table_name).set_return_row(True)) self.requests.append(DeleteRequest().set_key( {'fld_sid': self.ops_sk, 'fld_id': 5}).set_table_name( table_name).set_match_version( self.versions[self.ops_sk][0]).set_return_row(True)) self.illegal_requests.append(DeleteRequest().set_key( {'fld_sid': self.ops_sk, 'fld_id': 0}).set_table_name( 'IllegalUsers')) self.illegal_requests.append(DeleteRequest().set_key( {'fld_sid': illegal_sk, 'fld_id': 0}).set_table_name(table_name)) self.write_multiple_request = WriteMultipleRequest().set_timeout( timeout) self.get_request = GetRequest().set_table_name(table_name) self.hour_in_milliseconds = 60 * 60 * 1000 self.day_in_milliseconds = 24 * 60 * 60 * 1000 def tearDown(self): for sk in self.shardkeys: key = {'fld_sid': sk} request = MultiDeleteRequest().set_table_name( table_name).set_key(key) self.handle.multi_delete(request) TestBase.tear_down(self) def testWriteMultipleAddIllegalRequestAndAbortIfUnsuccessful(self): self.assertRaises(IllegalArgumentException, self.write_multiple_request.add, 'IllegalRequest', True) self.assertRaises(IllegalArgumentException, self.write_multiple_request.add, PutRequest(), 'IllegalAbortIfUnsuccessful') # add two operations with different table name self.write_multiple_request.add(self.requests[0], True) self.assertRaises(IllegalArgumentException, self.write_multiple_request.add, self.illegal_requests[0], False) self.write_multiple_request.clear() # add two operations with different major paths self.write_multiple_request.add( self.requests[0], True).add(self.illegal_requests[1], False) self.assertRaises(IllegalArgumentException, self.handle.write_multiple, self.write_multiple_request) self.write_multiple_request.clear() # add operations when sub requests reached the max number count = 0 while count < 50: self.write_multiple_request.add(self.requests[0], True) count += 1 self.assertRaises(BatchOperationNumberLimitException, self.write_multiple_request.add, self.requests[0], True) def testWriteMultipleGetRequestWithIllegalIndex(self): self.assertRaises(IllegalArgumentException, self.write_multiple_request.get_request, 'IllegalIndex') self.assertRaises(IllegalArgumentException, self.write_multiple_request.get_request, -1) self.assertRaises(IndexError, self.write_multiple_request.get_request, 0) def testWriteMultipleSetIllegalTimeout(self): self.assertRaises(IllegalArgumentException, self.write_multiple_request.set_timeout, 'IllegalTimeout') self.assertRaises(IllegalArgumentException, self.write_multiple_request.set_timeout, 0) self.assertRaises(IllegalArgumentException, self.write_multiple_request.set_timeout, -1) def testWriteMultipleNoOperations(self): self.assertRaises(IllegalArgumentException, self.handle.write_multiple, self.write_multiple_request) def testWriteMultipleGets(self): num_operations = 6 for request in self.requests: self.write_multiple_request.add(request, True) self.assertEqual(self.write_multiple_request.get_table_name(), table_name) self.assertEqual(self.write_multiple_request.get_request(2), self.requests[2]) requests = self.write_multiple_request.get_operations() for idx in range(len(requests)): self.assertEqual(requests[idx].get_request(), self.requests[idx]) self.assertTrue(requests[idx].is_abort_if_unsuccessful()) self.assertEqual(self.write_multiple_request.get_num_operations(), num_operations) self.assertEqual(self.write_multiple_request.get_timeout(), timeout) self.write_multiple_request.clear() self.assertIsNone(self.write_multiple_request.get_table_name()) self.assertEqual(self.write_multiple_request.get_operations(), []) self.assertEqual(self.write_multiple_request.get_num_operations(), 0) self.assertEqual(self.write_multiple_request.get_timeout(), timeout) def testWriteMultipleNormal(self): num_operations = 6 for request in self.requests: self.write_multiple_request.add(request, False) result = self.handle.write_multiple(self.write_multiple_request) expect_expiration = self.ttl.to_expiration_time( int(round(time() * 1000))) self.assertEqual(result.size(), num_operations) op_results = result.get_results() for idx in range(result.size()): if idx == 1 or idx == 5: # putIfAbsent and deleteIfVersion failed self.assertIsNone(op_results[idx].get_version()) self.assertFalse(op_results[idx].get_success()) self.assertEqual( op_results[idx].get_existing_version().get_bytes(), self.versions[self.ops_sk][idx].get_bytes()) self.assertEqual(op_results[idx].get_existing_value(), self.rows[self.ops_sk][idx]) elif idx == 4: # delete succeed self.assertIsNone(op_results[idx].get_version()) self.assertTrue(op_results[idx].get_success()) self.assertIsNone(op_results[idx].get_existing_version()) self.assertIsNone(op_results[idx].get_existing_value()) else: # put, putIfPresent and putIfVersion succeed self.assertIsNotNone(op_results[idx].get_version()) self.assertNotEqual(op_results[idx].get_version(), self.versions[self.ops_sk][idx]) self.assertTrue(op_results[idx].get_success()) self.assertIsNone(op_results[idx].get_existing_version()) self.assertIsNone(op_results[idx].get_existing_value()) self.assertIsNone(result.get_failed_operation_result()) self.assertEqual(result.get_failed_operation_index(), -1) self.assertTrue(result.get_success()) self.assertEqual(result.get_read_kb(), 0 + 1 + 1 + 1 + 1 + 1) self.assertEqual(result.get_read_units(), 0 + 2 + 2 + 2 + 2 + 2) self.assertEqual(result.get_write_kb(), 2 + 0 + 2 + 2 + 1 + 0) self.assertEqual(result.get_write_units(), 2 + 0 + 2 + 2 + 1 + 0) # check the records after write_multiple request succeed for sk in self.shardkeys: for i in self.ids: self.get_request.set_key({'fld_sid': sk, 'fld_id': i}) result = self.handle.get(self.get_request) if sk == 1 or i == 1 or i == 5: self.assertEqual(result.get_value(), self.rows[sk][i]) self.assertEqual(result.get_version().get_bytes(), self.versions[sk][i].get_bytes()) actual_expiration = result.get_expiration_time() actual_expect_diff = (actual_expiration - self.old_expect_expiration) self.assertGreater(actual_expiration, 0) self.assertLess(actual_expect_diff, self.day_in_milliseconds) elif i == 4: self.assertIsNone(result.get_value()) self.assertIsNone(result.get_version()) self.assertEqual(result.get_expiration_time(), 0) else: self.assertEqual(result.get_value(), self.new_rows[sk][i]) self.assertNotEqual(result.get_version().get_bytes(), 0) self.assertNotEqual(result.get_version().get_bytes(), self.versions[sk][i].get_bytes()) if i == 2: self.assertEqual(result.get_expiration_time(), 0) else: actual_expiration = result.get_expiration_time() actual_expect_diff = (actual_expiration - expect_expiration) self.assertGreater(actual_expiration, 0) self.assertLess(actual_expect_diff, self.hour_in_milliseconds) self.assertEqual(result.get_read_kb(), 1) self.assertEqual(result.get_read_units(), 2) self.assertEqual(result.get_write_kb(), 0) self.assertEqual(result.get_write_units(), 0) def testWriteMultipleAbortIfUnsuccessful(self): failed_idx = 1 for request in self.requests: self.write_multiple_request.add(request, True) result = self.handle.write_multiple(self.write_multiple_request) self.assertEqual(result.size(), 1) op_results = result.get_results() self.assertIsNone(op_results[0].get_version()) self.assertFalse(op_results[0].get_success()) self.assertEqual(op_results[0].get_existing_version().get_bytes(), self.versions[self.ops_sk][failed_idx].get_bytes()) self.assertEqual(op_results[0].get_existing_value(), self.rows[self.ops_sk][failed_idx]) failed_result = result.get_failed_operation_result() self.assertIsNone(failed_result.get_version()) self.assertFalse(failed_result.get_success()) self.assertEqual(failed_result.get_existing_version().get_bytes(), self.versions[self.ops_sk][failed_idx].get_bytes()) self.assertEqual(failed_result.get_existing_value(), self.rows[self.ops_sk][failed_idx]) self.assertEqual(result.get_failed_operation_index(), failed_idx) self.assertFalse(result.get_success()) self.assertEqual(result.get_read_kb(), 0 + 1) self.assertEqual(result.get_read_units(), 0 + 2) self.assertEqual(result.get_write_kb(), 2 + 0) self.assertEqual(result.get_write_units(), 2 + 0) # check the records after multi_delete request failed for sk in self.shardkeys: for i in self.ids: self.get_request.set_key({'fld_sid': sk, 'fld_id': i}) result = self.handle.get(self.get_request) self.assertEqual(result.get_value(), self.rows[sk][i]) self.assertEqual(result.get_version().get_bytes(), self.versions[sk][i].get_bytes()) actual_expiration = result.get_expiration_time() actual_expect_diff = (actual_expiration - self.old_expect_expiration) self.assertGreater(actual_expiration, 0) self.assertLess(actual_expect_diff, self.day_in_milliseconds) self.assertEqual(result.get_read_kb(), 1) self.assertEqual(result.get_read_units(), 2) self.assertEqual(result.get_write_kb(), 0) self.assertEqual(result.get_write_units(), 0)
class TestQuery(unittest.TestCase, TestBase): @classmethod def setUpClass(cls): TestBase.set_up_class() index_name = 'idx_' + table_name create_statement = ('CREATE TABLE ' + table_name + '(fld_sid INTEGER, fld_id INTEGER, \ fld_long LONG, fld_float FLOAT, fld_double DOUBLE, fld_bool BOOLEAN, \ fld_str STRING, fld_bin BINARY, fld_time TIMESTAMP(6), fld_num NUMBER, \ fld_json JSON, fld_arr ARRAY(STRING), fld_map MAP(STRING), \ fld_rec RECORD(fld_id LONG, fld_bool BOOLEAN, fld_str STRING), \ PRIMARY KEY(SHARD(fld_sid), fld_id))') limits = TableLimits(5000, 5000, 50) create_request = TableRequest().set_statement( create_statement).set_table_limits(limits) cls._result = TestBase.table_request(create_request, State.ACTIVE) create_index_statement = ('CREATE INDEX ' + index_name + ' ON ' + table_name + '(fld_long)') create_index_request = TableRequest().set_statement( create_index_statement) cls._result = TestBase.table_request(create_index_request, State.ACTIVE) global prepare_cost prepare_cost = 2 global query_statement query_statement = ('SELECT fld_sid, fld_id FROM ' + table_name + ' WHERE fld_sid = 1') @classmethod def tearDownClass(cls): TestBase.tear_down_class() def setUp(self): TestBase.set_up(self) self.handle_config = get_handle_config(tenant_id) shardkeys = [0, 1] ids = [0, 1, 2, 3, 4, 5] write_multiple_request = WriteMultipleRequest() for sk in shardkeys: for i in ids: row = { 'fld_sid': sk, 'fld_id': i, 'fld_long': 2147483648, 'fld_float': 3.1414999961853027, 'fld_double': 3.1415, 'fld_bool': True, 'fld_str': '{"name": u1, "phone": null}', 'fld_bin': bytearray(pack('>i', 4)), 'fld_time': datetime.now(), 'fld_num': Decimal(5), 'fld_json': { 'a': '1', 'b': None, 'c': '3' }, 'fld_arr': ['a', 'b', 'c'], 'fld_map': { 'a': '1', 'b': '2', 'c': '3' }, 'fld_rec': { 'fld_id': 1, 'fld_bool': False, 'fld_str': None } } write_multiple_request.add( PutRequest().set_value(row).set_table_name(table_name), True) self.handle.write_multiple(write_multiple_request) write_multiple_request.clear() prepare_statement_update = ( 'DECLARE $fld_sid INTEGER; $fld_id INTEGER; UPDATE ' + table_name + ' u SET u.fld_long = u.fld_long + 1 WHERE fld_sid = $fld_sid ' + 'AND fld_id = $fld_id') prepare_request_update = PrepareRequest().set_statement( prepare_statement_update) self.prepare_result_update = self.handle.prepare( prepare_request_update) prepare_statement_select = ( 'DECLARE $fld_long LONG; SELECT fld_sid, fld_id, fld_long FROM ' + table_name + ' WHERE fld_long = $fld_long') prepare_request_select = PrepareRequest().set_statement( prepare_statement_select) self.prepare_result_select = self.handle.prepare( prepare_request_select) self.query_request = QueryRequest().set_timeout(timeout) self.get_request = GetRequest().set_table_name(table_name) def tearDown(self): TestBase.tear_down(self) def testQuerySetIllegalLimit(self): self.assertRaises(IllegalArgumentException, self.query_request.set_limit, 'IllegalLimit') self.assertRaises(IllegalArgumentException, self.query_request.set_limit, -1) def testQuerySetIllegalMaxReadKb(self): self.assertRaises(IllegalArgumentException, self.query_request.set_max_read_kb, 'IllegalLimit') self.assertRaises(IllegalArgumentException, self.query_request.set_max_read_kb, -1) self.assertRaises(IllegalArgumentException, self.query_request.set_max_read_kb, 2049) def testQuerySetIllegalConsistency(self): self.assertRaises(IllegalArgumentException, self.query_request.set_consistency, 'IllegalConsistency') def testQuerySetIllegalContinuationKey(self): self.assertRaises(IllegalArgumentException, self.query_request.set_continuation_key, 'IllegalContinuationKey') def testQuerySetIllegalStatement(self): self.query_request.set_statement('IllegalStatement') self.assertRaises(IllegalArgumentException, self.handle.query, self.query_request) self.query_request.set_statement('SELECT fld_id FROM IllegalTableName') self.assertRaises(TableNotFoundException, self.handle.query, self.query_request) def testQuerySetIllegalPreparedStatement(self): self.assertRaises(IllegalArgumentException, self.query_request.set_prepared_statement, 'IllegalPreparedStatement') def testQuerySetIllegalTimeout(self): self.assertRaises(IllegalArgumentException, self.query_request.set_timeout, 'IllegalTimeout') self.assertRaises(IllegalArgumentException, self.query_request.set_timeout, 0) self.assertRaises(IllegalArgumentException, self.query_request.set_timeout, -1) def testQuerySetIllegalDefaults(self): self.assertRaises(IllegalArgumentException, self.query_request.set_defaults, 'IllegalDefaults') def testQuerySetDefaults(self): self.query_request.set_defaults(self.handle_config) self.assertEqual(self.query_request.get_timeout(), timeout) self.assertEqual(self.query_request.get_consistency(), Consistency.ABSOLUTE) def testQueryNoStatementAndBothStatement(self): self.assertRaises(IllegalArgumentException, self.handle.query, self.query_request) self.query_request.set_statement(query_statement) self.query_request.set_prepared_statement(self.prepare_result_select) self.assertRaises(IllegalArgumentException, self.handle.query, self.query_request) def testQueryGets(self): continuation_key = bytearray(5) self.query_request.set_consistency(Consistency.EVENTUAL).set_statement( query_statement).set_prepared_statement( self.prepare_result_select).set_limit(3).set_max_read_kb( 2).set_continuation_key(continuation_key) self.assertEqual(self.query_request.get_limit(), 3) self.assertEqual(self.query_request.get_max_read_kb(), 2) self.assertEqual(self.query_request.get_consistency(), Consistency.EVENTUAL) self.assertEqual(self.query_request.get_continuation_key(), continuation_key) self.assertEqual(self.query_request.get_statement(), query_statement) self.assertEqual(self.query_request.get_prepared_statement(), self.prepare_result_select.get_prepared_statement()) self.assertEqual(self.query_request.get_timeout(), timeout) def testQueryIllegalRequest(self): self.assertRaises(IllegalArgumentException, self.handle.query, 'IllegalRequest') def testQueryStatementSelect(self): num_records = 6 self.query_request.set_statement(query_statement) result = self.handle.query(self.query_request) records = result.get_results() self.assertEqual(len(records), num_records) for idx in range(num_records): self.assertEqual(records[idx], {'fld_sid': 1, 'fld_id': idx}) self.assertIsNone(result.get_continuation_key()) self.assertEqual(result.get_read_kb(), num_records + prepare_cost) self.assertEqual(result.get_read_units(), num_records * 2 + prepare_cost) self.assertEqual(result.get_write_kb(), 0) self.assertEqual(result.get_write_units(), 0) def testQueryStatementSelectWithLimit(self): limit = 3 self.query_request.set_statement(query_statement).set_limit(limit) result = self.handle.query(self.query_request) records = result.get_results() self.assertEqual(len(records), limit) for idx in range(limit): self.assertEqual(records[idx], {'fld_sid': 1, 'fld_id': idx}) self.assertIsNotNone(result.get_continuation_key()) self.assertEqual(result.get_read_kb(), limit + prepare_cost) self.assertEqual(result.get_read_units(), limit * 2 + prepare_cost) self.assertEqual(result.get_write_kb(), 0) self.assertEqual(result.get_write_units(), 0) def testQueryStatementSelectWithMaxReadKb(self): max_read_kb = 4 self.query_request.set_statement(query_statement).set_max_read_kb( max_read_kb) result = self.handle.query(self.query_request) records = result.get_results() self.assertEqual(len(records), max_read_kb + 1) for idx in range(len(records)): self.assertEqual(records[idx], {'fld_sid': 1, 'fld_id': idx}) self.assertIsNotNone(result.get_continuation_key()) self.assertEqual(result.get_read_kb(), max_read_kb + prepare_cost + 1) self.assertEqual(result.get_read_units(), max_read_kb * 2 + prepare_cost + 2) self.assertEqual(result.get_write_kb(), 0) self.assertEqual(result.get_write_units(), 0) def testQueryStatementSelectWithConsistency(self): num_records = 6 self.query_request.set_statement(query_statement).set_consistency( Consistency.ABSOLUTE) result = self.handle.query(self.query_request) records = result.get_results() self.assertEqual(len(records), num_records) for idx in range(num_records): self.assertEqual(records[idx], {'fld_sid': 1, 'fld_id': idx}) self.assertIsNone(result.get_continuation_key()) self.assertEqual(result.get_read_kb(), num_records + prepare_cost) self.assertEqual(result.get_read_units(), num_records * 2 + prepare_cost) self.assertEqual(result.get_write_kb(), 0) self.assertEqual(result.get_write_units(), 0) def testQueryStatementSelectWithContinuationKey(self): num_records = 6 limit = 4 self.query_request.set_statement(query_statement).set_limit(limit) count = 0 while True: completed = count * limit result = self.handle.query(self.query_request) records = result.get_results() if completed + limit <= num_records: num_get = limit read_kb = num_get self.assertIsNotNone(result.get_continuation_key()) else: num_get = num_records - completed read_kb = (1 if num_get == 0 else num_get) self.assertIsNone(result.get_continuation_key()) self.assertEqual(len(records), num_get) for idx in range(num_get): self.assertEqual(records[idx], { 'fld_sid': 1, 'fld_id': completed + idx }) self.assertEqual(result.get_read_kb(), read_kb + prepare_cost) self.assertEqual(result.get_read_units(), read_kb * 2 + prepare_cost) self.assertEqual(result.get_write_kb(), 0) self.assertEqual(result.get_write_units(), 0) count += 1 if result.get_continuation_key() is None: break self.query_request.set_continuation_key( result.get_continuation_key()) self.assertEqual(count, num_records // limit + 1) def testQueryStatementSelectWithDefault(self): num_records = 6 self.query_request.set_statement(query_statement).set_defaults( self.handle_config) result = self.handle.query(self.query_request) records = result.get_results() self.assertEqual(len(records), num_records) for idx in range(num_records): self.assertEqual(records[idx], {'fld_sid': 1, 'fld_id': idx}) self.assertIsNone(result.get_continuation_key()) self.assertEqual(result.get_read_kb(), num_records + prepare_cost) self.assertEqual(result.get_read_units(), num_records * 2 + prepare_cost) self.assertEqual(result.get_write_kb(), 0) self.assertEqual(result.get_write_units(), 0) def testQueryPreparedStatementUpdate(self): fld_sid = 0 fld_id = 2 fld_long = 2147483649 prepared_statement = self.prepare_result_update.get_prepared_statement( ) # update a non-existing row prepared_statement.set_variable('$fld_sid', 2).set_variable('$fld_id', 0) self.query_request.set_prepared_statement(self.prepare_result_update) result = self.handle.query(self.query_request) records = result.get_results() self.assertEqual(len(records), 1) self.assertEqual(records[0], {'NumRowsUpdated': 0}) self.assertIsNone(result.get_continuation_key()) self.assertEqual(result.get_read_kb(), 1) self.assertEqual(result.get_read_units(), 2) self.assertEqual(result.get_write_kb(), 0) self.assertEqual(result.get_write_units(), 0) # update an existing row prepared_statement.set_variable('$fld_sid', fld_sid).set_variable( '$fld_id', fld_id) self.query_request.set_prepared_statement(self.prepare_result_update) result = self.handle.query(self.query_request) records = result.get_results() self.assertEqual(len(records), 1) self.assertEqual(records[0], {'NumRowsUpdated': 1}) self.assertIsNone(result.get_continuation_key()) self.assertEqual(result.get_read_kb(), 2) self.assertEqual(result.get_read_units(), 4) self.assertEqual(result.get_write_kb(), 4) self.assertEqual(result.get_write_units(), 4) # check the updated row prepared_statement = self.prepare_result_select.get_prepared_statement( ) prepared_statement.set_variable('$fld_long', fld_long) self.query_request.set_prepared_statement(prepared_statement) result = self.handle.query(self.query_request) records = result.get_results() self.assertEqual(len(records), 1) self.assertEqual(records[0], { 'fld_sid': fld_sid, 'fld_id': fld_id, 'fld_long': fld_long }) self.assertIsNone(result.get_continuation_key()) self.assertEqual(result.get_read_kb(), 1) self.assertEqual(result.get_read_units(), 2) self.assertEqual(result.get_write_kb(), 0) self.assertEqual(result.get_write_units(), 0) def testQueryPreparedStatementUpdateWithLimit(self): fld_sid = 1 fld_id = 5 fld_long = 2147483649 prepared_statement = self.prepare_result_update.get_prepared_statement( ) prepared_statement.set_variable('$fld_sid', fld_sid).set_variable( '$fld_id', fld_id) self.query_request.set_prepared_statement( self.prepare_result_update).set_limit(1) result = self.handle.query(self.query_request) records = result.get_results() self.assertEqual(len(records), 1) self.assertEqual(records[0], {'NumRowsUpdated': 1}) self.assertIsNone(result.get_continuation_key()) self.assertEqual(result.get_read_kb(), 2) self.assertEqual(result.get_read_units(), 4) self.assertEqual(result.get_write_kb(), 4) self.assertEqual(result.get_write_units(), 4) # check the updated row prepared_statement = self.prepare_result_select.get_prepared_statement( ) prepared_statement.set_variable('$fld_long', fld_long) self.query_request.set_prepared_statement(prepared_statement) result = self.handle.query(self.query_request) records = result.get_results() self.assertEqual(len(records), 1) self.assertEqual(records[0], { 'fld_sid': fld_sid, 'fld_id': fld_id, 'fld_long': fld_long }) self.assertIsNotNone(result.get_continuation_key()) self.assertEqual(result.get_read_kb(), 1) self.assertEqual(result.get_read_units(), 2) self.assertEqual(result.get_write_kb(), 0) self.assertEqual(result.get_write_units(), 0) def testQueryPreparedStatementUpdateWithMaxReadKb(self): fld_sid = 0 fld_id = 1 fld_long = 2147483649 # set a small max_read_kb to read a row to update prepared_statement = self.prepare_result_update.get_prepared_statement( ) prepared_statement.set_variable('$fld_sid', fld_sid).set_variable( '$fld_id', fld_id) self.query_request.set_prepared_statement( self.prepare_result_update).set_max_read_kb(1) self.assertRaises(IllegalArgumentException, self.handle.query, self.query_request) # set a enough max_read_kb to read a row to update self.query_request.set_max_read_kb(2) result = self.handle.query(self.query_request) records = result.get_results() self.assertEqual(len(records), 1) self.assertEqual(records[0], {'NumRowsUpdated': 1}) self.assertIsNone(result.get_continuation_key()) self.assertEqual(result.get_read_kb(), 2) self.assertEqual(result.get_read_units(), 4) self.assertEqual(result.get_write_kb(), 4) self.assertEqual(result.get_write_units(), 4) # check the updated row prepared_statement = self.prepare_result_select.get_prepared_statement( ) prepared_statement.set_variable('$fld_long', fld_long) self.query_request.set_prepared_statement(prepared_statement) result = self.handle.query(self.query_request) records = result.get_results() self.assertEqual(len(records), 1) self.assertEqual(records[0], { 'fld_sid': fld_sid, 'fld_id': fld_id, 'fld_long': fld_long }) self.assertIsNone(result.get_continuation_key()) self.assertEqual(result.get_read_kb(), 1) self.assertEqual(result.get_read_units(), 2) self.assertEqual(result.get_write_kb(), 0) self.assertEqual(result.get_write_units(), 0) def testQueryPreparedStatementUpdateWithConsistency(self): fld_sid = 1 fld_id = 2 fld_long = 2147483649 prepared_statement = self.prepare_result_update.get_prepared_statement( ) prepared_statement.set_variable('$fld_sid', fld_sid).set_variable( '$fld_id', fld_id) self.query_request.set_prepared_statement( self.prepare_result_update).set_consistency(Consistency.ABSOLUTE) result = self.handle.query(self.query_request) records = result.get_results() self.assertEqual(len(records), 1) self.assertEqual(records[0], {'NumRowsUpdated': 1}) self.assertIsNone(result.get_continuation_key()) self.assertEqual(result.get_read_kb(), 2) self.assertEqual(result.get_read_units(), 4) self.assertEqual(result.get_write_kb(), 4) self.assertEqual(result.get_write_units(), 4) # check the updated row prepared_statement = self.prepare_result_select.get_prepared_statement( ) prepared_statement.set_variable('$fld_long', fld_long) self.query_request.set_prepared_statement(prepared_statement) result = self.handle.query(self.query_request) records = result.get_results() self.assertEqual(len(records), 1) self.assertEqual(records[0], { 'fld_sid': fld_sid, 'fld_id': fld_id, 'fld_long': fld_long }) self.assertIsNone(result.get_continuation_key()) self.assertEqual(result.get_read_kb(), 1) self.assertEqual(result.get_read_units(), 2) self.assertEqual(result.get_write_kb(), 0) self.assertEqual(result.get_write_units(), 0) def testQueryPreparedStatementUpdateWithContinuationKey(self): fld_sid = 1 fld_id = 3 fld_long = 2147483649 num_records = 1 limit = 3 prepared_statement = self.prepare_result_update.get_prepared_statement( ) prepared_statement.set_variable('$fld_sid', fld_sid).set_variable( '$fld_id', fld_id) self.query_request.set_prepared_statement( self.prepare_result_update).set_limit(limit) count = 0 while True: completed = count * limit result = self.handle.query(self.query_request) records = result.get_results() self.assertEqual(len(records), 1) if completed + limit <= num_records: self.assertEqual(records[0], {'NumRowsUpdated': limit}) read_kb = limit * 2 write_kb = limit * 4 else: num_update = num_records - completed self.assertEqual(records[0], {'NumRowsUpdated': num_update}) read_kb = (1 if num_update == 0 else num_update * 2) write_kb = (0 if num_update == 0 else num_update * 4) self.assertIsNone(result.get_continuation_key()) self.assertEqual(result.get_read_kb(), read_kb) self.assertEqual(result.get_read_units(), read_kb * 2) self.assertEqual(result.get_write_kb(), write_kb) self.assertEqual(result.get_write_units(), write_kb) count += 1 if result.get_continuation_key() is None: break self.query_request.set_continuation_key( result.get_continuation_key()) self.assertEqual(count, 1) # check the updated row prepared_statement = self.prepare_result_select.get_prepared_statement( ) prepared_statement.set_variable('$fld_long', fld_long) self.query_request.set_prepared_statement(prepared_statement) result = self.handle.query(self.query_request) records = result.get_results() self.assertEqual(len(records), num_records) self.assertEqual(records[0], { 'fld_sid': fld_sid, 'fld_id': fld_id, 'fld_long': fld_long }) if limit <= num_records: self.assertIsNotNone(result.get_continuation_key()) else: self.assertIsNone(result.get_continuation_key()) self.assertEqual(result.get_read_kb(), 1) self.assertEqual(result.get_read_units(), 2) self.assertEqual(result.get_write_kb(), 0) self.assertEqual(result.get_write_units(), 0) def testQueryPreparedStatementUpdateWithDefault(self): fld_sid = 0 fld_id = 5 fld_long = 2147483649 prepared_statement = self.prepare_result_update.get_prepared_statement( ) prepared_statement.set_variable('$fld_sid', fld_sid).set_variable( '$fld_id', fld_id) self.query_request.set_prepared_statement( self.prepare_result_update).set_defaults(self.handle_config) result = self.handle.query(self.query_request) records = result.get_results() self.assertEqual(len(records), 1) self.assertEqual(records[0], {'NumRowsUpdated': 1}) self.assertIsNone(result.get_continuation_key()) self.assertEqual(result.get_read_kb(), 2) self.assertEqual(result.get_read_units(), 4) self.assertEqual(result.get_write_kb(), 4) self.assertEqual(result.get_write_units(), 4) # check the updated row prepared_statement = self.prepare_result_select.get_prepared_statement( ) prepared_statement.set_variable('$fld_long', fld_long) self.query_request.set_prepared_statement(prepared_statement) result = self.handle.query(self.query_request) records = result.get_results() self.assertEqual(len(records), 1) self.assertEqual(records[0], { 'fld_sid': fld_sid, 'fld_id': fld_id, 'fld_long': fld_long }) self.assertIsNone(result.get_continuation_key()) self.assertEqual(result.get_read_kb(), 1) self.assertEqual(result.get_read_units(), 2) self.assertEqual(result.get_write_kb(), 0) self.assertEqual(result.get_write_units(), 0) def testQueryStatementUpdateTTL(self): hour_in_milliseconds = 60 * 60 * 1000 self.query_request.set_statement( 'UPDATE ' + table_name + ' $u SET TTL CASE WHEN ' + 'remaining_hours($u) < 0 THEN 3 ELSE remaining_hours($u) + 3 END ' + 'HOURS WHERE fld_sid = 1 AND fld_id = 3') result = self.handle.query(self.query_request) ttl = TimeToLive.of_hours(3) expect_expiration = ttl.to_expiration_time(int(round(time() * 1000))) records = result.get_results() self.assertEqual(len(records), 1) self.assertEqual(records[0], {'NumRowsUpdated': 1}) self.assertIsNone(result.get_continuation_key()) self.assertEqual(result.get_read_kb(), 2 + prepare_cost) self.assertEqual(result.get_read_units(), 4 + prepare_cost) self.assertEqual(result.get_write_kb(), 3) self.assertEqual(result.get_write_units(), 3) # check the record after update ttl request succeed self.get_request.set_key({'fld_sid': 1, 'fld_id': 3}) result = self.handle.get(self.get_request) actual_expiration = result.get_expiration_time() actual_expect_diff = actual_expiration - expect_expiration self.assertGreater(actual_expiration, 0) self.assertLess(actual_expect_diff, hour_in_milliseconds) self.assertEqual(result.get_read_kb(), 1) self.assertEqual(result.get_read_units(), 2) self.assertEqual(result.get_write_kb(), 0) self.assertEqual(result.get_write_units(), 0)