class TestPut(unittest.TestCase, TestBase):
    @classmethod
    def setUpClass(cls):
        cls.set_up_class()
        global table_ttl
        table_ttl = TimeToLive.of_days(30)
        create_statement = ('CREATE TABLE ' + table_name +
                            '(fld_id INTEGER, fld_long LONG, \
fld_float FLOAT, fld_double DOUBLE, fld_bool BOOLEAN, fld_str STRING, \
fld_bin BINARY, fld_time TIMESTAMP(9), fld_num NUMBER, fld_json JSON, \
fld_arr ARRAY(STRING), fld_map MAP(STRING), \
fld_rec RECORD(fld_id LONG, fld_bool BOOLEAN, fld_str STRING), \
PRIMARY KEY(fld_id)) USING TTL ' + str(table_ttl))
        create_request = TableRequest().set_statement(
            create_statement).set_table_limits(TableLimits(50, 50, 1))
        cls.table_request(create_request)

    @classmethod
    def tearDownClass(cls):
        cls.tear_down_class()

    def setUp(self):
        self.set_up()
        self.row = get_row(with_sid=False)
        self.key = {'fld_id': 1}
        self.put_request = PutRequest().set_value(
            self.row).set_table_name(table_name).set_timeout(timeout)
        self.get_request = GetRequest().set_key(
            self.key).set_table_name(table_name)
        self.ttl = TimeToLive.of_hours(24)
        self.hour_in_milliseconds = 60 * 60 * 1000
        self.day_in_milliseconds = 24 * 60 * 60 * 1000

    def tearDown(self):
        request = DeleteRequest().set_key(self.key).set_table_name(table_name)
        self.handle.delete(request)
        self.tear_down()

    def testPutSetIllegalValue(self):
        self.assertRaises(IllegalArgumentException, self.put_request.set_value,
                          'IllegalValue')

    def testPutSetIllegalValueFromJson(self):
        self.assertRaises(ValueError, self.put_request.set_value_from_json,
                          'IllegalJson')
        self.put_request.set_value_from_json('{"invalid_field": "value"}')
        self.assertRaises(IllegalArgumentException, self.handle.put,
                          self.put_request)

    def testPutSetIllegalCompartment(self):
        self.assertRaises(IllegalArgumentException,
                          self.put_request.set_compartment, {})
        self.assertRaises(IllegalArgumentException,
                          self.put_request.set_compartment, '')

    def testPutSetIllegalOption(self):
        self.put_request.set_option('IllegalOption')
        self.assertRaises(IllegalStateException, self.handle.put,
                          self.put_request)

    def testPutSetIllegalMatchVersion(self):
        self.assertRaises(IllegalArgumentException,
                          self.put_request.set_match_version,
                          'IllegalMatchVersion')

    def testPutSetIllegalTtl(self):
        self.assertRaises(IllegalArgumentException, self.put_request.set_ttl,
                          'IllegalTtl')

    def testPutSetIllegalUseTableDefaultTtl(self):
        self.assertRaises(IllegalArgumentException,
                          self.put_request.set_use_table_default_ttl,
                          'IllegalUseTableDefaultTtl')

    def testPutSetIllegalExactMatch(self):
        self.assertRaises(IllegalArgumentException,
                          self.put_request.set_exact_match,
                          'IllegalExactMatch')

    def testPutSetIllegalIdentityCacheSize(self):
        self.assertRaises(IllegalArgumentException,
                          self.put_request.set_identity_cache_size,
                          'IllegalIdentityCacheSize')

    def testPutSetIllegalTimeout(self):
        self.assertRaises(IllegalArgumentException,
                          self.put_request.set_timeout, 'IllegalTimeout')
        self.assertRaises(IllegalArgumentException,
                          self.put_request.set_timeout, 0)
        self.assertRaises(IllegalArgumentException,
                          self.put_request.set_timeout, -1)

    def testPutSetIllegalTableName(self):
        self.assertRaises(IllegalArgumentException,
                          self.put_request.set_table_name,
                          {'name': table_name})
        self.put_request.set_table_name('IllegalTable')
        self.assertRaises(TableNotFoundException, self.handle.put,
                          self.put_request)

    def testPutSetIllegalReturnRow(self):
        self.assertRaises(IllegalArgumentException,
                          self.put_request.set_return_row, 'IllegalReturnRow')

    def testPutSetLargeSizeValue(self):
        self.row['fld_str'] = self.get_random_str(2)
        self.put_request.set_value(self.row)
        if is_onprem():
            version = self.handle.put(self.put_request).get_version()
            self.assertIsNotNone(version)
        else:
            self.assertRaises(RequestSizeLimitException, self.handle.put,
                              self.put_request)

    def testPutIfVersionWithoutMatchVersion(self):
        self.put_request.set_option(PutOption.IF_VERSION)
        self.assertRaises(IllegalArgumentException, self.handle.put,
                          self.put_request)

    def testPutNoVersionWithMatchVersion(self):
        version = self.handle.put(self.put_request).get_version()
        self.put_request.set_option(
            PutOption.IF_ABSENT).set_match_version(version)
        self.assertRaises(IllegalArgumentException, self.handle.put,
                          self.put_request)
        self.put_request.set_option(
            PutOption.IF_PRESENT).set_match_version(version)
        self.assertRaises(IllegalArgumentException, self.handle.put,
                          self.put_request)

    def testPutSetTtlAndUseTableDefaultTtl(self):
        self.put_request.set_ttl(self.ttl).set_use_table_default_ttl(True)
        self.assertRaises(IllegalArgumentException, self.handle.put,
                          self.put_request)

    def testPutGets(self):
        identity_cache_size = 5
        version = self.handle.put(self.put_request).get_version()
        self.put_request.set_option(
            PutOption.IF_ABSENT).set_match_version(version).set_ttl(
                self.ttl).set_use_table_default_ttl(True).set_exact_match(
                    True).set_identity_cache_size(
                        identity_cache_size).set_return_row(True)
        self.assertEqual(self.put_request.get_value(), self.row)
        self.assertEqual(self.put_request.get_compartment(), tenant_id)
        self.assertEqual(self.put_request.get_option(), PutOption.IF_ABSENT)
        self.assertEqual(self.put_request.get_match_version(), version)
        self.assertEqual(self.put_request.get_ttl(), self.ttl)
        self.assertTrue(self.put_request.get_use_table_default_ttl())
        self.assertTrue(self.put_request.get_update_ttl())
        self.assertEqual(self.put_request.get_timeout(), timeout)
        self.assertEqual(self.put_request.get_table_name(), table_name)
        self.assertTrue(self.put_request.get_exact_match())
        self.assertEqual(self.put_request.get_identity_cache_size(),
                         identity_cache_size)
        self.assertTrue(self.put_request.get_return_row())

    def testPutIllegalRequest(self):
        self.assertRaises(IllegalArgumentException, self.handle.put,
                          'IllegalRequest')

    def testPutNormal(self):
        # test put with normal values
        result = self.handle.put(self.put_request)
        tb_expect_expiration = table_ttl.to_expiration_time(
            int(round(time() * 1000)))
        version = result.get_version()
        self._check_put_result(result)
        self.check_cost(result, 0, 0, 1, 1)
        result = self.handle.get(self.get_request)
        self.check_get_result(result, self.row, version, tb_expect_expiration,
                              TimeUnit.DAYS)
        self.check_cost(result, 1, 2, 0, 0)
        # put a row with the same primary key to update the row
        self.row['fld_long'] = 2147483649
        self.put_request.set_value(self.row).set_ttl(self.ttl)
        result = self.handle.put(self.put_request)
        expect_expiration = self.ttl.to_expiration_time(
            int(round(time() * 1000)))
        version = result.get_version()
        self._check_put_result(result)
        self.check_cost(result, 0, 0, 2, 2)
        result = self.handle.get(self.get_request)
        self.check_get_result(result, self.row, version, expect_expiration,
                              TimeUnit.HOURS)
        self.check_cost(result, 1, 2, 0, 0)
        # update the ttl of the row to never expire
        self.put_request.set_ttl(TimeToLive.of_days(0))
        result = self.handle.put(self.put_request)
        version = result.get_version()
        self._check_put_result(result)
        self.check_cost(result, 0, 0, 2, 2)
        result = self.handle.get(self.get_request)
        self.check_get_result(result, self.row, version)
        self.check_cost(result, 1, 2, 0, 0)

    def testPutIfAbsent(self):
        # test PutIfAbsent with normal values
        self.put_request.set_option(PutOption.IF_ABSENT).set_ttl(
            self.ttl).set_return_row(True)
        result = self.handle.put(self.put_request)
        expect_expiration = self.ttl.to_expiration_time(
            int(round(time() * 1000)))
        version = result.get_version()
        self._check_put_result(result)
        self.check_cost(result, 1, 2, 1, 1)
        result = self.handle.get(self.get_request)
        self.check_get_result(result, self.row, version, expect_expiration,
                              TimeUnit.HOURS)
        self.check_cost(result, 1, 2, 0, 0)
        # put a row with the same primary key to update the row, operation
        # should fail, and return the existing row
        result = self.handle.put(self.put_request)
        self._check_put_result(result,
                               False,
                               existing_version=version,
                               existing_value=self.row)
        self.check_cost(result, 1, 2, 0, 0)

    def testPutIfPresent(self):
        # test PutIfPresent with normal values, operation should fail because
        # there is no existing row in store
        self.put_request.set_option(PutOption.IF_PRESENT)
        result = self.handle.put(self.put_request)
        self._check_put_result(result, False)
        self.check_cost(result, 1, 2, 0, 0)
        # insert a row
        self.put_request.set_option(PutOption.IF_ABSENT).set_ttl(self.ttl)
        self.handle.put(self.put_request)
        expect_expiration = self.ttl.to_expiration_time(
            int(round(time() * 1000)))
        # test PutIfPresent with normal values, operation should succeed
        self.row['fld_long'] = 2147483649
        self.put_request.set_value(self.row).set_option(
            PutOption.IF_PRESENT).set_return_row(True)
        result = self.handle.put(self.put_request)
        version = result.get_version()
        self._check_put_result(result)
        self.check_cost(result, 1, 2, 2, 2)
        result = self.handle.get(self.get_request)
        self.check_get_result(result, self.row, version, expect_expiration,
                              TimeUnit.HOURS)
        self.check_cost(result, 1, 2, 0, 0)
        # test PutIfPresent with normal values, update the ttl with table
        # default ttl
        self.put_request.set_ttl(None).set_use_table_default_ttl(True)
        result = self.handle.put(self.put_request)
        tb_expect_expiration = table_ttl.to_expiration_time(
            int(round(time() * 1000)))
        version = result.get_version()
        self._check_put_result(result)
        self.check_cost(result, 1, 2, 2, 2)
        result = self.handle.get(self.get_request)
        self.check_get_result(result, self.row, version, tb_expect_expiration,
                              TimeUnit.DAYS)
        self.check_cost(result, 1, 2, 0, 0)

    def testPutIfVersion(self):
        # insert a row
        result = self.handle.put(self.put_request)
        version_old = result.get_version()
        # test PutIfVersion with normal values, operation should succeed
        self.row['fld_bool'] = False
        self.put_request.set_value(self.row).set_ttl(
            self.ttl).set_match_version(version_old).set_return_row(True)
        result = self.handle.put(self.put_request)
        expect_expiration = self.ttl.to_expiration_time(
            int(round(time() * 1000)))
        version = result.get_version()
        self._check_put_result(result)
        self.check_cost(result, 1, 2, 2, 2)
        result = self.handle.get(self.get_request)
        self.check_get_result(result, self.row, version, expect_expiration,
                              TimeUnit.HOURS)
        self.check_cost(result, 1, 2, 0, 0)
        # test PutIfVersion with normal values, operation should fail because
        # version not match, and return the existing row
        self.put_request.set_ttl(None).set_use_table_default_ttl(True)
        result = self.handle.put(self.put_request)
        self._check_put_result(result,
                               False,
                               existing_version=version,
                               existing_value=self.row)
        self.check_cost(result, 1, 2, 0, 0)

    def testPutWithExactMatch(self):
        # test put a row with an extra field not in the table, by default this
        # will succeed
        row = deepcopy(self.row)
        row.update({'fld_id': 2, 'extra': 5})
        key = {'fld_id': 2}
        self.row['fld_id'] = 2
        self.put_request.set_value(row)
        result = self.handle.put(self.put_request)
        tb_expect_expiration = table_ttl.to_expiration_time(
            int(round(time() * 1000)))
        version = result.get_version()
        self._check_put_result(result)
        self.check_cost(result, 0, 0, 1, 1)
        self.get_request.set_key(key)
        result = self.handle.get(self.get_request)
        self.check_get_result(result, self.row, version, tb_expect_expiration,
                              TimeUnit.DAYS)
        self.check_cost(result, 1, 2, 0, 0)
        # test put a row with an extra field not in the table, this will fail
        # because it's not an exact match when we set exact_match=True
        self.put_request.set_exact_match(True)
        self.assertRaises(IllegalArgumentException, self.handle.put,
                          self.put_request)

    def testPutWithIdentityColumn(self):
        id_table = table_prefix + 'Identity'
        create_request = TableRequest().set_statement(
            'CREATE TABLE ' + id_table + '(sid INTEGER, id LONG GENERATED \
ALWAYS AS IDENTITY, name STRING, PRIMARY KEY(SHARD(sid), id))')
        create_request.set_table_limits(TableLimits(50, 50, 1))
        self.table_request(create_request)

        # test put a row with an extra field not in the table, by default this
        # will succeed
        row = {'name': 'myname', 'extra': 'extra', 'sid': 1}
        key = {'sid': 1, 'id': 1}
        expected = OrderedDict()
        expected['sid'] = 1
        expected['id'] = 1
        expected['name'] = 'myname'
        self.put_request.set_table_name(id_table).set_value(row)
        result = self.handle.put(self.put_request)
        version = result.get_version()
        self._check_put_result(result, has_generated_value=True)
        self.check_cost(result, 0, 0, 1, 1)
        self.get_request.set_table_name(id_table).set_key(key)
        result = self.handle.get(self.get_request)
        self.check_get_result(result, expected, version)
        self.check_cost(result, 1, 2, 0, 0)
        # test put a row with identity field, this will fail because id is
        # 'generated always' and in that path it is not legal to provide a value
        # for id
        row['id'] = 1
        self.assertRaises(IllegalArgumentException, self.handle.put,
                          self.put_request)

    def _check_put_result(self,
                          result,
                          has_version=True,
                          has_generated_value=False,
                          existing_version=None,
                          existing_value=None):
        # check version
        version = result.get_version()
        (self.assertIsNotNone(version)
         if has_version else self.assertIsNone(version))
        # check generated_value
        generated_value = result.get_generated_value()
        (self.assertIsNotNone(generated_value)
         if has_generated_value else self.assertIsNone(generated_value))
        # check existing version
        ver = result.get_existing_version()
        (self.assertIsNone(ver) if existing_version is None else
         self.assertEqual(ver.get_bytes(), existing_version.get_bytes()))
        # check existing value
        self.assertEqual(result.get_existing_value(), existing_value)
Beispiel #2
0
class TestGet(unittest.TestCase, TestBase):
    @classmethod
    def setUpClass(cls):
        cls.handle = None
        cls.set_up_class()
        table_ttl = TimeToLive.of_hours(16)
        create_statement = ('CREATE TABLE ' + table_name +
                            '(fld_sid INTEGER, fld_id INTEGER, \
fld_long LONG, fld_float FLOAT, fld_double DOUBLE, fld_bool BOOLEAN, \
fld_str STRING, fld_bin BINARY, fld_time TIMESTAMP(7), fld_num NUMBER, \
fld_json JSON, fld_arr ARRAY(STRING), fld_map MAP(STRING), \
fld_rec RECORD(fld_id LONG, fld_bool BOOLEAN, fld_str STRING), \
PRIMARY KEY(SHARD(fld_sid), fld_id)) USING TTL ' + str(table_ttl))
        create_request = TableRequest().set_statement(
            create_statement).set_table_limits(TableLimits(100, 100, 1))
        cls.table_request(create_request)
        global row, tb_expect_expiration, version
        row = get_row()
        put_request = PutRequest().set_value(row).set_table_name(table_name)
        version = cls.handle.put(put_request).get_version()
        tb_expect_expiration = table_ttl.to_expiration_time(
            int(round(time() * 1000)))

    @classmethod
    def tearDownClass(cls):
        cls.tear_down_class()

    def setUp(self):
        self.set_up()
        self.key = {'fld_sid': 1, 'fld_id': 1}
        self.get_request = GetRequest().set_key(
            self.key).set_table_name(table_name).set_timeout(timeout)

    def tearDown(self):
        self.tear_down()

    def testGetSetIllegalKey(self):
        self.assertRaises(IllegalArgumentException, self.get_request.set_key,
                          'IllegalKey')
        self.get_request.set_key({'fld_sid': 1})
        self.assertRaises(IllegalArgumentException, self.handle.get,
                          self.get_request)
        self.get_request.set_key({'fld_id': 1})
        self.assertRaises(IllegalArgumentException, self.handle.get,
                          self.get_request)

    def testGetSetIllegalKeyFromJson(self):
        self.assertRaises(ValueError, self.get_request.set_key_from_json,
                          'IllegalJson')
        self.get_request.set_key_from_json('{"invalid_field": "key"}')
        self.assertRaises(IllegalArgumentException, self.handle.get,
                          self.get_request)

    def testGetSetIllegalTableName(self):
        self.assertRaises(IllegalArgumentException,
                          self.get_request.set_table_name,
                          {'name': table_name})
        self.get_request.set_table_name('IllegalTable')
        self.assertRaises(TableNotFoundException, self.handle.get,
                          self.get_request)

    def testGetSetIllegalCompartment(self):
        self.assertRaises(IllegalArgumentException,
                          self.get_request.set_compartment, {})
        self.assertRaises(IllegalArgumentException,
                          self.get_request.set_compartment, '')

    def testGetSetIllegalConsistency(self):
        self.assertRaises(IllegalArgumentException,
                          self.get_request.set_consistency,
                          'IllegalConsistency')

    def testGetSetIllegalTimeout(self):
        self.assertRaises(IllegalArgumentException,
                          self.get_request.set_timeout, 'IllegalTimeout')
        self.assertRaises(IllegalArgumentException,
                          self.get_request.set_timeout, 0)
        self.assertRaises(IllegalArgumentException,
                          self.get_request.set_timeout, -1)

    def testGetWithoutKey(self):
        self.get_request.set_key(None)
        self.assertRaises(IllegalArgumentException, self.handle.get,
                          self.get_request)

    def testGetGets(self):
        self.assertEqual(self.get_request.get_key(), self.key)
        self.assertIsNone(self.get_request.get_compartment())

    def testGetIllegalRequest(self):
        self.assertRaises(IllegalArgumentException, self.handle.get,
                          'IllegalRequest')

    def testGetNormal(self):
        result = self.handle.get(self.get_request)
        self.check_get_result(result, row, version, tb_expect_expiration,
                              TimeUnit.HOURS)
        self.check_cost(result, 1, 2, 0, 0)

    def testGetEventual(self):
        self.get_request.set_consistency(Consistency.EVENTUAL)
        result = self.handle.get(self.get_request)
        self.check_get_result(result,
                              row,
                              expect_expiration=tb_expect_expiration,
                              timeunit=TimeUnit.HOURS,
                              ver_eq=False)
        self.check_cost(result, 1, 1, 0, 0)

    def testGetNonExisting(self):
        self.get_request.set_key({'fld_sid': 2, 'fld_id': 2})
        result = self.handle.get(self.get_request)
        self.check_get_result(result)
        self.check_cost(result, 1, 2, 0, 0)
Beispiel #3
0
class TestQuery(unittest.TestCase, TestBase):
    @classmethod
    def setUpClass(cls):
        cls.set_up_class()
        index_name = 'idx_' + table_name
        create_statement = ('CREATE TABLE ' + table_name +
                            '(fld_sid INTEGER, fld_id INTEGER, \
fld_long LONG, fld_float FLOAT, fld_double DOUBLE, fld_bool BOOLEAN, \
fld_str STRING, fld_bin BINARY, fld_time TIMESTAMP(6), fld_num NUMBER, \
fld_json JSON, fld_arr ARRAY(STRING), fld_map MAP(STRING), \
fld_rec RECORD(fld_id LONG, fld_bool BOOLEAN, fld_str STRING), \
PRIMARY KEY(SHARD(fld_sid), fld_id))')
        limits = TableLimits(100, 100, 1)
        create_request = TableRequest().set_statement(
            create_statement).set_table_limits(limits)
        cls.table_request(create_request)

        create_idx_request = TableRequest()
        create_idx_statement = ('CREATE INDEX ' + index_name + '1 ON ' +
                                table_name + '(fld_long)')
        create_idx_request.set_statement(create_idx_statement)
        cls.table_request(create_idx_request)
        create_idx_statement = ('CREATE INDEX ' + index_name + '2 ON ' +
                                table_name + '(fld_str)')
        create_idx_request.set_statement(create_idx_statement)
        cls.table_request(create_idx_request)
        create_idx_statement = ('CREATE INDEX ' + index_name + '3 ON ' +
                                table_name + '(fld_bool)')
        create_idx_request.set_statement(create_idx_statement)
        cls.table_request(create_idx_request)
        create_idx_statement = ('CREATE INDEX ' + index_name + '4 ON ' +
                                table_name + '(fld_json.location as point)')
        create_idx_request.set_statement(create_idx_statement)
        cls.table_request(create_idx_request)
        global prepare_cost
        prepare_cost = 2
        global query_statement
        query_statement = ('SELECT fld_sid, fld_id FROM ' + table_name +
                           ' WHERE fld_sid = 1')

    @classmethod
    def tearDownClass(cls):
        cls.tear_down_class()

    def setUp(self):
        self.set_up()
        self.handle_config = get_handle_config(tenant_id)
        self.min_time = list()
        self.max_time = list()
        shardkeys = 2
        ids = 6
        write_multiple_request = WriteMultipleRequest()
        for sk in range(shardkeys):
            for i in range(ids):
                row = get_row()
                if i == 0:
                    self.min_time.append(row['fld_time'])
                elif i == ids - 1:
                    self.max_time.append(row['fld_time'])
                row['fld_sid'] = sk
                row['fld_id'] = i
                row['fld_bool'] = False if sk == 0 else True
                row['fld_str'] = (
                    '{"name": u' +
                    str(shardkeys * ids - sk * ids - i - 1).zfill(2) + '}')
                row['fld_json']['location']['coordinates'] = ([
                    23.549 - sk * 0.5 - i, 35.2908 + sk * 0.5 + i
                ])
                write_multiple_request.add(
                    PutRequest().set_value(row).set_table_name(table_name),
                    True)
            self.handle.write_multiple(write_multiple_request)
            write_multiple_request.clear()
        prepare_statement_update = (
            'DECLARE $fld_sid INTEGER; $fld_id INTEGER; UPDATE ' + table_name +
            ' u SET u.fld_long = u.fld_long + 1 WHERE fld_sid = $fld_sid ' +
            'AND fld_id = $fld_id')
        prepare_request_update = PrepareRequest().set_statement(
            prepare_statement_update)
        self.prepare_result_update = self.handle.prepare(
            prepare_request_update)
        prepare_statement_select = (
            'DECLARE $fld_long LONG; SELECT fld_sid, fld_id, fld_long FROM ' +
            table_name + ' WHERE fld_long = $fld_long')
        prepare_request_select = PrepareRequest().set_statement(
            prepare_statement_select)
        self.prepare_result_select = self.handle.prepare(
            prepare_request_select)
        self.query_request = QueryRequest().set_timeout(timeout)
        self.get_request = GetRequest().set_table_name(table_name)

    def tearDown(self):
        self.tear_down()

    def testQuerySetIllegalCompartment(self):
        self.assertRaises(IllegalArgumentException,
                          self.query_request.set_compartment, {})
        self.assertRaises(IllegalArgumentException,
                          self.query_request.set_compartment, '')

    def testQuerySetIllegalLimit(self):
        self.assertRaises(IllegalArgumentException,
                          self.query_request.set_limit, 'IllegalLimit')
        self.assertRaises(IllegalArgumentException,
                          self.query_request.set_limit, -1)

    def testQuerySetIllegalMaxReadKb(self):
        self.assertRaises(IllegalArgumentException,
                          self.query_request.set_max_read_kb,
                          'IllegalMaxReadKb')
        self.assertRaises(IllegalArgumentException,
                          self.query_request.set_max_read_kb, -1)

    def testQuerySetIllegalMaxWriteKb(self):
        self.assertRaises(IllegalArgumentException,
                          self.query_request.set_max_write_kb,
                          'IllegalMaxWriteKb')
        self.assertRaises(IllegalArgumentException,
                          self.query_request.set_max_write_kb, -1)

    def testQuerySetIllegalMaxMemoryConsumption(self):
        self.assertRaises(IllegalArgumentException,
                          self.query_request.set_max_memory_consumption,
                          'IllegalMaxMemoryConsumption')
        self.assertRaises(IllegalArgumentException,
                          self.query_request.set_max_memory_consumption, -1)

    def testQuerySetIllegalMathContext(self):
        self.assertRaises(IllegalArgumentException,
                          self.query_request.set_math_context,
                          'IllegalMathContext')

    def testQuerySetIllegalConsistency(self):
        self.assertRaises(IllegalArgumentException,
                          self.query_request.set_consistency,
                          'IllegalConsistency')

    def testQuerySetIllegalContinuationKey(self):
        self.assertRaises(IllegalArgumentException,
                          self.query_request.set_continuation_key,
                          'IllegalContinuationKey')

    def testQuerySetIllegalStatement(self):
        self.assertRaises(IllegalArgumentException,
                          self.query_request.set_statement, {})
        self.assertRaises(IllegalArgumentException,
                          self.query_request.set_statement, '')
        self.query_request.set_statement('IllegalStatement')
        self.assertRaises(IllegalArgumentException, self.handle.query,
                          self.query_request)
        self.query_request.set_statement('SELECT fld_id FROM IllegalTableName')
        self.assertRaises(TableNotFoundException, self.handle.query,
                          self.query_request)

    def testQuerySetIllegalPreparedStatement(self):
        self.assertRaises(IllegalArgumentException,
                          self.query_request.set_prepared_statement,
                          'IllegalPreparedStatement')

    def testQuerySetIllegalTimeout(self):
        self.assertRaises(IllegalArgumentException,
                          self.query_request.set_timeout, 'IllegalTimeout')
        self.assertRaises(IllegalArgumentException,
                          self.query_request.set_timeout, 0)
        self.assertRaises(IllegalArgumentException,
                          self.query_request.set_timeout, -1)

    def testQuerySetIllegalDefaults(self):
        self.assertRaises(IllegalArgumentException,
                          self.query_request.set_defaults, 'IllegalDefaults')

    def testQuerySetDefaults(self):
        self.query_request.set_defaults(self.handle_config)
        self.assertEqual(self.query_request.get_timeout(), timeout)
        self.assertEqual(self.query_request.get_consistency(),
                         Consistency.ABSOLUTE)

    def testQueryNoStatementAndBothStatement(self):
        self.assertRaises(IllegalArgumentException, self.handle.query,
                          self.query_request)
        self.query_request.set_statement(query_statement)
        self.query_request.set_prepared_statement(self.prepare_result_select)
        self.assertRaises(IllegalArgumentException, self.handle.query,
                          self.query_request)

    def testQueryGets(self):
        continuation_key = bytearray(5)
        context = Context(prec=10, rounding=ROUND_HALF_EVEN)
        self.query_request.set_consistency(Consistency.EVENTUAL).set_statement(
            query_statement).set_prepared_statement(
                self.prepare_result_select).set_limit(3).set_max_read_kb(
                    2).set_max_write_kb(3).set_max_memory_consumption(
                        5).set_math_context(context).set_continuation_key(
                            continuation_key)
        self.assertIsNone(self.query_request.get_compartment())
        self.assertFalse(self.query_request.is_done())
        self.assertEqual(self.query_request.get_limit(), 3)
        self.assertEqual(self.query_request.get_max_read_kb(), 2)
        self.assertEqual(self.query_request.get_max_write_kb(), 3)
        self.assertEqual(self.query_request.get_max_memory_consumption(), 5)
        self.assertEqual(self.query_request.get_math_context(), context)
        self.assertEqual(self.query_request.get_consistency(),
                         Consistency.EVENTUAL)
        self.assertEqual(self.query_request.get_continuation_key(),
                         continuation_key)
        self.assertEqual(self.query_request.get_statement(), query_statement)
        self.assertEqual(self.query_request.get_prepared_statement(),
                         self.prepare_result_select.get_prepared_statement())
        self.assertEqual(self.query_request.get_timeout(), timeout)

    def testQueryIllegalRequest(self):
        self.assertRaises(IllegalArgumentException, self.handle.query,
                          'IllegalRequest')

    def testQueryStatementSelect(self):
        num_records = 6
        self.query_request.set_statement(query_statement)
        result = self.handle.query(self.query_request)
        records = self.check_query_result(result, num_records)
        for idx in range(num_records):
            self.assertEqual(records[idx], self._expected_row(1, idx))
        self.check_cost(result, num_records + prepare_cost,
                        num_records * 2 + prepare_cost, 0, 0)

    def testQueryStatementSelectWithLimit(self):
        limit = 3
        self.query_request.set_statement(query_statement).set_limit(limit)
        result = self.handle.query(self.query_request)
        records = self.check_query_result(result, limit, True)
        for idx in range(limit):
            self.assertEqual(records[idx], self._expected_row(1, idx))
        self.check_cost(result, limit + prepare_cost, limit * 2 + prepare_cost,
                        0, 0)

    def testQueryStatementSelectWithMaxReadKb(self):
        num_records = 6
        max_read_kb = 4
        self.query_request.set_statement(query_statement).set_max_read_kb(
            max_read_kb)
        result = self.handle.query(self.query_request)
        # TODO: [#27744] KV doesn't honor max read kb for on-prem proxy because
        # it has no table limits.
        if is_onprem():
            records = self.check_query_result(result, num_records)
        else:
            records = self.check_query_result(result, max_read_kb + 1, True)
        for idx in range(len(records)):
            self.assertEqual(records[idx], self._expected_row(1, idx))
        self.check_cost(result, max_read_kb + prepare_cost + 1,
                        max_read_kb * 2 + prepare_cost + 2, 0, 0)

    def testQueryStatementSelectWithConsistency(self):
        num_records = 6
        self.query_request.set_statement(query_statement).set_consistency(
            Consistency.ABSOLUTE)
        result = self.handle.query(self.query_request)
        records = self.check_query_result(result, num_records)
        for idx in range(num_records):
            self.assertEqual(records[idx], self._expected_row(1, idx))
        self.check_cost(result, num_records + prepare_cost,
                        num_records * 2 + prepare_cost, 0, 0)

    def testQueryStatementSelectWithContinuationKey(self):
        num_records = 6
        limit = 4
        self.query_request.set_statement(query_statement).set_limit(limit)
        count = 0
        while True:
            completed = count * limit
            result = self.handle.query(self.query_request)
            if completed + limit <= num_records:
                num_get = limit
                read_kb = num_get
                records = self.check_query_result(result, num_get, True)
            else:
                num_get = num_records - completed
                read_kb = (1 if num_get == 0 else num_get)
                records = self.check_query_result(result, num_get)
            for idx in range(num_get):
                self.assertEqual(records[idx],
                                 self._expected_row(1, completed + idx))
            self.check_cost(result,
                            read_kb + (prepare_cost if count == 0 else 0),
                            read_kb * 2 + (prepare_cost if count == 0 else 0),
                            0, 0)
            count += 1
            if result.get_continuation_key() is None:
                break
            self.query_request.set_continuation_key(
                result.get_continuation_key())
        self.assertEqual(count, num_records // limit + 1)

    def testQueryStatementSelectWithDefault(self):
        num_records = 6
        self.query_request.set_statement(query_statement).set_defaults(
            self.handle_config)
        result = self.handle.query(self.query_request)
        records = self.check_query_result(result, num_records)
        for idx in range(num_records):
            self.assertEqual(records[idx], self._expected_row(1, idx))
        self.check_cost(result, num_records + prepare_cost,
                        num_records * 2 + prepare_cost, 0, 0)

    def testQueryPreparedStatementUpdate(self):
        fld_sid = 0
        fld_id = 2
        fld_long = 2147483649
        prepared_statement = self.prepare_result_update.get_prepared_statement(
        )
        # update a non-existing row
        prepared_statement.set_variable('$fld_sid',
                                        2).set_variable('$fld_id', 0)
        self.query_request.set_prepared_statement(self.prepare_result_update)
        result = self.handle.query(self.query_request)
        records = self.check_query_result(result, 1)
        self.assertEqual(records[0], {'NumRowsUpdated': 0})
        self.check_cost(result, 1, 2, 0, 0)
        # update an existing row
        prepared_statement.set_variable('$fld_sid', fld_sid).set_variable(
            '$fld_id', fld_id)
        self.query_request.set_prepared_statement(self.prepare_result_update)
        result = self.handle.query(self.query_request)
        records = self.check_query_result(result, 1)
        self.assertEqual(records[0], {'NumRowsUpdated': 1})
        self.check_cost(result, 2, 4, 4, 4)
        # check the updated row
        prepared_statement = self.prepare_result_select.get_prepared_statement(
        )
        prepared_statement.set_variable('$fld_long', fld_long)
        self.query_request.set_prepared_statement(prepared_statement)
        result = self.handle.query(self.query_request)
        records = self.check_query_result(result, 1)
        self.assertEqual(records[0],
                         self._expected_row(fld_sid, fld_id, fld_long))
        self.check_cost(result, 1, 2, 0, 0)

    def testQueryPreparedStatementUpdateWithLimit(self):
        fld_sid = 1
        fld_id = 5
        fld_long = 2147483649
        prepared_statement = self.prepare_result_update.get_prepared_statement(
        )
        prepared_statement.set_variable('$fld_sid', fld_sid).set_variable(
            '$fld_id', fld_id)
        self.query_request.set_prepared_statement(
            self.prepare_result_update).set_limit(1)
        result = self.handle.query(self.query_request)
        records = self.check_query_result(result, 1)
        self.assertEqual(records[0], {'NumRowsUpdated': 1})
        self.check_cost(result, 2, 4, 4, 4)
        # check the updated row
        prepared_statement = self.prepare_result_select.get_prepared_statement(
        )
        prepared_statement.set_variable('$fld_long', fld_long)
        self.query_request.set_prepared_statement(prepared_statement)
        result = self.handle.query(self.query_request)
        records = self.check_query_result(result, 1, True)
        self.assertEqual(records[0],
                         self._expected_row(fld_sid, fld_id, fld_long))
        self.check_cost(result, 1, 2, 0, 0)

    def testQueryPreparedStatementUpdateWithMaxReadKb(self):
        fld_sid = 0
        fld_id = 1
        fld_long = 2147483649
        # set a small max_read_kb to read a row to update
        prepared_statement = self.prepare_result_update.get_prepared_statement(
        )
        prepared_statement.set_variable('$fld_sid', fld_sid).set_variable(
            '$fld_id', fld_id)
        self.query_request.set_prepared_statement(
            self.prepare_result_update).set_max_read_kb(1)
        if not is_onprem():
            self.assertRaises(IllegalArgumentException, self.handle.query,
                              self.query_request)
        # set a enough max_read_kb to read a row to update
        self.query_request.set_max_read_kb(2)
        result = self.handle.query(self.query_request)
        records = self.check_query_result(result, 1)
        self.assertEqual(records[0], {'NumRowsUpdated': 1})
        self.check_cost(result, 2, 4, 4, 4)
        # check the updated row
        prepared_statement = self.prepare_result_select.get_prepared_statement(
        )
        prepared_statement.set_variable('$fld_long', fld_long)
        self.query_request.set_prepared_statement(prepared_statement)
        result = self.handle.query(self.query_request)
        records = self.check_query_result(result, 1)
        self.assertEqual(records[0],
                         self._expected_row(fld_sid, fld_id, fld_long))
        self.check_cost(result, 1, 2, 0, 0)

    def testQueryPreparedStatementUpdateWithConsistency(self):
        fld_sid = 1
        fld_id = 2
        fld_long = 2147483649
        prepared_statement = self.prepare_result_update.get_prepared_statement(
        )
        prepared_statement.set_variable('$fld_sid', fld_sid).set_variable(
            '$fld_id', fld_id)
        self.query_request.set_prepared_statement(
            self.prepare_result_update).set_consistency(Consistency.ABSOLUTE)
        result = self.handle.query(self.query_request)
        records = self.check_query_result(result, 1)
        self.assertEqual(records[0], {'NumRowsUpdated': 1})
        self.check_cost(result, 2, 4, 4, 4)
        # check the updated row
        prepared_statement = self.prepare_result_select.get_prepared_statement(
        )
        prepared_statement.set_variable('$fld_long', fld_long)
        self.query_request.set_prepared_statement(prepared_statement)
        result = self.handle.query(self.query_request)
        records = self.check_query_result(result, 1)
        self.assertEqual(records[0],
                         self._expected_row(fld_sid, fld_id, fld_long))
        self.check_cost(result, 1, 2, 0, 0)

    def testQueryPreparedStatementUpdateWithContinuationKey(self):
        fld_sid = 1
        fld_id = 3
        fld_long = 2147483649
        num_records = 1
        limit = 3
        prepared_statement = self.prepare_result_update.get_prepared_statement(
        )
        prepared_statement.set_variable('$fld_sid', fld_sid).set_variable(
            '$fld_id', fld_id)
        self.query_request.set_prepared_statement(
            self.prepare_result_update).set_limit(limit)
        count = 0
        while True:
            completed = count * limit
            result = self.handle.query(self.query_request)
            records = self.check_query_result(result, 1)
            if completed + limit <= num_records:
                self.assertEqual(records[0], {'NumRowsUpdated': limit})
                read_kb = limit * 2
                write_kb = limit * 4
            else:
                num_update = num_records - completed
                self.assertEqual(records[0], {'NumRowsUpdated': num_update})
                read_kb = (1 if num_update == 0 else num_update * 2)
                write_kb = (0 if num_update == 0 else num_update * 4)
            self.check_cost(result, read_kb, read_kb * 2, write_kb, write_kb)
            count += 1
            if result.get_continuation_key() is None:
                break
            self.query_request.set_continuation_key(
                result.get_continuation_key())
        self.assertEqual(count, 1)
        # check the updated row
        prepared_statement = self.prepare_result_select.get_prepared_statement(
        )
        prepared_statement.set_variable('$fld_long', fld_long)
        self.query_request.set_prepared_statement(prepared_statement)
        result = self.handle.query(self.query_request)
        if limit <= num_records:
            records = self.check_query_result(result, num_records, True)
        else:
            records = self.check_query_result(result, num_records)
        self.assertEqual(records[0],
                         self._expected_row(fld_sid, fld_id, fld_long))
        self.check_cost(result, 1, 2, 0, 0)

    def testQueryPreparedStatementUpdateWithDefault(self):
        fld_sid = 0
        fld_id = 5
        fld_long = 2147483649
        prepared_statement = self.prepare_result_update.get_prepared_statement(
        )
        prepared_statement.set_variable('$fld_sid', fld_sid).set_variable(
            '$fld_id', fld_id)
        self.query_request.set_prepared_statement(
            self.prepare_result_update).set_defaults(self.handle_config)
        result = self.handle.query(self.query_request)
        records = self.check_query_result(result, 1)
        self.assertEqual(records[0], {'NumRowsUpdated': 1})
        self.check_cost(result, 2, 4, 4, 4)
        # check the updated row
        prepared_statement = self.prepare_result_select.get_prepared_statement(
        )
        prepared_statement.set_variable('$fld_long', fld_long)
        self.query_request.set_prepared_statement(prepared_statement)
        result = self.handle.query(self.query_request)
        records = self.check_query_result(result, 1)
        self.assertEqual(records[0],
                         self._expected_row(fld_sid, fld_id, fld_long))
        self.check_cost(result, 1, 2, 0, 0)

    def testQueryStatementUpdateTTL(self):
        hour_in_milliseconds = 60 * 60 * 1000
        self.query_request.set_statement(
            'UPDATE ' + table_name + ' $u SET TTL CASE WHEN ' +
            'remaining_hours($u) < 0 THEN 3 ELSE remaining_hours($u) + 3 END '
            + 'HOURS WHERE fld_sid = 1 AND fld_id = 3')
        result = self.handle.query(self.query_request)
        ttl = TimeToLive.of_hours(3)
        expect_expiration = ttl.to_expiration_time(int(round(time() * 1000)))
        records = self.check_query_result(result, 1)
        self.assertEqual(records[0], {'NumRowsUpdated': 1})
        self.check_cost(result, 2 + prepare_cost, 4 + prepare_cost, 6, 6)
        # check the record after update ttl request succeed
        self.get_request.set_key({'fld_sid': 1, 'fld_id': 3})
        result = self.handle.get(self.get_request)
        actual_expiration = result.get_expiration_time()
        actual_expect_diff = actual_expiration - expect_expiration
        self.assertGreater(actual_expiration, 0)
        self.assertLess(actual_expect_diff, hour_in_milliseconds)
        self.check_cost(result, 1, 2, 0, 0)

    def testQueryOrderBy(self):
        num_records = 12
        num_ids = 6
        # test order by primary index field
        statement = ('SELECT fld_sid, fld_id FROM ' + table_name +
                     ' ORDER BY fld_sid, fld_id')
        query_request = QueryRequest().set_statement(statement)
        count = 0
        while True:
            count += 1
            result = self.handle.query(query_request)
            records = result.get_results()
            if query_request.is_done():
                self.check_query_result(result, num_records, rec=records)
                for idx in range(num_records):
                    self.assertEqual(
                        records[idx],
                        self._expected_row(idx // num_ids, idx % num_ids))
                self.check_cost(result, 0, 0, 0, 0)
                break
            else:
                self.check_query_result(result, 0, True, records)
                self.assertEqual(records, [])
                self.check_cost(result, 0, 0, 0, 0, True)
        self.assertEqual(count, 2)

        # test order by secondary index field
        statement = ('SELECT fld_str FROM ' + table_name + ' ORDER BY fld_str')
        query_request = QueryRequest().set_statement(statement)
        while True:
            result = self.handle.query(query_request)
            records = result.get_results()
            if query_request.is_done():
                self.check_query_result(result, num_records, rec=records)
                for idx in range(num_records):
                    self.assertEqual(
                        records[idx],
                        {'fld_str': '{"name": u' + str(idx).zfill(2) + '}'})
                self.check_cost(result, 0, 0, 0, 0)
                break
            else:
                self.check_query_result(result, 0, True, records)
                self.assertEqual(records, [])
                self.check_cost(result, 0, 0, 0, 0, True)

    def testQueryFuncMinMaxGroupBy(self):
        num_sids = 2
        # test min function
        statement = ('SELECT min(fld_time) FROM ' + table_name)
        query_request = QueryRequest().set_statement(statement)
        result = self.handle.query(query_request)
        records = self.check_query_result(result, 1)
        self.assertEqual(records[0], {'Column_1': self.min_time[0]})
        self.check_cost(result, prepare_cost, prepare_cost, 0, 0, True)

        # test max function
        statement = ('SELECT max(fld_time) FROM ' + table_name)
        query_request = QueryRequest().set_statement(statement)
        result = self.handle.query(query_request)
        records = self.check_query_result(result, 1)
        self.assertEqual(records[0], {'Column_1': self.max_time[1]})
        self.check_cost(result, prepare_cost, prepare_cost, 0, 0, True)

        # test min function group by primary index field
        statement = ('SELECT min(fld_time) FROM ' + table_name +
                     ' GROUP BY fld_sid')
        query_request = QueryRequest().set_statement(statement)
        count = 0
        while True:
            count += 1
            result = self.handle.query(query_request)
            records = result.get_results()
            if query_request.is_done():
                self.check_query_result(result, num_sids, rec=records)
                for idx in range(num_sids):
                    self.assertEqual(records[idx],
                                     {'Column_1': self.min_time[idx]})
                self.check_cost(result, 0, 0, 0, 0)
                break
            else:
                self.check_query_result(result, 0, True, records)
                self.assertEqual(records, [])
                self.check_cost(result, 0, 0, 0, 0, True)
        self.assertEqual(count, 2)

        # test max function group by primary index field
        statement = ('SELECT max(fld_time) FROM ' + table_name +
                     ' GROUP BY fld_sid')
        query_request = QueryRequest().set_statement(statement)
        count = 0
        while True:
            count += 1
            result = self.handle.query(query_request)
            records = result.get_results()
            if query_request.is_done():
                self.check_query_result(result, num_sids, rec=records)
                for idx in range(num_sids):
                    self.assertEqual(records[idx],
                                     {'Column_1': self.max_time[idx]})
                self.check_cost(result, 0, 0, 0, 0)
                break
            else:
                self.check_query_result(result, 0, True, records)
                self.assertEqual(records, [])
                self.check_cost(result, 0, 0, 0, 0, True)
        self.assertEqual(count, 2)

        # test min function group by secondary index field
        statement = ('SELECT min(fld_time) FROM ' + table_name +
                     ' GROUP BY fld_bool')
        query_request = QueryRequest().set_statement(statement)
        while True:
            result = self.handle.query(query_request)
            records = result.get_results()
            if query_request.is_done():
                self.check_query_result(result, num_sids, rec=records)
                for idx in range(num_sids):
                    self.assertEqual(records[idx],
                                     {'Column_1': self.min_time[idx]})
                self.check_cost(result, 0, 0, 0, 0)
                break
            else:
                self.check_query_result(result, 0, True, records)
                self.assertEqual(records, [])
                self.check_cost(result, 0, 0, 0, 0, True)

        # test max function group by secondary index field
        statement = ('SELECT max(fld_time) FROM ' + table_name +
                     ' GROUP BY fld_bool')
        query_request = QueryRequest().set_statement(statement)
        while True:
            result = self.handle.query(query_request)
            records = result.get_results()
            if query_request.is_done():
                self.check_query_result(result, num_sids, rec=records)
                for idx in range(num_sids):
                    self.assertEqual(records[idx],
                                     {'Column_1': self.max_time[idx]})
                self.check_cost(result, 0, 0, 0, 0)
                break
            else:
                self.check_query_result(result, 0, True, records)
                self.assertEqual(records, [])
                self.check_cost(result, 0, 0, 0, 0, True)

    def testQueryFuncSumGroupBy(self):
        num_records = 12
        num_sids = 2
        # test sum function
        statement = ('SELECT sum(fld_double) FROM ' + table_name)
        query_request = QueryRequest().set_statement(statement)
        result = self.handle.query(query_request)
        records = self.check_query_result(result, 1)
        self.assertEqual(records[0], {'Column_1': 3.1415 * num_records})
        self.check_cost(result, prepare_cost, prepare_cost, 0, 0, True)

        # test sum function group by primary index field
        statement = ('SELECT sum(fld_double) FROM ' + table_name +
                     ' GROUP BY fld_sid')
        query_request = QueryRequest().set_statement(statement)
        count = 0
        while True:
            count += 1
            result = self.handle.query(query_request)
            records = result.get_results()
            if query_request.is_done():
                self.check_query_result(result, num_sids, rec=records)
                for idx in range(num_sids):
                    self.assertEqual(
                        records[idx],
                        {'Column_1': 3.1415 * (num_records // num_sids)})
                self.check_cost(result, 0, 0, 0, 0)
                break
            else:
                self.check_query_result(result, 0, True, records)
                self.assertEqual(records, [])
                self.check_cost(result, 0, 0, 0, 0, True)
        self.assertEqual(count, 2)

        # test sum function group by secondary index field
        statement = ('SELECT sum(fld_double) FROM ' + table_name +
                     ' GROUP BY fld_bool')
        query_request = QueryRequest().set_statement(statement)
        while True:
            result = self.handle.query(query_request)
            records = result.get_results()
            if query_request.is_done():
                self.check_query_result(result, num_sids, rec=records)
                for idx in range(num_sids):
                    self.assertEqual(
                        records[idx],
                        {'Column_1': 3.1415 * (num_records // num_sids)})
                self.check_cost(result, 0, 0, 0, 0)
                break
            else:
                self.check_query_result(result, 0, True, records)
                self.assertEqual(records, [])
                self.check_cost(result, 0, 0, 0, 0, True)

    def testQueryFuncAvgGroupBy(self):
        num_sids = 2
        # test avg function
        statement = ('SELECT avg(fld_double) FROM ' + table_name)
        query_request = QueryRequest().set_statement(statement)
        result = self.handle.query(query_request)
        records = self.check_query_result(result, 1)
        self.assertEqual(records[0], {'Column_1': 3.1415})
        self.check_cost(result, prepare_cost, prepare_cost, 0, 0, True)

        # test avg function group by primary index field
        statement = ('SELECT avg(fld_double) FROM ' + table_name +
                     ' GROUP BY fld_sid')
        query_request = QueryRequest().set_statement(statement)
        count = 0
        while True:
            count += 1
            result = self.handle.query(query_request)
            records = result.get_results()
            if query_request.is_done():
                self.check_query_result(result, num_sids, rec=records)
                for idx in range(num_sids):
                    self.assertEqual(records[idx], {'Column_1': 3.1415})
                self.check_cost(result, 0, 0, 0, 0)
                break
            else:
                self.check_query_result(result, 0, True, records)
                self.assertEqual(records, [])
                self.check_cost(result, 0, 0, 0, 0, True)
        self.assertEqual(count, 2)

        # test avg function group by secondary index field
        statement = ('SELECT avg(fld_double) FROM ' + table_name +
                     ' GROUP BY fld_bool')
        query_request = QueryRequest().set_statement(statement)
        while True:
            result = self.handle.query(query_request)
            records = result.get_results()
            if query_request.is_done():
                self.check_query_result(result, num_sids, rec=records)
                for idx in range(num_sids):
                    self.assertEqual(records[idx], {'Column_1': 3.1415})
                self.check_cost(result, 0, 0, 0, 0)
                break
            else:
                self.check_query_result(result, 0, True, records)
                self.assertEqual(records, [])
                self.check_cost(result, 0, 0, 0, 0, True)

    def testQueryFuncCountGroupBy(self):
        num_records = 12
        num_sids = 2
        # test count function
        statement = ('SELECT count(*) FROM ' + table_name)
        query_request = QueryRequest().set_statement(statement)
        result = self.handle.query(query_request)
        records = self.check_query_result(result, 1)
        self.assertEqual(records[0], {'Column_1': num_records})
        self.check_cost(result, prepare_cost, prepare_cost, 0, 0, True)

        # test count function group by primary index field
        statement = ('SELECT count(*) FROM ' + table_name +
                     ' GROUP BY fld_sid')
        query_request = QueryRequest().set_statement(statement)
        count = 0
        while True:
            count += 1
            result = self.handle.query(query_request)
            records = result.get_results()
            if query_request.is_done():
                self.check_query_result(result, num_sids, rec=records)
                for idx in range(num_sids):
                    self.assertEqual(records[idx],
                                     {'Column_1': num_records // num_sids})
                self.check_cost(result, 0, 0, 0, 0)
                break
            else:
                self.check_query_result(result, 0, True, records)
                self.assertEqual(records, [])
                self.check_cost(result, 0, 0, 0, 0, True)
        self.assertEqual(count, 2)

        # test count function group by secondary index field
        statement = ('SELECT count(*) FROM ' + table_name +
                     ' GROUP BY fld_bool')
        query_request = QueryRequest().set_statement(statement)
        while True:
            result = self.handle.query(query_request)
            records = result.get_results()
            if query_request.is_done():
                self.check_query_result(result, num_sids, rec=records)
                for idx in range(num_sids):
                    self.assertEqual(records[idx],
                                     {'Column_1': num_records // num_sids})
                self.check_cost(result, 0, 0, 0, 0)
                break
            else:
                self.check_query_result(result, 0, True, records)
                self.assertEqual(records, [])
                self.check_cost(result, 0, 0, 0, 0, True)

    def testQueryOrderByWithLimit(self):
        num_records = 12
        limit = 10
        # test order by primary index field with limit
        statement = ('SELECT fld_str FROM ' + table_name +
                     ' ORDER BY fld_sid, fld_id LIMIT 10')
        query_request = QueryRequest().set_statement(statement)
        count = 0
        while True:
            count += 1
            result = self.handle.query(query_request)
            records = result.get_results()
            if query_request.is_done():
                self.check_query_result(result, limit, rec=records)
                for idx in range(limit):
                    self.assertEqual(
                        records[idx], {
                            'fld_str':
                            '{"name": u' +
                            str(num_records - idx - 1).zfill(2) + '}'
                        })
                self.check_cost(result, 0, 0, 0, 0)
                break
            else:
                self.check_query_result(result, 0, True, records)
                self.assertEqual(records, [])
                self.check_cost(result, 0, 0, 0, 0, True)
        self.assertEqual(count, 2)

        # test order by secondary index field with limit
        statement = ('SELECT fld_str FROM ' + table_name +
                     ' ORDER BY fld_str LIMIT 10')
        query_request = QueryRequest().set_statement(statement)
        while True:
            result = self.handle.query(query_request)
            records = result.get_results()
            if query_request.is_done():
                self.check_query_result(result, limit, rec=records)
                for idx in range(limit):
                    self.assertEqual(
                        records[idx],
                        {'fld_str': '{"name": u' + str(idx).zfill(2) + '}'})
                self.check_cost(result, 0, 0, 0, 0)
                break
            else:
                self.check_query_result(result, 0, True, records)
                self.assertEqual(records, [])
                self.check_cost(result, 0, 0, 0, 0, True)

    def testQueryOrderByWithOffset(self):
        offset = 4
        num_get = 8
        # test order by primary index field with offset
        statement = ('DECLARE $offset INTEGER; SELECT fld_str FROM ' +
                     table_name + ' ORDER BY fld_sid, fld_id OFFSET $offset')
        prepare_request = PrepareRequest().set_statement(statement)
        prepare_result = self.handle.prepare(prepare_request)
        prepared_statement = prepare_result.get_prepared_statement()
        prepared_statement.set_variable('$offset', offset)
        query_request = QueryRequest().set_prepared_statement(
            prepared_statement)
        count = 0
        while True:
            count += 1
            result = self.handle.query(query_request)
            records = result.get_results()
            if query_request.is_done():
                self.check_query_result(result, num_get, rec=records)
                for idx in range(num_get):
                    self.assertEqual(
                        records[idx], {
                            'fld_str':
                            '{"name": u' + str(num_get - idx - 1).zfill(2) +
                            '}'
                        })
                self.check_cost(result, 0, 0, 0, 0)
                break
            else:
                self.check_query_result(result, 0, True, records)
                self.assertEqual(records, [])
                self.check_cost(result, 0, 0, 0, 0, True)
        self.assertEqual(count, 2)

        # test order by secondary index field with offset
        statement = ('DECLARE $offset INTEGER; SELECT fld_str FROM ' +
                     table_name + ' ORDER BY fld_str OFFSET $offset')
        prepare_request = PrepareRequest().set_statement(statement)
        prepare_result = self.handle.prepare(prepare_request)
        prepared_statement = prepare_result.get_prepared_statement()
        prepared_statement.set_variable('$offset', offset)
        query_request = QueryRequest().set_prepared_statement(
            prepared_statement)
        while True:
            result = self.handle.query(query_request)
            records = result.get_results()
            if query_request.is_done():
                self.check_query_result(result, num_get, rec=records)
                for idx in range(num_get):
                    self.assertEqual(records[idx], {
                        'fld_str':
                        '{"name": u' + str(offset + idx).zfill(2) + '}'
                    })
                self.check_cost(result, 0, 0, 0, 0)
                break
            else:
                self.check_query_result(result, 0, True, records)
                self.assertEqual(records, [])
                self.check_cost(result, 0, 0, 0, 0, True)

    def testQueryFuncGeoNear(self):
        num_get = 6
        longitude = 21.547
        latitude = 37.291
        # test geo_near function
        statement = ('SELECT tb.fld_json.location FROM ' + table_name +
                     ' tb WHERE geo_near(tb.fld_json.location, ' +
                     '{"type": "point", "coordinates": [' + str(longitude) +
                     ', ' + str(latitude) + ']}, 215000)')
        query_request = QueryRequest().set_statement(statement)
        result = self.handle.query(query_request)
        records = self.check_query_result(result, num_get)
        for i in range(1, num_get):
            pre = records[i - 1]['location']['coordinates']
            curr = records[i]['location']['coordinates']
            self.assertLess(abs(pre[0] - longitude), abs(curr[0] - longitude))
            self.assertLess(abs(pre[1] - latitude), abs(curr[1] - latitude))
        self.check_cost(result, prepare_cost, prepare_cost, 0, 0, True)

        # test geo_near function order by primary index field
        statement = (
            'SELECT fld_str FROM ' + table_name + ' tb WHERE geo_near(' +
            'tb.fld_json.location, {"type": "point", "coordinates": [' +
            str(longitude) + ', ' + str(latitude) + ']}, 215000) ' +
            'ORDER BY fld_sid, fld_id')
        query_request = QueryRequest().set_statement(statement)
        count = 0
        while True:
            count += 1
            result = self.handle.query(query_request)
            records = result.get_results()
            if query_request.is_done():
                self.check_query_result(result, num_get, rec=records)
                name = [10, 9, 8, 4, 3, 2]
                for i in range(num_get):
                    self.assertEqual(records[i], {
                        'fld_str':
                        '{"name": u' + str(name[i]).zfill(2) + '}'
                    })
                self.check_cost(result, 0, 0, 0, 0)
                break
            else:
                self.check_query_result(result, 0, True, records)
                self.assertEqual(records, [])
                self.check_cost(result, 0, 0, 0, 0, True)
        self.assertEqual(count, 2)

        # test geo_near function order by secondary index field
        statement = (
            'SELECT fld_str FROM ' + table_name + ' tb WHERE geo_near(' +
            'tb.fld_json.location, {"type": "point", "coordinates": [' +
            str(longitude) + ', ' + str(latitude) + ']}, 215000) ' +
            'ORDER BY fld_str')
        query_request = QueryRequest().set_statement(statement)
        while True:
            result = self.handle.query(query_request)
            records = result.get_results()
            if query_request.is_done():
                self.check_query_result(result, num_get, rec=records)
                name = [2, 3, 4, 8, 9, 10]
                for i in range(num_get):
                    self.assertEqual(records[i], {
                        'fld_str':
                        '{"name": u' + str(name[i]).zfill(2) + '}'
                    })
                self.check_cost(result, 0, 0, 0, 0)
                break
            else:
                self.check_query_result(result, 0, True, records)
                self.assertEqual(records, [])
                self.check_cost(result, 0, 0, 0, 0, True)

    @staticmethod
    def _expected_row(fld_sid, fld_id, fld_long=None):
        expected_row = OrderedDict()
        expected_row['fld_sid'] = fld_sid
        expected_row['fld_id'] = fld_id
        if fld_long is not None:
            expected_row['fld_long'] = fld_long
        return expected_row
Beispiel #4
0
class TestGet(unittest.TestCase, TestBase):
    @classmethod
    def setUpClass(cls):
        TestBase.set_up_class()
        table_ttl = TimeToLive.of_hours(16)
        create_statement = ('CREATE TABLE ' + table_name +
                            '(fld_sid INTEGER, fld_id INTEGER, \
fld_long LONG, fld_float FLOAT, fld_double DOUBLE, fld_bool BOOLEAN, \
fld_str STRING, fld_bin BINARY, fld_time TIMESTAMP(7), fld_num NUMBER, \
fld_json JSON, fld_arr ARRAY(STRING), fld_map MAP(STRING), \
fld_rec RECORD(fld_id LONG, fld_bool BOOLEAN, fld_str STRING), \
PRIMARY KEY(SHARD(fld_sid), fld_id)) USING TTL ' + str(table_ttl))
        create_request = TableRequest().set_statement(
            create_statement).set_table_limits(TableLimits(5000, 5000, 50))
        cls._result = TestBase.table_request(create_request, State.ACTIVE)
        global row, tb_expect_expiration, hour_in_milliseconds
        row = {
            'fld_sid': 1,
            'fld_id': 1,
            'fld_long': 2147483648,
            'fld_float': 3.1414999961853027,
            'fld_double': 3.1415,
            'fld_bool': True,
            'fld_str': '{"name": u1, "phone": null}',
            'fld_bin': bytearray(pack('>i', 4)),
            'fld_time': datetime.now(),
            'fld_num': Decimal(5),
            'fld_json': {
                'a': '1',
                'b': None,
                'c': '3'
            },
            'fld_arr': ['a', 'b', 'c'],
            'fld_map': {
                'a': '1',
                'b': '2',
                'c': '3'
            },
            'fld_rec': {
                'fld_id': 1,
                'fld_bool': False,
                'fld_str': None
            }
        }
        put_request = PutRequest().set_value(row).set_table_name(table_name)
        cls._handle.put(put_request)
        tb_expect_expiration = table_ttl.to_expiration_time(
            int(round(time() * 1000)))
        hour_in_milliseconds = 60 * 60 * 1000

    @classmethod
    def tearDownClass(cls):
        TestBase.tear_down_class()

    def setUp(self):
        TestBase.set_up(self)
        self.key = {'fld_sid': 1, 'fld_id': 1}
        self.get_request = GetRequest().set_key(
            self.key).set_table_name(table_name).set_timeout(timeout)

    def tearDown(self):
        TestBase.tear_down(self)

    def testGetSetIllegalKey(self):
        self.assertRaises(IllegalArgumentException, self.get_request.set_key,
                          'IllegalKey')
        self.get_request.set_key({'fld_sid': 1})
        self.assertRaises(IllegalArgumentException, self.handle.get,
                          self.get_request)
        self.get_request.set_key({'fld_id': 1})
        self.assertRaises(IllegalArgumentException, self.handle.get,
                          self.get_request)

    def testGetSetIllegalKeyFromJson(self):
        self.assertRaises(ValueError, self.get_request.set_key_from_json,
                          'IllegalJson')
        self.get_request.set_key_from_json('{"invalid_field": "key"}')
        self.assertRaises(IllegalArgumentException, self.handle.get,
                          self.get_request)

    def testGetSetIllegalTableName(self):
        self.assertRaises(IllegalArgumentException,
                          self.get_request.set_table_name,
                          {'name': table_name})
        self.get_request.set_table_name('IllegalTable')
        self.assertRaises(TableNotFoundException, self.handle.get,
                          self.get_request)

    def testGetSetIllegalConsistency(self):
        self.assertRaises(IllegalArgumentException,
                          self.get_request.set_consistency,
                          'IllegalConsistency')

    def testGetSetIllegalTimeout(self):
        self.assertRaises(IllegalArgumentException,
                          self.get_request.set_timeout, 'IllegalTimeout')
        self.assertRaises(IllegalArgumentException,
                          self.get_request.set_timeout, 0)
        self.assertRaises(IllegalArgumentException,
                          self.get_request.set_timeout, -1)

    def testGetWithoutKey(self):
        self.get_request.set_key(None)
        self.assertRaises(IllegalArgumentException, self.handle.get,
                          self.get_request)

    def testGetGets(self):
        self.assertEqual(self.get_request.get_key(), self.key)

    def testGetIllegalRequest(self):
        self.assertRaises(IllegalArgumentException, self.handle.get,
                          'IllegalRequest')

    def testGetNormal(self):
        result = self.handle.get(self.get_request)
        self.assertEqual(result.get_value(), row)
        actual_expiration = result.get_expiration_time()
        actual_expect_diff = actual_expiration - tb_expect_expiration
        self.assertGreater(actual_expiration, 0)
        self.assertLess(actual_expect_diff, hour_in_milliseconds)
        self.assertEqual(result.get_read_kb(), 1)
        self.assertEqual(result.get_read_units(), 2)
        self.assertEqual(result.get_write_kb(), 0)
        self.assertEqual(result.get_write_units(), 0)

    def testGetEventual(self):
        self.get_request.set_consistency(Consistency.EVENTUAL)
        result = self.handle.get(self.get_request)
        self.assertEqual(result.get_value(), row)
        self.assertIsNotNone(result.get_version())
        actual_expiration = result.get_expiration_time()
        actual_expect_diff = actual_expiration - tb_expect_expiration
        self.assertGreater(actual_expiration, 0)
        self.assertLess(actual_expect_diff, hour_in_milliseconds)
        self.assertEqual(result.get_read_kb(), 1)
        self.assertEqual(result.get_read_units(), 1)
        self.assertEqual(result.get_write_kb(), 0)
        self.assertEqual(result.get_write_units(), 0)

    def testGetNonExisting(self):
        self.get_request.set_key({'fld_sid': 2, 'fld_id': 2})
        result = self.handle.get(self.get_request)
        self.assertIsNone(result.get_value())
        self.assertIsNone(result.get_version())
        self.assertEqual(result.get_expiration_time(), 0)
        self.assertEqual(result.get_read_kb(), 1)
        self.assertEqual(result.get_read_units(), 2)
        self.assertEqual(result.get_write_kb(), 0)
        self.assertEqual(result.get_write_units(), 0)
class TestWriteMultiple(unittest.TestCase, TestBase):
    @classmethod
    def setUpClass(cls):
        cls.set_up_class()
        create_statement = ('CREATE TABLE ' + table_name +
                            '(fld_sid INTEGER, fld_id INTEGER, \
fld_long LONG, fld_float FLOAT, fld_double DOUBLE, fld_bool BOOLEAN, \
fld_str STRING, fld_bin BINARY, fld_time TIMESTAMP(8), fld_num NUMBER, \
fld_json JSON, fld_arr ARRAY(STRING), fld_map MAP(STRING), \
fld_rec RECORD(fld_id LONG, fld_bool BOOLEAN, fld_str STRING), \
PRIMARY KEY(SHARD(fld_sid), fld_id))')
        limits = TableLimits(50, 50, 1)
        create_request = TableRequest().set_statement(
            create_statement).set_table_limits(limits)
        cls.table_request(create_request)

    @classmethod
    def tearDownClass(cls):
        cls.tear_down_class()

    def setUp(self):
        self.set_up()
        self.shardkeys = [0, 1]
        self.ids = [0, 1, 2, 3, 4, 5]
        self.rows = list()
        self.new_rows = list()
        self.versions = list()
        self.requests = list()
        self.illegal_requests = list()
        ttl = TimeToLive.of_days(16)
        for sk in self.shardkeys:
            self.rows.append(list())
            self.new_rows.append(list())
            self.versions.append(list())
            for i in self.ids:
                row = get_row()
                row['fld_sid'] = sk
                row['fld_id'] = i
                new_row = deepcopy(row)
                new_row['fld_long'] = 2147483649
                self.rows[sk].append(row)
                self.new_rows[sk].append(new_row)
                put_request = PutRequest().set_value(row).set_table_name(
                    table_name).set_ttl(ttl)
                self.versions[sk].append(
                    self.handle.put(put_request).get_version())
        self.old_expect_expiration = ttl.to_expiration_time(
            int(round(time() * 1000)))
        self.ttl = TimeToLive.of_hours(1)
        self.ops_sk = 0
        illegal_sk = 1
        self.requests.append(PutRequest().set_value(
            self.new_rows[self.ops_sk][0]).set_table_name(table_name).set_ttl(
                self.ttl).set_return_row(True))
        self.requests.append(PutRequest().set_value(self.new_rows[
            self.ops_sk][1]).set_table_name(table_name).set_option(
                PutOption.IF_ABSENT).set_ttl(self.ttl).set_return_row(True))
        self.requests.append(PutRequest().set_value(
            self.new_rows[self.ops_sk][2]).set_use_table_default_ttl(
                True).set_table_name(table_name).set_option(
                    PutOption.IF_PRESENT).set_return_row(True))
        self.requests.append(PutRequest().set_value(
            self.new_rows[self.ops_sk][3]).set_table_name(table_name).set_ttl(
                self.ttl).set_option(PutOption.IF_VERSION).set_match_version(
                    self.versions[self.ops_sk][3]).set_return_row(True))
        self.requests.append(DeleteRequest().set_key({
            'fld_sid': self.ops_sk,
            'fld_id': 4
        }).set_table_name(table_name).set_return_row(True))
        self.requests.append(DeleteRequest().set_key({
            'fld_sid': self.ops_sk,
            'fld_id': 5
        }).set_table_name(table_name).set_return_row(True).set_match_version(
            self.versions[self.ops_sk][0]))
        self.illegal_requests.append(DeleteRequest().set_key({
            'fld_sid': self.ops_sk,
            'fld_id': 0
        }).set_table_name('IllegalUsers'))
        self.illegal_requests.append(DeleteRequest().set_key({
            'fld_sid': illegal_sk,
            'fld_id': 0
        }).set_table_name(table_name))
        self.write_multiple_request = WriteMultipleRequest().set_timeout(
            timeout)
        self.get_request = GetRequest().set_table_name(table_name)
        self.hour_in_milliseconds = 60 * 60 * 1000
        self.day_in_milliseconds = 24 * 60 * 60 * 1000

    def tearDown(self):
        for sk in self.shardkeys:
            key = {'fld_sid': sk}
            request = MultiDeleteRequest().set_table_name(table_name).set_key(
                key)
            self.handle.multi_delete(request)
        self.tear_down()

    def testWriteMultipleSetIllegalCompartment(self):
        self.assertRaises(IllegalArgumentException,
                          self.write_multiple_request.set_compartment, {})
        self.assertRaises(IllegalArgumentException,
                          self.write_multiple_request.set_compartment, '')

    def testWriteMultipleAddIllegalRequestAndAbortIfUnsuccessful(self):
        self.assertRaises(IllegalArgumentException,
                          self.write_multiple_request.add, 'IllegalRequest',
                          True)
        self.assertRaises(IllegalArgumentException,
                          self.write_multiple_request.add, PutRequest(),
                          'IllegalAbortIfUnsuccessful')
        # add two operations with different table name
        self.write_multiple_request.add(self.requests[0], True)
        self.assertRaises(IllegalArgumentException,
                          self.write_multiple_request.add,
                          self.illegal_requests[0], False)
        self.write_multiple_request.clear()
        # add two operations with different major paths
        self.write_multiple_request.add(self.requests[0],
                                        True).add(self.illegal_requests[1],
                                                  False)
        self.assertRaises(IllegalArgumentException, self.handle.write_multiple,
                          self.write_multiple_request)
        if not is_onprem():
            # add operations when the request size exceeded the limit
            self.write_multiple_request.clear()
            for op in range(64):
                row = get_row()
                row['fld_str'] = self.get_random_str(0.4)
                self.write_multiple_request.add(
                    PutRequest().set_value(row).set_table_name(table_name),
                    True)
            self.assertRaises(RequestSizeLimitException,
                              self.handle.write_multiple,
                              self.write_multiple_request)
            # add operations when sub requests reached the max number
            self.write_multiple_request.clear()
            for op in range(51):
                row = get_row()
                row['fld_id'] = op
                self.write_multiple_request.add(
                    PutRequest().set_value(row).set_table_name(table_name),
                    True)
            self.assertRaises(BatchOperationNumberLimitException,
                              self.handle.write_multiple,
                              self.write_multiple_request)

    def testWriteMultipleGetRequestWithIllegalIndex(self):
        self.assertRaises(IllegalArgumentException,
                          self.write_multiple_request.get_request,
                          'IllegalIndex')
        self.assertRaises(IllegalArgumentException,
                          self.write_multiple_request.get_request, -1)
        self.assertRaises(IndexError, self.write_multiple_request.get_request,
                          0)

    def testWriteMultipleSetIllegalTimeout(self):
        self.assertRaises(IllegalArgumentException,
                          self.write_multiple_request.set_timeout,
                          'IllegalTimeout')
        self.assertRaises(IllegalArgumentException,
                          self.write_multiple_request.set_timeout, 0)
        self.assertRaises(IllegalArgumentException,
                          self.write_multiple_request.set_timeout, -1)

    def testWriteMultipleNoOperations(self):
        self.assertRaises(IllegalArgumentException, self.handle.write_multiple,
                          self.write_multiple_request)

    def testWriteMultipleGets(self):
        num_operations = 6
        for request in self.requests:
            self.write_multiple_request.add(request, True)
        self.assertIsNone(self.write_multiple_request.get_compartment())
        self.assertEqual(self.write_multiple_request.get_table_name(),
                         table_name)
        self.assertEqual(self.write_multiple_request.get_request(2),
                         self.requests[2])
        requests = self.write_multiple_request.get_operations()
        for idx in range(len(requests)):
            self.assertEqual(requests[idx].get_request(), self.requests[idx])
            self.assertTrue(requests[idx].is_abort_if_unsuccessful())
        self.assertEqual(self.write_multiple_request.get_num_operations(),
                         num_operations)
        self.assertEqual(self.write_multiple_request.get_timeout(), timeout)
        self.write_multiple_request.clear()
        self.assertIsNone(self.write_multiple_request.get_table_name())
        self.assertEqual(self.write_multiple_request.get_operations(), [])
        self.assertEqual(self.write_multiple_request.get_num_operations(), 0)
        self.assertEqual(self.write_multiple_request.get_timeout(), timeout)

    def testWriteMultipleNormal(self):
        num_operations = 6
        for request in self.requests:
            self.write_multiple_request.add(request, False)
        result = self.handle.write_multiple(self.write_multiple_request)
        expect_expiration = self.ttl.to_expiration_time(
            int(round(time() * 1000)))
        op_results = self._check_write_multiple_result(result, num_operations)
        for idx in range(result.size()):
            if idx == 1 or idx == 5:
                # putIfAbsent and deleteIfVersion failed
                self._check_operation_result(
                    op_results[idx],
                    existing_version=self.versions[self.ops_sk][idx],
                    existing_value=self.rows[self.ops_sk][idx])
            elif idx == 4:
                # delete succeed
                self._check_operation_result(op_results[idx], success=True)
            else:
                # put, putIfPresent and putIfVersion succeed
                self._check_operation_result(op_results[idx], True, True)
        self.check_cost(result, 5, 10, 7, 7)
        # check the records after write_multiple request succeed
        for sk in self.shardkeys:
            for i in self.ids:
                self.get_request.set_key({'fld_sid': sk, 'fld_id': i})
                result = self.handle.get(self.get_request)
                if sk == 1 or i == 1 or i == 5:
                    self.check_get_result(result, self.rows[sk][i],
                                          self.versions[sk][i],
                                          self.old_expect_expiration,
                                          TimeUnit.DAYS)
                elif i == 4:
                    self.check_get_result(result)
                elif i == 2:
                    self.check_get_result(result,
                                          self.new_rows[sk][i],
                                          self.versions[sk][i],
                                          ver_eq=False)
                else:
                    self.check_get_result(result, self.new_rows[sk][i],
                                          self.versions[sk][i],
                                          expect_expiration, TimeUnit.HOURS,
                                          False)
                self.check_cost(result, 1, 2, 0, 0)

    def testWriteMultipleAbortIfUnsuccessful(self):
        failed_idx = 1
        for request in self.requests:
            self.write_multiple_request.add(request, True)
        result = self.handle.write_multiple(self.write_multiple_request)
        op_results = self._check_write_multiple_result(result, 1, True,
                                                       failed_idx, False)
        self._check_operation_result(
            op_results[0],
            existing_version=self.versions[self.ops_sk][failed_idx],
            existing_value=self.rows[self.ops_sk][failed_idx])
        failed_result = result.get_failed_operation_result()
        self._check_operation_result(
            failed_result,
            existing_version=self.versions[self.ops_sk][failed_idx],
            existing_value=self.rows[self.ops_sk][failed_idx])
        self.check_cost(result, 1, 2, 2, 2)
        # check the records after multi_delete request failed
        for sk in self.shardkeys:
            for i in self.ids:
                self.get_request.set_key({'fld_sid': sk, 'fld_id': i})
                result = self.handle.get(self.get_request)
                self.check_get_result(result, self.rows[sk][i],
                                      self.versions[sk][i],
                                      self.old_expect_expiration,
                                      TimeUnit.DAYS)
                self.check_cost(result, 1, 2, 0, 0)

    def testWriteMultipleWithIdentityColumn(self):
        num_operations = 10
        id_table = table_prefix + 'Identity'
        create_request = TableRequest().set_statement(
            'CREATE TABLE ' + id_table + '(sid INTEGER, id LONG GENERATED \
ALWAYS AS IDENTITY, name STRING, PRIMARY KEY(SHARD(sid), id))')
        create_request.set_table_limits(TableLimits(50, 50, 1))
        self.table_request(create_request)

        # add ten operations
        row = {'name': 'myname', 'sid': 1}
        for idx in range(num_operations):
            put_request = PutRequest().set_table_name(id_table).set_value(row)
            put_request.set_identity_cache_size(idx)
            self.write_multiple_request.add(put_request, False)
        # execute the write multiple request
        versions = list()
        result = self.handle.write_multiple(self.write_multiple_request)
        op_results = self._check_write_multiple_result(result, num_operations)
        generated = 0
        for idx in range(result.size()):
            version, generated = self._check_operation_result(
                op_results[idx], True, True, generated)
            versions.append(version)
        self.check_cost(result, 0, 0, num_operations, num_operations)
        # check the records after write_multiple request succeed
        self.get_request.set_table_name(id_table)
        for idx in range(num_operations):
            curr_id = generated - num_operations + idx + 1
            self.get_request.set_key({'sid': 1, 'id': curr_id})
            result = self.handle.get(self.get_request)
            expected = OrderedDict()
            expected['sid'] = 1
            expected['id'] = curr_id
            expected['name'] = 'myname'
            self.check_get_result(result, expected, versions[idx])
            self.check_cost(result, 1, 2, 0, 0)

    def _check_operation_result(self,
                                op_result,
                                version=False,
                                success=False,
                                last_generated=None,
                                existing_version=None,
                                existing_value=None):
        # check version of operation result
        ver = op_result.get_version()
        self.assertIsNotNone(ver) if version else self.assertIsNone(ver)
        # check if the operation success
        self.assertEqual(op_result.get_success(), success)
        # check generated value of operation result
        generated = op_result.get_generated_value()
        if last_generated is None:
            self.assertIsNone(generated)
        else:
            self.assertGreater(generated, last_generated)
        # check existing version
        existing_ver = op_result.get_existing_version()
        (self.assertIsNone(existing_ver)
         if existing_version is None else self.assertEqual(
             existing_ver.get_bytes(), existing_version.get_bytes()))
        # check existing value
        self.assertEqual(op_result.get_existing_value(), existing_value)
        return ver, generated

    def _check_write_multiple_result(self,
                                     result,
                                     num_operations,
                                     has_failed_operation=False,
                                     failed_index=-1,
                                     success=True):
        # check number of operations
        self.assertEqual(result.size(), num_operations)
        # check failed operation
        failed_result = result.get_failed_operation_result()
        (self.assertIsNotNone(failed_result)
         if has_failed_operation else self.assertIsNone(failed_result))
        # check failed operation index
        self.assertEqual(result.get_failed_operation_index(), failed_index)
        # check operation status
        self.assertEqual(result.get_success(), success)
        return result.get_results()
        def _do_rate_limited_ops(self, num_seconds, read_limit, write_limit,
                                 max_rows, check_units, use_percent,
                                 use_external_limiters):
            """
            Runs puts and gets continuously for N seconds.

            Verify that the resultant RUs/WUs used match the given rate limits.
            """
            if read_limit == 0 and write_limit == 0:
                return
            put_request = PutRequest().set_table_name(table_name)
            get_request = GetRequest().set_table_name(table_name)
            key = dict()
            # TODO: random sizes 0-nKB.
            value = dict()
            value['name'] = 'jane'

            start_time = int(round(time() * 1000))
            end_time = start_time + num_seconds * 1000
            read_units_used = 0
            write_units_used = 0
            total_delayed_ms = 0
            throttle_exceptions = 0
            rlim = None
            wlim = None

            max_val = float(read_limit + write_limit)
            if not use_external_limiters:
                # Reset internal limiters so they don't have unused units.
                self.handle.get_client().reset_rate_limiters(table_name)
            else:
                rlim = SimpleRateLimiter(read_limit * use_percent / 100.0, 1)
                wlim = SimpleRateLimiter(write_limit * use_percent / 100.0, 1)

            while True:
                fld_id = int(random() * max_rows)
                if read_limit == 0:
                    do_put = True
                elif write_limit == 0:
                    do_put = False
                else:
                    v = int(random() * max_val)
                    do_put = v >= read_limit
                try:
                    if do_put:
                        value['id'] = fld_id
                        put_request.set_value(value).set_read_rate_limiter(
                            None).set_write_rate_limiter(wlim)
                        pres = self.handle.put(put_request)
                        write_units_used += pres.get_write_units()
                        total_delayed_ms += pres.get_rate_limit_delayed_ms()
                        rs = pres.get_retry_stats()
                        if rs is not None:
                            throttle_exceptions += rs.get_num_exceptions(
                                WriteThrottlingException.__class__.__name__)
                    else:
                        key['id'] = fld_id
                        get_request.set_key(key).set_read_rate_limiter(
                            rlim).set_write_rate_limiter(None)
                        gres = self.handle.get(get_request)
                        read_units_used += gres.get_read_units()
                        total_delayed_ms += gres.get_rate_limit_delayed_ms()
                        rs = gres.get_retry_stats()
                        if rs is not None:
                            throttle_exceptions += rs.get_num_exceptions(
                                ReadThrottlingException.__class__.__name__)
                except ReadThrottlingException:
                    self.fail(
                        'Expected no read throttling exceptions, got one.')
                except WriteThrottlingException:
                    self.fail(
                        'Expected no write throttling exceptions, got one.')

                if int(round(time() * 1000)) >= end_time:
                    break
            num_seconds = (int(round(time() * 1000)) - start_time) / 1000
            rus = read_units_used / num_seconds
            wus = write_units_used / num_seconds
            if not check_units:
                return
            use_percent /= 100.0
            if (rus < read_limit * use_percent * 0.8
                    or rus > read_limit * use_percent * 1.2):
                self.fail('Gets: Expected around ' +
                          str(read_limit * use_percent) + ' RUs, got ' +
                          str(rus))
            if (wus < write_limit * use_percent * 0.8
                    or wus > write_limit * use_percent * 1.2):
                self.fail('Puts: Expected around ' +
                          str(write_limit * use_percent) + ' WUs, got ' +
                          str(wus))
class TestWriteMultiple(unittest.TestCase, TestBase):
    @classmethod
    def setUpClass(cls):
        TestBase.set_up_class()
        create_statement = (
            'CREATE TABLE ' + table_name + '(fld_sid INTEGER, fld_id INTEGER, \
fld_long LONG, fld_float FLOAT, fld_double DOUBLE, fld_bool BOOLEAN, \
fld_str STRING, fld_bin BINARY, fld_time TIMESTAMP(8), fld_num NUMBER, \
fld_json JSON, fld_arr ARRAY(STRING), fld_map MAP(STRING), \
fld_rec RECORD(fld_id LONG, fld_bool BOOLEAN, fld_str STRING), \
PRIMARY KEY(SHARD(fld_sid), fld_id))')
        limits = TableLimits(5000, 5000, 50)
        create_request = TableRequest().set_statement(
            create_statement).set_table_limits(limits)
        cls._result = TestBase.table_request(create_request, State.ACTIVE)

    @classmethod
    def tearDownClass(cls):
        TestBase.tear_down_class()

    def setUp(self):
        TestBase.set_up(self)
        self.shardkeys = [0, 1]
        self.ids = [0, 1, 2, 3, 4, 5]
        self.rows = list()
        self.new_rows = list()
        self.versions = list()
        self.requests = list()
        self.illegal_requests = list()
        ttl = TimeToLive.of_days(16)
        for sk in self.shardkeys:
            self.rows.append(list())
            self.new_rows.append(list())
            self.versions.append(list())
            for i in self.ids:
                row = {'fld_sid': sk, 'fld_id': i, 'fld_long': 2147483648,
                       'fld_float': 3.1414999961853027, 'fld_double': 3.1415,
                       'fld_bool': True,
                       'fld_str': '{"name": u1, "phone": null}',
                       'fld_bin': bytearray(pack('>i', 4)),
                       'fld_time': datetime.now(), 'fld_num': Decimal(5),
                       'fld_json': {'a': '1', 'b': None, 'c': '3'},
                       'fld_arr': ['a', 'b', 'c'],
                       'fld_map': {'a': '1', 'b': '2', 'c': '3'},
                       'fld_rec': {'fld_id': 1, 'fld_bool': False,
                                   'fld_str': None}}
                new_row = deepcopy(row)
                new_row.update({'fld_long': 2147483649})
                self.rows[sk].append(row)
                self.new_rows[sk].append(new_row)
                put_request = PutRequest().set_value(row).set_table_name(
                    table_name).set_ttl(ttl)
                self.versions[sk].append(
                    self.handle.put(put_request).get_version())
        self.old_expect_expiration = ttl.to_expiration_time(
            int(round(time() * 1000)))
        self.ttl = TimeToLive.of_hours(1)
        self.ops_sk = 0
        illegal_sk = 1
        self.requests.append(PutRequest().set_value(
            self.new_rows[self.ops_sk][0]).set_table_name(table_name).set_ttl(
            self.ttl).set_return_row(True))
        self.requests.append(PutRequest().set_value(
            self.new_rows[self.ops_sk][1]).set_table_name(
            table_name).set_option(PutOption.IF_ABSENT).set_ttl(
            self.ttl).set_return_row(True))
        self.requests.append(PutRequest().set_value(
            self.new_rows[self.ops_sk][2]).set_use_table_default_ttl(
            True).set_table_name(table_name).set_option(
            PutOption.IF_PRESENT).set_return_row(True))
        self.requests.append(PutRequest().set_value(
            self.new_rows[self.ops_sk][3]).set_table_name(
            table_name).set_option(PutOption.IF_VERSION).set_ttl(
            self.ttl).set_match_version(
            self.versions[self.ops_sk][3]).set_return_row(True))
        self.requests.append(DeleteRequest().set_key(
            {'fld_sid': self.ops_sk, 'fld_id': 4}).set_table_name(
            table_name).set_return_row(True))
        self.requests.append(DeleteRequest().set_key(
            {'fld_sid': self.ops_sk, 'fld_id': 5}).set_table_name(
                table_name).set_match_version(
                self.versions[self.ops_sk][0]).set_return_row(True))
        self.illegal_requests.append(DeleteRequest().set_key(
            {'fld_sid': self.ops_sk, 'fld_id': 0}).set_table_name(
            'IllegalUsers'))
        self.illegal_requests.append(DeleteRequest().set_key(
            {'fld_sid': illegal_sk, 'fld_id': 0}).set_table_name(table_name))
        self.write_multiple_request = WriteMultipleRequest().set_timeout(
            timeout)
        self.get_request = GetRequest().set_table_name(table_name)
        self.hour_in_milliseconds = 60 * 60 * 1000
        self.day_in_milliseconds = 24 * 60 * 60 * 1000

    def tearDown(self):
        for sk in self.shardkeys:
            key = {'fld_sid': sk}
            request = MultiDeleteRequest().set_table_name(
                table_name).set_key(key)
            self.handle.multi_delete(request)
        TestBase.tear_down(self)

    def testWriteMultipleAddIllegalRequestAndAbortIfUnsuccessful(self):
        self.assertRaises(IllegalArgumentException,
                          self.write_multiple_request.add,
                          'IllegalRequest', True)
        self.assertRaises(IllegalArgumentException,
                          self.write_multiple_request.add,
                          PutRequest(), 'IllegalAbortIfUnsuccessful')
        # add two operations with different table name
        self.write_multiple_request.add(self.requests[0], True)
        self.assertRaises(IllegalArgumentException,
                          self.write_multiple_request.add,
                          self.illegal_requests[0], False)
        self.write_multiple_request.clear()
        # add two operations with different major paths
        self.write_multiple_request.add(
            self.requests[0], True).add(self.illegal_requests[1], False)
        self.assertRaises(IllegalArgumentException, self.handle.write_multiple,
                          self.write_multiple_request)
        self.write_multiple_request.clear()
        # add operations when sub requests reached the max number
        count = 0
        while count < 50:
            self.write_multiple_request.add(self.requests[0], True)
            count += 1
        self.assertRaises(BatchOperationNumberLimitException,
                          self.write_multiple_request.add,
                          self.requests[0], True)

    def testWriteMultipleGetRequestWithIllegalIndex(self):
        self.assertRaises(IllegalArgumentException,
                          self.write_multiple_request.get_request,
                          'IllegalIndex')
        self.assertRaises(IllegalArgumentException,
                          self.write_multiple_request.get_request, -1)
        self.assertRaises(IndexError, self.write_multiple_request.get_request,
                          0)

    def testWriteMultipleSetIllegalTimeout(self):
        self.assertRaises(IllegalArgumentException,
                          self.write_multiple_request.set_timeout,
                          'IllegalTimeout')
        self.assertRaises(IllegalArgumentException,
                          self.write_multiple_request.set_timeout, 0)
        self.assertRaises(IllegalArgumentException,
                          self.write_multiple_request.set_timeout, -1)

    def testWriteMultipleNoOperations(self):
        self.assertRaises(IllegalArgumentException, self.handle.write_multiple,
                          self.write_multiple_request)

    def testWriteMultipleGets(self):
        num_operations = 6
        for request in self.requests:
            self.write_multiple_request.add(request, True)
        self.assertEqual(self.write_multiple_request.get_table_name(),
                         table_name)
        self.assertEqual(self.write_multiple_request.get_request(2),
                         self.requests[2])
        requests = self.write_multiple_request.get_operations()
        for idx in range(len(requests)):
            self.assertEqual(requests[idx].get_request(),
                             self.requests[idx])
            self.assertTrue(requests[idx].is_abort_if_unsuccessful())
        self.assertEqual(self.write_multiple_request.get_num_operations(),
                         num_operations)
        self.assertEqual(self.write_multiple_request.get_timeout(), timeout)
        self.write_multiple_request.clear()
        self.assertIsNone(self.write_multiple_request.get_table_name())
        self.assertEqual(self.write_multiple_request.get_operations(), [])
        self.assertEqual(self.write_multiple_request.get_num_operations(), 0)
        self.assertEqual(self.write_multiple_request.get_timeout(), timeout)

    def testWriteMultipleNormal(self):
        num_operations = 6
        for request in self.requests:
            self.write_multiple_request.add(request, False)
        result = self.handle.write_multiple(self.write_multiple_request)
        expect_expiration = self.ttl.to_expiration_time(
            int(round(time() * 1000)))
        self.assertEqual(result.size(), num_operations)
        op_results = result.get_results()
        for idx in range(result.size()):
            if idx == 1 or idx == 5:
                # putIfAbsent and deleteIfVersion failed
                self.assertIsNone(op_results[idx].get_version())
                self.assertFalse(op_results[idx].get_success())
                self.assertEqual(
                    op_results[idx].get_existing_version().get_bytes(),
                    self.versions[self.ops_sk][idx].get_bytes())
                self.assertEqual(op_results[idx].get_existing_value(),
                                 self.rows[self.ops_sk][idx])
            elif idx == 4:
                # delete succeed
                self.assertIsNone(op_results[idx].get_version())
                self.assertTrue(op_results[idx].get_success())
                self.assertIsNone(op_results[idx].get_existing_version())
                self.assertIsNone(op_results[idx].get_existing_value())
            else:
                # put, putIfPresent and putIfVersion succeed
                self.assertIsNotNone(op_results[idx].get_version())
                self.assertNotEqual(op_results[idx].get_version(),
                                    self.versions[self.ops_sk][idx])
                self.assertTrue(op_results[idx].get_success())
                self.assertIsNone(op_results[idx].get_existing_version())
                self.assertIsNone(op_results[idx].get_existing_value())
        self.assertIsNone(result.get_failed_operation_result())
        self.assertEqual(result.get_failed_operation_index(), -1)
        self.assertTrue(result.get_success())
        self.assertEqual(result.get_read_kb(), 0 + 1 + 1 + 1 + 1 + 1)
        self.assertEqual(result.get_read_units(), 0 + 2 + 2 + 2 + 2 + 2)
        self.assertEqual(result.get_write_kb(), 2 + 0 + 2 + 2 + 1 + 0)
        self.assertEqual(result.get_write_units(), 2 + 0 + 2 + 2 + 1 + 0)
        # check the records after write_multiple request succeed
        for sk in self.shardkeys:
            for i in self.ids:
                self.get_request.set_key({'fld_sid': sk, 'fld_id': i})
                result = self.handle.get(self.get_request)
                if sk == 1 or i == 1 or i == 5:
                    self.assertEqual(result.get_value(), self.rows[sk][i])
                    self.assertEqual(result.get_version().get_bytes(),
                                     self.versions[sk][i].get_bytes())
                    actual_expiration = result.get_expiration_time()
                    actual_expect_diff = (actual_expiration -
                                          self.old_expect_expiration)
                    self.assertGreater(actual_expiration, 0)
                    self.assertLess(actual_expect_diff,
                                    self.day_in_milliseconds)
                elif i == 4:
                    self.assertIsNone(result.get_value())
                    self.assertIsNone(result.get_version())
                    self.assertEqual(result.get_expiration_time(), 0)
                else:
                    self.assertEqual(result.get_value(), self.new_rows[sk][i])
                    self.assertNotEqual(result.get_version().get_bytes(), 0)
                    self.assertNotEqual(result.get_version().get_bytes(),
                                        self.versions[sk][i].get_bytes())
                    if i == 2:
                        self.assertEqual(result.get_expiration_time(), 0)
                    else:
                        actual_expiration = result.get_expiration_time()
                        actual_expect_diff = (actual_expiration -
                                              expect_expiration)
                        self.assertGreater(actual_expiration, 0)
                        self.assertLess(actual_expect_diff,
                                        self.hour_in_milliseconds)
                self.assertEqual(result.get_read_kb(), 1)
                self.assertEqual(result.get_read_units(), 2)
                self.assertEqual(result.get_write_kb(), 0)
                self.assertEqual(result.get_write_units(), 0)

    def testWriteMultipleAbortIfUnsuccessful(self):
        failed_idx = 1
        for request in self.requests:
            self.write_multiple_request.add(request, True)
        result = self.handle.write_multiple(self.write_multiple_request)
        self.assertEqual(result.size(), 1)
        op_results = result.get_results()
        self.assertIsNone(op_results[0].get_version())
        self.assertFalse(op_results[0].get_success())
        self.assertEqual(op_results[0].get_existing_version().get_bytes(),
                         self.versions[self.ops_sk][failed_idx].get_bytes())
        self.assertEqual(op_results[0].get_existing_value(),
                         self.rows[self.ops_sk][failed_idx])
        failed_result = result.get_failed_operation_result()
        self.assertIsNone(failed_result.get_version())
        self.assertFalse(failed_result.get_success())
        self.assertEqual(failed_result.get_existing_version().get_bytes(),
                         self.versions[self.ops_sk][failed_idx].get_bytes())
        self.assertEqual(failed_result.get_existing_value(),
                         self.rows[self.ops_sk][failed_idx])
        self.assertEqual(result.get_failed_operation_index(), failed_idx)
        self.assertFalse(result.get_success())
        self.assertEqual(result.get_read_kb(), 0 + 1)
        self.assertEqual(result.get_read_units(), 0 + 2)
        self.assertEqual(result.get_write_kb(), 2 + 0)
        self.assertEqual(result.get_write_units(), 2 + 0)
        # check the records after multi_delete request failed
        for sk in self.shardkeys:
            for i in self.ids:
                self.get_request.set_key({'fld_sid': sk, 'fld_id': i})
                result = self.handle.get(self.get_request)
                self.assertEqual(result.get_value(), self.rows[sk][i])
                self.assertEqual(result.get_version().get_bytes(),
                                 self.versions[sk][i].get_bytes())
                actual_expiration = result.get_expiration_time()
                actual_expect_diff = (actual_expiration -
                                      self.old_expect_expiration)
                self.assertGreater(actual_expiration, 0)
                self.assertLess(actual_expect_diff, self.day_in_milliseconds)
                self.assertEqual(result.get_read_kb(), 1)
                self.assertEqual(result.get_read_units(), 2)
                self.assertEqual(result.get_write_kb(), 0)
                self.assertEqual(result.get_write_units(), 0)
Beispiel #8
0
class TestQuery(unittest.TestCase, TestBase):
    @classmethod
    def setUpClass(cls):
        TestBase.set_up_class()
        index_name = 'idx_' + table_name
        create_statement = ('CREATE TABLE ' + table_name +
                            '(fld_sid INTEGER, fld_id INTEGER, \
fld_long LONG, fld_float FLOAT, fld_double DOUBLE, fld_bool BOOLEAN, \
fld_str STRING, fld_bin BINARY, fld_time TIMESTAMP(6), fld_num NUMBER, \
fld_json JSON, fld_arr ARRAY(STRING), fld_map MAP(STRING), \
fld_rec RECORD(fld_id LONG, fld_bool BOOLEAN, fld_str STRING), \
PRIMARY KEY(SHARD(fld_sid), fld_id))')
        limits = TableLimits(5000, 5000, 50)
        create_request = TableRequest().set_statement(
            create_statement).set_table_limits(limits)
        cls._result = TestBase.table_request(create_request, State.ACTIVE)
        create_index_statement = ('CREATE INDEX ' + index_name + ' ON ' +
                                  table_name + '(fld_long)')
        create_index_request = TableRequest().set_statement(
            create_index_statement)
        cls._result = TestBase.table_request(create_index_request,
                                             State.ACTIVE)
        global prepare_cost
        prepare_cost = 2
        global query_statement
        query_statement = ('SELECT fld_sid, fld_id FROM ' + table_name +
                           ' WHERE fld_sid = 1')

    @classmethod
    def tearDownClass(cls):
        TestBase.tear_down_class()

    def setUp(self):
        TestBase.set_up(self)
        self.handle_config = get_handle_config(tenant_id)
        shardkeys = [0, 1]
        ids = [0, 1, 2, 3, 4, 5]
        write_multiple_request = WriteMultipleRequest()
        for sk in shardkeys:
            for i in ids:
                row = {
                    'fld_sid': sk,
                    'fld_id': i,
                    'fld_long': 2147483648,
                    'fld_float': 3.1414999961853027,
                    'fld_double': 3.1415,
                    'fld_bool': True,
                    'fld_str': '{"name": u1, "phone": null}',
                    'fld_bin': bytearray(pack('>i', 4)),
                    'fld_time': datetime.now(),
                    'fld_num': Decimal(5),
                    'fld_json': {
                        'a': '1',
                        'b': None,
                        'c': '3'
                    },
                    'fld_arr': ['a', 'b', 'c'],
                    'fld_map': {
                        'a': '1',
                        'b': '2',
                        'c': '3'
                    },
                    'fld_rec': {
                        'fld_id': 1,
                        'fld_bool': False,
                        'fld_str': None
                    }
                }
                write_multiple_request.add(
                    PutRequest().set_value(row).set_table_name(table_name),
                    True)
            self.handle.write_multiple(write_multiple_request)
            write_multiple_request.clear()
        prepare_statement_update = (
            'DECLARE $fld_sid INTEGER; $fld_id INTEGER; UPDATE ' + table_name +
            ' u SET u.fld_long = u.fld_long + 1 WHERE fld_sid = $fld_sid ' +
            'AND fld_id = $fld_id')
        prepare_request_update = PrepareRequest().set_statement(
            prepare_statement_update)
        self.prepare_result_update = self.handle.prepare(
            prepare_request_update)
        prepare_statement_select = (
            'DECLARE $fld_long LONG; SELECT fld_sid, fld_id, fld_long FROM ' +
            table_name + ' WHERE fld_long = $fld_long')
        prepare_request_select = PrepareRequest().set_statement(
            prepare_statement_select)
        self.prepare_result_select = self.handle.prepare(
            prepare_request_select)
        self.query_request = QueryRequest().set_timeout(timeout)
        self.get_request = GetRequest().set_table_name(table_name)

    def tearDown(self):
        TestBase.tear_down(self)

    def testQuerySetIllegalLimit(self):
        self.assertRaises(IllegalArgumentException,
                          self.query_request.set_limit, 'IllegalLimit')
        self.assertRaises(IllegalArgumentException,
                          self.query_request.set_limit, -1)

    def testQuerySetIllegalMaxReadKb(self):
        self.assertRaises(IllegalArgumentException,
                          self.query_request.set_max_read_kb, 'IllegalLimit')
        self.assertRaises(IllegalArgumentException,
                          self.query_request.set_max_read_kb, -1)
        self.assertRaises(IllegalArgumentException,
                          self.query_request.set_max_read_kb, 2049)

    def testQuerySetIllegalConsistency(self):
        self.assertRaises(IllegalArgumentException,
                          self.query_request.set_consistency,
                          'IllegalConsistency')

    def testQuerySetIllegalContinuationKey(self):
        self.assertRaises(IllegalArgumentException,
                          self.query_request.set_continuation_key,
                          'IllegalContinuationKey')

    def testQuerySetIllegalStatement(self):
        self.query_request.set_statement('IllegalStatement')
        self.assertRaises(IllegalArgumentException, self.handle.query,
                          self.query_request)
        self.query_request.set_statement('SELECT fld_id FROM IllegalTableName')
        self.assertRaises(TableNotFoundException, self.handle.query,
                          self.query_request)

    def testQuerySetIllegalPreparedStatement(self):
        self.assertRaises(IllegalArgumentException,
                          self.query_request.set_prepared_statement,
                          'IllegalPreparedStatement')

    def testQuerySetIllegalTimeout(self):
        self.assertRaises(IllegalArgumentException,
                          self.query_request.set_timeout, 'IllegalTimeout')
        self.assertRaises(IllegalArgumentException,
                          self.query_request.set_timeout, 0)
        self.assertRaises(IllegalArgumentException,
                          self.query_request.set_timeout, -1)

    def testQuerySetIllegalDefaults(self):
        self.assertRaises(IllegalArgumentException,
                          self.query_request.set_defaults, 'IllegalDefaults')

    def testQuerySetDefaults(self):
        self.query_request.set_defaults(self.handle_config)
        self.assertEqual(self.query_request.get_timeout(), timeout)
        self.assertEqual(self.query_request.get_consistency(),
                         Consistency.ABSOLUTE)

    def testQueryNoStatementAndBothStatement(self):
        self.assertRaises(IllegalArgumentException, self.handle.query,
                          self.query_request)
        self.query_request.set_statement(query_statement)
        self.query_request.set_prepared_statement(self.prepare_result_select)
        self.assertRaises(IllegalArgumentException, self.handle.query,
                          self.query_request)

    def testQueryGets(self):
        continuation_key = bytearray(5)
        self.query_request.set_consistency(Consistency.EVENTUAL).set_statement(
            query_statement).set_prepared_statement(
                self.prepare_result_select).set_limit(3).set_max_read_kb(
                    2).set_continuation_key(continuation_key)
        self.assertEqual(self.query_request.get_limit(), 3)
        self.assertEqual(self.query_request.get_max_read_kb(), 2)
        self.assertEqual(self.query_request.get_consistency(),
                         Consistency.EVENTUAL)
        self.assertEqual(self.query_request.get_continuation_key(),
                         continuation_key)
        self.assertEqual(self.query_request.get_statement(), query_statement)
        self.assertEqual(self.query_request.get_prepared_statement(),
                         self.prepare_result_select.get_prepared_statement())
        self.assertEqual(self.query_request.get_timeout(), timeout)

    def testQueryIllegalRequest(self):
        self.assertRaises(IllegalArgumentException, self.handle.query,
                          'IllegalRequest')

    def testQueryStatementSelect(self):
        num_records = 6
        self.query_request.set_statement(query_statement)
        result = self.handle.query(self.query_request)
        records = result.get_results()
        self.assertEqual(len(records), num_records)
        for idx in range(num_records):
            self.assertEqual(records[idx], {'fld_sid': 1, 'fld_id': idx})
        self.assertIsNone(result.get_continuation_key())
        self.assertEqual(result.get_read_kb(), num_records + prepare_cost)
        self.assertEqual(result.get_read_units(),
                         num_records * 2 + prepare_cost)
        self.assertEqual(result.get_write_kb(), 0)
        self.assertEqual(result.get_write_units(), 0)

    def testQueryStatementSelectWithLimit(self):
        limit = 3
        self.query_request.set_statement(query_statement).set_limit(limit)
        result = self.handle.query(self.query_request)
        records = result.get_results()
        self.assertEqual(len(records), limit)
        for idx in range(limit):
            self.assertEqual(records[idx], {'fld_sid': 1, 'fld_id': idx})
        self.assertIsNotNone(result.get_continuation_key())
        self.assertEqual(result.get_read_kb(), limit + prepare_cost)
        self.assertEqual(result.get_read_units(), limit * 2 + prepare_cost)
        self.assertEqual(result.get_write_kb(), 0)
        self.assertEqual(result.get_write_units(), 0)

    def testQueryStatementSelectWithMaxReadKb(self):
        max_read_kb = 4
        self.query_request.set_statement(query_statement).set_max_read_kb(
            max_read_kb)
        result = self.handle.query(self.query_request)
        records = result.get_results()
        self.assertEqual(len(records), max_read_kb + 1)
        for idx in range(len(records)):
            self.assertEqual(records[idx], {'fld_sid': 1, 'fld_id': idx})
        self.assertIsNotNone(result.get_continuation_key())
        self.assertEqual(result.get_read_kb(), max_read_kb + prepare_cost + 1)
        self.assertEqual(result.get_read_units(),
                         max_read_kb * 2 + prepare_cost + 2)
        self.assertEqual(result.get_write_kb(), 0)
        self.assertEqual(result.get_write_units(), 0)

    def testQueryStatementSelectWithConsistency(self):
        num_records = 6
        self.query_request.set_statement(query_statement).set_consistency(
            Consistency.ABSOLUTE)
        result = self.handle.query(self.query_request)
        records = result.get_results()
        self.assertEqual(len(records), num_records)
        for idx in range(num_records):
            self.assertEqual(records[idx], {'fld_sid': 1, 'fld_id': idx})
        self.assertIsNone(result.get_continuation_key())
        self.assertEqual(result.get_read_kb(), num_records + prepare_cost)
        self.assertEqual(result.get_read_units(),
                         num_records * 2 + prepare_cost)
        self.assertEqual(result.get_write_kb(), 0)
        self.assertEqual(result.get_write_units(), 0)

    def testQueryStatementSelectWithContinuationKey(self):
        num_records = 6
        limit = 4
        self.query_request.set_statement(query_statement).set_limit(limit)
        count = 0
        while True:
            completed = count * limit
            result = self.handle.query(self.query_request)
            records = result.get_results()
            if completed + limit <= num_records:
                num_get = limit
                read_kb = num_get
                self.assertIsNotNone(result.get_continuation_key())
            else:
                num_get = num_records - completed
                read_kb = (1 if num_get == 0 else num_get)
                self.assertIsNone(result.get_continuation_key())
            self.assertEqual(len(records), num_get)
            for idx in range(num_get):
                self.assertEqual(records[idx], {
                    'fld_sid': 1,
                    'fld_id': completed + idx
                })
            self.assertEqual(result.get_read_kb(), read_kb + prepare_cost)
            self.assertEqual(result.get_read_units(),
                             read_kb * 2 + prepare_cost)
            self.assertEqual(result.get_write_kb(), 0)
            self.assertEqual(result.get_write_units(), 0)
            count += 1
            if result.get_continuation_key() is None:
                break
            self.query_request.set_continuation_key(
                result.get_continuation_key())
        self.assertEqual(count, num_records // limit + 1)

    def testQueryStatementSelectWithDefault(self):
        num_records = 6
        self.query_request.set_statement(query_statement).set_defaults(
            self.handle_config)
        result = self.handle.query(self.query_request)
        records = result.get_results()
        self.assertEqual(len(records), num_records)
        for idx in range(num_records):
            self.assertEqual(records[idx], {'fld_sid': 1, 'fld_id': idx})
        self.assertIsNone(result.get_continuation_key())
        self.assertEqual(result.get_read_kb(), num_records + prepare_cost)
        self.assertEqual(result.get_read_units(),
                         num_records * 2 + prepare_cost)
        self.assertEqual(result.get_write_kb(), 0)
        self.assertEqual(result.get_write_units(), 0)

    def testQueryPreparedStatementUpdate(self):
        fld_sid = 0
        fld_id = 2
        fld_long = 2147483649
        prepared_statement = self.prepare_result_update.get_prepared_statement(
        )
        # update a non-existing row
        prepared_statement.set_variable('$fld_sid',
                                        2).set_variable('$fld_id', 0)
        self.query_request.set_prepared_statement(self.prepare_result_update)
        result = self.handle.query(self.query_request)
        records = result.get_results()
        self.assertEqual(len(records), 1)
        self.assertEqual(records[0], {'NumRowsUpdated': 0})
        self.assertIsNone(result.get_continuation_key())
        self.assertEqual(result.get_read_kb(), 1)
        self.assertEqual(result.get_read_units(), 2)
        self.assertEqual(result.get_write_kb(), 0)
        self.assertEqual(result.get_write_units(), 0)
        # update an existing row
        prepared_statement.set_variable('$fld_sid', fld_sid).set_variable(
            '$fld_id', fld_id)
        self.query_request.set_prepared_statement(self.prepare_result_update)
        result = self.handle.query(self.query_request)
        records = result.get_results()
        self.assertEqual(len(records), 1)
        self.assertEqual(records[0], {'NumRowsUpdated': 1})
        self.assertIsNone(result.get_continuation_key())
        self.assertEqual(result.get_read_kb(), 2)
        self.assertEqual(result.get_read_units(), 4)
        self.assertEqual(result.get_write_kb(), 4)
        self.assertEqual(result.get_write_units(), 4)
        # check the updated row
        prepared_statement = self.prepare_result_select.get_prepared_statement(
        )
        prepared_statement.set_variable('$fld_long', fld_long)
        self.query_request.set_prepared_statement(prepared_statement)
        result = self.handle.query(self.query_request)
        records = result.get_results()
        self.assertEqual(len(records), 1)
        self.assertEqual(records[0], {
            'fld_sid': fld_sid,
            'fld_id': fld_id,
            'fld_long': fld_long
        })
        self.assertIsNone(result.get_continuation_key())
        self.assertEqual(result.get_read_kb(), 1)
        self.assertEqual(result.get_read_units(), 2)
        self.assertEqual(result.get_write_kb(), 0)
        self.assertEqual(result.get_write_units(), 0)

    def testQueryPreparedStatementUpdateWithLimit(self):
        fld_sid = 1
        fld_id = 5
        fld_long = 2147483649
        prepared_statement = self.prepare_result_update.get_prepared_statement(
        )
        prepared_statement.set_variable('$fld_sid', fld_sid).set_variable(
            '$fld_id', fld_id)
        self.query_request.set_prepared_statement(
            self.prepare_result_update).set_limit(1)
        result = self.handle.query(self.query_request)
        records = result.get_results()
        self.assertEqual(len(records), 1)
        self.assertEqual(records[0], {'NumRowsUpdated': 1})
        self.assertIsNone(result.get_continuation_key())
        self.assertEqual(result.get_read_kb(), 2)
        self.assertEqual(result.get_read_units(), 4)
        self.assertEqual(result.get_write_kb(), 4)
        self.assertEqual(result.get_write_units(), 4)
        # check the updated row
        prepared_statement = self.prepare_result_select.get_prepared_statement(
        )
        prepared_statement.set_variable('$fld_long', fld_long)
        self.query_request.set_prepared_statement(prepared_statement)
        result = self.handle.query(self.query_request)
        records = result.get_results()
        self.assertEqual(len(records), 1)
        self.assertEqual(records[0], {
            'fld_sid': fld_sid,
            'fld_id': fld_id,
            'fld_long': fld_long
        })
        self.assertIsNotNone(result.get_continuation_key())
        self.assertEqual(result.get_read_kb(), 1)
        self.assertEqual(result.get_read_units(), 2)
        self.assertEqual(result.get_write_kb(), 0)
        self.assertEqual(result.get_write_units(), 0)

    def testQueryPreparedStatementUpdateWithMaxReadKb(self):
        fld_sid = 0
        fld_id = 1
        fld_long = 2147483649
        # set a small max_read_kb to read a row to update
        prepared_statement = self.prepare_result_update.get_prepared_statement(
        )
        prepared_statement.set_variable('$fld_sid', fld_sid).set_variable(
            '$fld_id', fld_id)
        self.query_request.set_prepared_statement(
            self.prepare_result_update).set_max_read_kb(1)
        self.assertRaises(IllegalArgumentException, self.handle.query,
                          self.query_request)
        # set a enough max_read_kb to read a row to update
        self.query_request.set_max_read_kb(2)
        result = self.handle.query(self.query_request)
        records = result.get_results()
        self.assertEqual(len(records), 1)
        self.assertEqual(records[0], {'NumRowsUpdated': 1})
        self.assertIsNone(result.get_continuation_key())
        self.assertEqual(result.get_read_kb(), 2)
        self.assertEqual(result.get_read_units(), 4)
        self.assertEqual(result.get_write_kb(), 4)
        self.assertEqual(result.get_write_units(), 4)
        # check the updated row
        prepared_statement = self.prepare_result_select.get_prepared_statement(
        )
        prepared_statement.set_variable('$fld_long', fld_long)
        self.query_request.set_prepared_statement(prepared_statement)
        result = self.handle.query(self.query_request)
        records = result.get_results()
        self.assertEqual(len(records), 1)
        self.assertEqual(records[0], {
            'fld_sid': fld_sid,
            'fld_id': fld_id,
            'fld_long': fld_long
        })
        self.assertIsNone(result.get_continuation_key())
        self.assertEqual(result.get_read_kb(), 1)
        self.assertEqual(result.get_read_units(), 2)
        self.assertEqual(result.get_write_kb(), 0)
        self.assertEqual(result.get_write_units(), 0)

    def testQueryPreparedStatementUpdateWithConsistency(self):
        fld_sid = 1
        fld_id = 2
        fld_long = 2147483649
        prepared_statement = self.prepare_result_update.get_prepared_statement(
        )
        prepared_statement.set_variable('$fld_sid', fld_sid).set_variable(
            '$fld_id', fld_id)
        self.query_request.set_prepared_statement(
            self.prepare_result_update).set_consistency(Consistency.ABSOLUTE)
        result = self.handle.query(self.query_request)
        records = result.get_results()
        self.assertEqual(len(records), 1)
        self.assertEqual(records[0], {'NumRowsUpdated': 1})
        self.assertIsNone(result.get_continuation_key())
        self.assertEqual(result.get_read_kb(), 2)
        self.assertEqual(result.get_read_units(), 4)
        self.assertEqual(result.get_write_kb(), 4)
        self.assertEqual(result.get_write_units(), 4)
        # check the updated row
        prepared_statement = self.prepare_result_select.get_prepared_statement(
        )
        prepared_statement.set_variable('$fld_long', fld_long)
        self.query_request.set_prepared_statement(prepared_statement)
        result = self.handle.query(self.query_request)
        records = result.get_results()
        self.assertEqual(len(records), 1)
        self.assertEqual(records[0], {
            'fld_sid': fld_sid,
            'fld_id': fld_id,
            'fld_long': fld_long
        })
        self.assertIsNone(result.get_continuation_key())
        self.assertEqual(result.get_read_kb(), 1)
        self.assertEqual(result.get_read_units(), 2)
        self.assertEqual(result.get_write_kb(), 0)
        self.assertEqual(result.get_write_units(), 0)

    def testQueryPreparedStatementUpdateWithContinuationKey(self):
        fld_sid = 1
        fld_id = 3
        fld_long = 2147483649
        num_records = 1
        limit = 3
        prepared_statement = self.prepare_result_update.get_prepared_statement(
        )
        prepared_statement.set_variable('$fld_sid', fld_sid).set_variable(
            '$fld_id', fld_id)
        self.query_request.set_prepared_statement(
            self.prepare_result_update).set_limit(limit)
        count = 0
        while True:
            completed = count * limit
            result = self.handle.query(self.query_request)
            records = result.get_results()
            self.assertEqual(len(records), 1)
            if completed + limit <= num_records:
                self.assertEqual(records[0], {'NumRowsUpdated': limit})
                read_kb = limit * 2
                write_kb = limit * 4

            else:
                num_update = num_records - completed
                self.assertEqual(records[0], {'NumRowsUpdated': num_update})
                read_kb = (1 if num_update == 0 else num_update * 2)
                write_kb = (0 if num_update == 0 else num_update * 4)
            self.assertIsNone(result.get_continuation_key())
            self.assertEqual(result.get_read_kb(), read_kb)
            self.assertEqual(result.get_read_units(), read_kb * 2)
            self.assertEqual(result.get_write_kb(), write_kb)
            self.assertEqual(result.get_write_units(), write_kb)
            count += 1
            if result.get_continuation_key() is None:
                break
            self.query_request.set_continuation_key(
                result.get_continuation_key())
        self.assertEqual(count, 1)
        # check the updated row
        prepared_statement = self.prepare_result_select.get_prepared_statement(
        )
        prepared_statement.set_variable('$fld_long', fld_long)
        self.query_request.set_prepared_statement(prepared_statement)
        result = self.handle.query(self.query_request)
        records = result.get_results()
        self.assertEqual(len(records), num_records)
        self.assertEqual(records[0], {
            'fld_sid': fld_sid,
            'fld_id': fld_id,
            'fld_long': fld_long
        })
        if limit <= num_records:
            self.assertIsNotNone(result.get_continuation_key())
        else:
            self.assertIsNone(result.get_continuation_key())
        self.assertEqual(result.get_read_kb(), 1)
        self.assertEqual(result.get_read_units(), 2)
        self.assertEqual(result.get_write_kb(), 0)
        self.assertEqual(result.get_write_units(), 0)

    def testQueryPreparedStatementUpdateWithDefault(self):
        fld_sid = 0
        fld_id = 5
        fld_long = 2147483649
        prepared_statement = self.prepare_result_update.get_prepared_statement(
        )
        prepared_statement.set_variable('$fld_sid', fld_sid).set_variable(
            '$fld_id', fld_id)
        self.query_request.set_prepared_statement(
            self.prepare_result_update).set_defaults(self.handle_config)
        result = self.handle.query(self.query_request)
        records = result.get_results()
        self.assertEqual(len(records), 1)
        self.assertEqual(records[0], {'NumRowsUpdated': 1})
        self.assertIsNone(result.get_continuation_key())
        self.assertEqual(result.get_read_kb(), 2)
        self.assertEqual(result.get_read_units(), 4)
        self.assertEqual(result.get_write_kb(), 4)
        self.assertEqual(result.get_write_units(), 4)
        # check the updated row
        prepared_statement = self.prepare_result_select.get_prepared_statement(
        )
        prepared_statement.set_variable('$fld_long', fld_long)
        self.query_request.set_prepared_statement(prepared_statement)
        result = self.handle.query(self.query_request)
        records = result.get_results()
        self.assertEqual(len(records), 1)
        self.assertEqual(records[0], {
            'fld_sid': fld_sid,
            'fld_id': fld_id,
            'fld_long': fld_long
        })
        self.assertIsNone(result.get_continuation_key())
        self.assertEqual(result.get_read_kb(), 1)
        self.assertEqual(result.get_read_units(), 2)
        self.assertEqual(result.get_write_kb(), 0)
        self.assertEqual(result.get_write_units(), 0)

    def testQueryStatementUpdateTTL(self):
        hour_in_milliseconds = 60 * 60 * 1000
        self.query_request.set_statement(
            'UPDATE ' + table_name + ' $u SET TTL CASE WHEN ' +
            'remaining_hours($u) < 0 THEN 3 ELSE remaining_hours($u) + 3 END '
            + 'HOURS WHERE fld_sid = 1 AND fld_id = 3')
        result = self.handle.query(self.query_request)
        ttl = TimeToLive.of_hours(3)
        expect_expiration = ttl.to_expiration_time(int(round(time() * 1000)))
        records = result.get_results()
        self.assertEqual(len(records), 1)
        self.assertEqual(records[0], {'NumRowsUpdated': 1})
        self.assertIsNone(result.get_continuation_key())
        self.assertEqual(result.get_read_kb(), 2 + prepare_cost)
        self.assertEqual(result.get_read_units(), 4 + prepare_cost)
        self.assertEqual(result.get_write_kb(), 3)
        self.assertEqual(result.get_write_units(), 3)
        # check the record after update ttl request succeed
        self.get_request.set_key({'fld_sid': 1, 'fld_id': 3})
        result = self.handle.get(self.get_request)
        actual_expiration = result.get_expiration_time()
        actual_expect_diff = actual_expiration - expect_expiration
        self.assertGreater(actual_expiration, 0)
        self.assertLess(actual_expect_diff, hour_in_milliseconds)
        self.assertEqual(result.get_read_kb(), 1)
        self.assertEqual(result.get_read_units(), 2)
        self.assertEqual(result.get_write_kb(), 0)
        self.assertEqual(result.get_write_units(), 0)