Exemple #1
0
    def test_cluster_settings(self):
        """
        Test connection setting getters and setters
        """
        if PROTOCOL_VERSION >= 3:
            raise unittest.SkipTest("min/max requests and core/max conns aren't used with v3 protocol")

        cluster = Cluster(protocol_version=PROTOCOL_VERSION)

        min_requests_per_connection = cluster.get_min_requests_per_connection(HostDistance.LOCAL)
        self.assertEqual(cassandra.cluster.DEFAULT_MIN_REQUESTS, min_requests_per_connection)
        cluster.set_min_requests_per_connection(HostDistance.LOCAL, min_requests_per_connection + 1)
        self.assertEqual(cluster.get_min_requests_per_connection(HostDistance.LOCAL), min_requests_per_connection + 1)

        max_requests_per_connection = cluster.get_max_requests_per_connection(HostDistance.LOCAL)
        self.assertEqual(cassandra.cluster.DEFAULT_MAX_REQUESTS, max_requests_per_connection)
        cluster.set_max_requests_per_connection(HostDistance.LOCAL, max_requests_per_connection + 1)
        self.assertEqual(cluster.get_max_requests_per_connection(HostDistance.LOCAL), max_requests_per_connection + 1)

        core_connections_per_host = cluster.get_core_connections_per_host(HostDistance.LOCAL)
        self.assertEqual(cassandra.cluster.DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST, core_connections_per_host)
        cluster.set_core_connections_per_host(HostDistance.LOCAL, core_connections_per_host + 1)
        self.assertEqual(cluster.get_core_connections_per_host(HostDistance.LOCAL), core_connections_per_host + 1)

        max_connections_per_host = cluster.get_max_connections_per_host(HostDistance.LOCAL)
        self.assertEqual(cassandra.cluster.DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST, max_connections_per_host)
        cluster.set_max_connections_per_host(HostDistance.LOCAL, max_connections_per_host + 1)
        self.assertEqual(cluster.get_max_connections_per_host(HostDistance.LOCAL), max_connections_per_host + 1)
    def test_cluster_settings(self):
        """
        Test connection setting getters and setters
        """

        cluster = Cluster()

        min_requests_per_connection = cluster.get_min_requests_per_connection(HostDistance.LOCAL)
        self.assertEqual(cassandra.cluster.DEFAULT_MIN_REQUESTS, min_requests_per_connection)
        cluster.set_min_requests_per_connection(HostDistance.LOCAL, min_requests_per_connection + 1)
        self.assertEqual(cluster.get_min_requests_per_connection(HostDistance.LOCAL), min_requests_per_connection + 1)

        max_requests_per_connection = cluster.get_max_requests_per_connection(HostDistance.LOCAL)
        self.assertEqual(cassandra.cluster.DEFAULT_MAX_REQUESTS, max_requests_per_connection)
        cluster.set_max_requests_per_connection(HostDistance.LOCAL, max_requests_per_connection + 1)
        self.assertEqual(cluster.get_max_requests_per_connection(HostDistance.LOCAL), max_requests_per_connection + 1)

        core_connections_per_host = cluster.get_core_connections_per_host(HostDistance.LOCAL)
        self.assertEqual(cassandra.cluster.DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST, core_connections_per_host)
        cluster.set_core_connections_per_host(HostDistance.LOCAL, core_connections_per_host + 1)
        self.assertEqual(cluster.get_core_connections_per_host(HostDistance.LOCAL), core_connections_per_host + 1)

        max_connections_per_host = cluster.get_max_connections_per_host(HostDistance.LOCAL)
        self.assertEqual(cassandra.cluster.DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST, max_connections_per_host)
        cluster.set_max_connections_per_host(HostDistance.LOCAL, max_connections_per_host + 1)
        self.assertEqual(cluster.get_max_connections_per_host(HostDistance.LOCAL), max_connections_per_host + 1)
    def test_cluster_settings(self):
        """
        Test connection setting getters and setters
        """
        if PROTOCOL_VERSION >= 3:
            raise unittest.SkipTest("min/max requests and core/max conns aren't used with v3 protocol")

        cluster = Cluster(protocol_version=PROTOCOL_VERSION)

        min_requests_per_connection = cluster.get_min_requests_per_connection(HostDistance.LOCAL)
        self.assertEqual(cassandra.cluster.DEFAULT_MIN_REQUESTS, min_requests_per_connection)
        cluster.set_min_requests_per_connection(HostDistance.LOCAL, min_requests_per_connection + 1)
        self.assertEqual(cluster.get_min_requests_per_connection(HostDistance.LOCAL), min_requests_per_connection + 1)

        max_requests_per_connection = cluster.get_max_requests_per_connection(HostDistance.LOCAL)
        self.assertEqual(cassandra.cluster.DEFAULT_MAX_REQUESTS, max_requests_per_connection)
        cluster.set_max_requests_per_connection(HostDistance.LOCAL, max_requests_per_connection + 1)
        self.assertEqual(cluster.get_max_requests_per_connection(HostDistance.LOCAL), max_requests_per_connection + 1)

        core_connections_per_host = cluster.get_core_connections_per_host(HostDistance.LOCAL)
        self.assertEqual(cassandra.cluster.DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST, core_connections_per_host)
        cluster.set_core_connections_per_host(HostDistance.LOCAL, core_connections_per_host + 1)
        self.assertEqual(cluster.get_core_connections_per_host(HostDistance.LOCAL), core_connections_per_host + 1)

        max_connections_per_host = cluster.get_max_connections_per_host(HostDistance.LOCAL)
        self.assertEqual(cassandra.cluster.DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST, max_connections_per_host)
        cluster.set_max_connections_per_host(HostDistance.LOCAL, max_connections_per_host + 1)
        self.assertEqual(cluster.get_max_connections_per_host(HostDistance.LOCAL), max_connections_per_host + 1)
Exemple #4
0
def setup(hosts):
    log.info("Using 'cassandra' package from %s", cassandra.__path__)

    cluster = Cluster(hosts)
    cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
    try:
        session = cluster.connect()

        log.debug("Creating keyspace...")
        session.execute("""
            CREATE KEYSPACE %s
            WITH replication = { 'class': 'SimpleStrategy', 'replication_factor': '2' }
            """ % KEYSPACE)

        log.debug("Setting keyspace...")
        session.set_keyspace(KEYSPACE)

        log.debug("Creating table...")
        session.execute("""
            CREATE TABLE %s (
                thekey text,
                col1 text,
                col2 text,
                PRIMARY KEY (thekey, col1)
            )
            """ % TABLE)
    finally:
        cluster.shutdown()
Exemple #5
0
def setup(hosts):

    cluster = Cluster(hosts)
    cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
    session = cluster.connect()

    rows = session.execute("SELECT keyspace_name FROM system.schema_keyspaces")
    if KEYSPACE in [row[0] for row in rows]:
        log.debug("dropping existing keyspace...")
        session.execute("DROP KEYSPACE " + KEYSPACE)

    log.debug("Creating keyspace...")
    session.execute(
        """
        CREATE KEYSPACE %s
        WITH replication = { 'class': 'SimpleStrategy', 'replication_factor': '2' }
        """
        % KEYSPACE
    )

    log.debug("Setting keyspace...")
    session.set_keyspace(KEYSPACE)

    log.debug("Creating table...")
    session.execute(
        """
        CREATE TABLE %s (
            thekey text,
            col1 text,
            col2 text,
            PRIMARY KEY (thekey, col1)
        )
        """
        % TABLE
    )
    def test_cluster_settings(self):
        """
        Test connection setting getters and setters
        """

        cluster = Cluster()

        min_requests_per_connection = cluster.get_min_requests_per_connection(HostDistance.LOCAL)
        self.assertEqual(cassandra.cluster.DEFAULT_MIN_REQUESTS, min_requests_per_connection)
        cluster.set_min_requests_per_connection(HostDistance.LOCAL, min_requests_per_connection + 1)
        self.assertEqual(cluster.get_min_requests_per_connection(HostDistance.LOCAL), min_requests_per_connection + 1)

        max_requests_per_connection = cluster.get_max_requests_per_connection(HostDistance.LOCAL)
        self.assertEqual(cassandra.cluster.DEFAULT_MAX_REQUESTS, max_requests_per_connection)
        cluster.set_max_requests_per_connection(HostDistance.LOCAL, max_requests_per_connection + 1)
        self.assertEqual(cluster.get_max_requests_per_connection(HostDistance.LOCAL), max_requests_per_connection + 1)

        core_connections_per_host = cluster.get_core_connections_per_host(HostDistance.LOCAL)
        self.assertEqual(cassandra.cluster.DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST, core_connections_per_host)
        cluster.set_core_connections_per_host(HostDistance.LOCAL, core_connections_per_host + 1)
        self.assertEqual(cluster.get_core_connections_per_host(HostDistance.LOCAL), core_connections_per_host + 1)

        max_connections_per_host = cluster.get_max_connections_per_host(HostDistance.LOCAL)
        self.assertEqual(cassandra.cluster.DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST, max_connections_per_host)
        cluster.set_max_connections_per_host(HostDistance.LOCAL, max_connections_per_host + 1)
        self.assertEqual(cluster.get_max_connections_per_host(HostDistance.LOCAL), max_connections_per_host + 1)
Exemple #7
0
def setup(hosts):

    cluster = Cluster(hosts)
    cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
    session = cluster.connect()

    rows = session.execute("SELECT keyspace_name FROM system.schema_keyspaces")
    if KEYSPACE in [row[0] for row in rows]:
        log.debug("dropping existing keyspace...")
        session.execute("DROP KEYSPACE " + KEYSPACE)

    log.debug("Creating keyspace...")
    session.execute("""
        CREATE KEYSPACE %s
        WITH replication = { 'class': 'SimpleStrategy', 'replication_factor': '2' }
        """ % KEYSPACE)

    log.debug("Setting keyspace...")
    session.set_keyspace(KEYSPACE)

    log.debug("Creating table...")
    session.execute("""
        CREATE TABLE %s (
            thekey text,
            col1 text,
            col2 text,
            PRIMARY KEY (thekey, col1)
        )
        """ % TABLE)
Exemple #8
0
def setup(hosts):
    log.info("Using 'cassandra' package from %s", cassandra.__path__)

    cluster = Cluster(hosts)
    cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
    try:
        session = cluster.connect()

        log.debug("Creating keyspace...")
        session.execute("""
            CREATE KEYSPACE %s
            WITH replication = { 'class': 'SimpleStrategy', 'replication_factor': '2' }
            """ % KEYSPACE)

        log.debug("Setting keyspace...")
        session.set_keyspace(KEYSPACE)

        log.debug("Creating table...")
        session.execute("""
            CREATE TABLE %s (
                thekey text,
                col1 text,
                col2 text,
                PRIMARY KEY (thekey, col1)
            )
            """ % TABLE)
    finally:
        cluster.shutdown()
    def test_idle_heartbeat(self):
        interval = 1
        cluster = Cluster(protocol_version=PROTOCOL_VERSION, idle_heartbeat_interval=interval)
        if PROTOCOL_VERSION < 3:
            cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
        session = cluster.connect()

        # This test relies on impl details of connection req id management to see if heartbeats 
        # are being sent. May need update if impl is changed
        connection_request_ids = {}
        for h in cluster.get_connection_holders():
            for c in h.get_connections():
                # make sure none are idle (should have startup messages)
                self.assertFalse(c.is_idle)
                with c.lock:
                    connection_request_ids[id(c)] = deque(c.request_ids)  # copy of request ids

        # let two heatbeat intervals pass (first one had startup messages in it)
        time.sleep(2 * interval + interval/10.)

        connections = [c for holders in cluster.get_connection_holders() for c in holders.get_connections()]

        # make sure requests were sent on all connections
        for c in connections:
            expected_ids = connection_request_ids[id(c)]
            expected_ids.rotate(-1)
            with c.lock:
                self.assertListEqual(list(c.request_ids), list(expected_ids))

        # assert idle status
        self.assertTrue(all(c.is_idle for c in connections))

        # send messages on all connections
        statements_and_params = [("SELECT release_version FROM system.local", ())] * len(cluster.metadata.all_hosts())
        results = execute_concurrent(session, statements_and_params)
        for success, result in results:
            self.assertTrue(success)

        # assert not idle status
        self.assertFalse(any(c.is_idle if not c.is_control_connection else False for c in connections))

        # holders include session pools and cc
        holders = cluster.get_connection_holders()
        self.assertIn(cluster.control_connection, holders)
        self.assertEqual(len(holders), len(cluster.metadata.all_hosts()) + 1)  # hosts pools, 1 for cc

        # include additional sessions
        session2 = cluster.connect()

        holders = cluster.get_connection_holders()
        self.assertIn(cluster.control_connection, holders)
        self.assertEqual(len(holders), 2 * len(cluster.metadata.all_hosts()) + 1)  # 2 sessions' hosts pools, 1 for cc

        cluster._idle_heartbeat.stop()
        cluster._idle_heartbeat.join()
        assert_quiescent_pool_state(self, cluster)

        cluster.shutdown()
Exemple #10
0
    def test_idle_heartbeat(self):
        interval = 1
        cluster = Cluster(protocol_version=PROTOCOL_VERSION, idle_heartbeat_interval=interval)
        if PROTOCOL_VERSION < 3:
            cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
        session = cluster.connect()

        # This test relies on impl details of connection req id management to see if heartbeats 
        # are being sent. May need update if impl is changed
        connection_request_ids = {}
        for h in cluster.get_connection_holders():
            for c in h.get_connections():
                # make sure none are idle (should have startup messages)
                self.assertFalse(c.is_idle)
                with c.lock:
                    connection_request_ids[id(c)] = deque(c.request_ids)  # copy of request ids

        # let two heatbeat intervals pass (first one had startup messages in it)
        time.sleep(2 * interval + interval/10.)

        connections = [c for holders in cluster.get_connection_holders() for c in holders.get_connections()]

        # make sure requests were sent on all connections
        for c in connections:
            expected_ids = connection_request_ids[id(c)]
            expected_ids.rotate(-1)
            with c.lock:
                self.assertListEqual(list(c.request_ids), list(expected_ids))

        # assert idle status
        self.assertTrue(all(c.is_idle for c in connections))

        # send messages on all connections
        statements_and_params = [("SELECT release_version FROM system.local", ())] * len(cluster.metadata.all_hosts())
        results = execute_concurrent(session, statements_and_params)
        for success, result in results:
            self.assertTrue(success)

        # assert not idle status
        self.assertFalse(any(c.is_idle if not c.is_control_connection else False for c in connections))

        # holders include session pools and cc
        holders = cluster.get_connection_holders()
        self.assertIn(cluster.control_connection, holders)
        self.assertEqual(len(holders), len(cluster.metadata.all_hosts()) + 1)  # hosts pools, 1 for cc

        # include additional sessions
        session2 = cluster.connect()

        holders = cluster.get_connection_holders()
        self.assertIn(cluster.control_connection, holders)
        self.assertEqual(len(holders), 2 * len(cluster.metadata.all_hosts()) + 1)  # 2 sessions' hosts pools, 1 for cc

        cluster._idle_heartbeat.stop()
        cluster._idle_heartbeat.join()
        assert_quiescent_pool_state(self, cluster)

        cluster.shutdown()
Exemple #11
0
class SerialConsistencyTests(unittest.TestCase):
    def setUp(self):
        if PROTOCOL_VERSION < 2:
            raise unittest.SkipTest(
                "Protocol 2.0+ is required for Serial Consistency, currently testing against %r"
                % (PROTOCOL_VERSION, ))

        self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
        if PROTOCOL_VERSION < 3:
            self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
        self.session = self.cluster.connect()

    def tearDown(self):
        self.cluster.shutdown()

    def test_conditional_update(self):
        self.session.execute("INSERT INTO test3rf.test (k, v) VALUES (0, 0)")
        statement = SimpleStatement(
            "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=1",
            serial_consistency_level=ConsistencyLevel.SERIAL)
        result = self.session.execute(statement)
        self.assertEqual(1, len(result))
        self.assertFalse(result[0].applied)

        statement = SimpleStatement(
            "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=0",
            serial_consistency_level=ConsistencyLevel.SERIAL)
        result = self.session.execute(statement)
        self.assertEqual(1, len(result))
        self.assertTrue(result[0].applied)

    def test_conditional_update_with_prepared_statements(self):
        self.session.execute("INSERT INTO test3rf.test (k, v) VALUES (0, 0)")
        statement = self.session.prepare(
            "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=2")

        statement.serial_consistency_level = ConsistencyLevel.SERIAL
        result = self.session.execute(statement)
        self.assertEqual(1, len(result))
        self.assertFalse(result[0].applied)

        statement = self.session.prepare(
            "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=0")
        bound = statement.bind(())
        bound.serial_consistency_level = ConsistencyLevel.SERIAL
        result = self.session.execute(statement)
        self.assertEqual(1, len(result))
        self.assertTrue(result[0].applied)

    def test_bad_consistency_level(self):
        statement = SimpleStatement("foo")
        self.assertRaises(ValueError, setattr, statement,
                          'serial_consistency_level', ConsistencyLevel.ONE)
        self.assertRaises(ValueError,
                          SimpleStatement,
                          'foo',
                          serial_consistency_level=ConsistencyLevel.ONE)
class SerialConsistencyTests(unittest.TestCase):
    def setUp(self):
        if PROTOCOL_VERSION < 2:
            raise unittest.SkipTest(
                "Protocol 2.0+ is required for Serial Consistency, currently testing against %r"
                % (PROTOCOL_VERSION,))

        self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
        if PROTOCOL_VERSION < 3:
            self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
        self.session = self.cluster.connect()

    def tearDown(self):
        self.cluster.shutdown()

    def test_conditional_update(self):
        self.session.execute("INSERT INTO test3rf.test (k, v) VALUES (0, 0)")
        statement = SimpleStatement(
            "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=1",
            serial_consistency_level=ConsistencyLevel.SERIAL)
        result = self.session.execute(statement)
        self.assertEqual(1, len(result))
        self.assertFalse(result[0].applied)

        statement = SimpleStatement(
            "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=0",
            serial_consistency_level=ConsistencyLevel.SERIAL)
        result = self.session.execute(statement)
        self.assertEqual(1, len(result))
        self.assertTrue(result[0].applied)

    def test_conditional_update_with_prepared_statements(self):
        self.session.execute("INSERT INTO test3rf.test (k, v) VALUES (0, 0)")
        statement = self.session.prepare(
            "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=2")

        statement.serial_consistency_level = ConsistencyLevel.SERIAL
        result = self.session.execute(statement)
        self.assertEqual(1, len(result))
        self.assertFalse(result[0].applied)

        statement = self.session.prepare(
            "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=0")
        bound = statement.bind(())
        bound.serial_consistency_level = ConsistencyLevel.SERIAL
        result = self.session.execute(statement)
        self.assertEqual(1, len(result))
        self.assertTrue(result[0].applied)

    def test_bad_consistency_level(self):
        statement = SimpleStatement("foo")
        self.assertRaises(ValueError, setattr, statement, 'serial_consistency_level', ConsistencyLevel.ONE)
        self.assertRaises(ValueError, SimpleStatement, 'foo', serial_consistency_level=ConsistencyLevel.ONE)
Exemple #13
0
def connect(seeds, keyspace, datacenter=None, port=9042):
    from cassandra.io.libevreactor import LibevConnection
    from cassandra.cluster import Cluster
    from cassandra.policies import DCAwareRoundRobinPolicy, RetryPolicy, ExponentialReconnectionPolicy

    class CustomRetryPolicy(RetryPolicy):

        def on_write_timeout(self, query, consistency, write_type,
                             required_responses, received_responses, retry_num):

            # retry at most 5 times regardless of query type
            if retry_num >= 5:
                return (self.RETHROW, None)

            return (self.RETRY, consistency)


    load_balancing_policy = None
    if datacenter:
        # If you are using multiple datacenters it's important to use
        # the DCAwareRoundRobinPolicy. If not then the client will
        # make cross DC connections. This defaults to round robin
        # which means round robin across all nodes irrespective of
        # data center.
        load_balancing_policy = DCAwareRoundRobinPolicy(local_dc=datacenter)

    cluster = Cluster(contact_points=seeds,
                      port=port,
                      auth_provider=auth_provider,
                      default_retry_policy=CustomRetryPolicy(),
                      reconnection_policy=ExponentialReconnectionPolicy(1, 60),
                      load_balancing_policy=load_balancing_policy)

    cluster.connection_class = LibevConnection
    cluster.set_core_connections_per_host(0, 3) # local connections
    cluster.set_core_connections_per_host(1, 0) # remote connections
    cluster.control_connection_timeout = 10.0
    cluster.compression = False
    session = cluster.connect(keyspace)
    return session
class SerialConsistencyTests(unittest.TestCase):
    def setUp(self):
        if PROTOCOL_VERSION < 2:
            raise unittest.SkipTest(
                "Protocol 2.0+ is required for Serial Consistency, currently testing against %r"
                % (PROTOCOL_VERSION,))

        self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
        if PROTOCOL_VERSION < 3:
            self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
        self.session = self.cluster.connect()

    def tearDown(self):
        self.cluster.shutdown()

    def test_conditional_update(self):
        self.session.execute("INSERT INTO test3rf.test (k, v) VALUES (0, 0)")
        statement = SimpleStatement(
            "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=1",
            serial_consistency_level=ConsistencyLevel.SERIAL)
        # crazy test, but PYTHON-299
        # TODO: expand to check more parameters get passed to statement, and on to messages
        self.assertEqual(statement.serial_consistency_level, ConsistencyLevel.SERIAL)
        future = self.session.execute_async(statement)
        result = future.result()
        self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.SERIAL)
        self.assertTrue(result)
        self.assertFalse(result[0].applied)

        statement = SimpleStatement(
            "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=0",
            serial_consistency_level=ConsistencyLevel.LOCAL_SERIAL)
        self.assertEqual(statement.serial_consistency_level, ConsistencyLevel.LOCAL_SERIAL)
        future = self.session.execute_async(statement)
        result = future.result()
        self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.LOCAL_SERIAL)
        self.assertTrue(result)
        self.assertTrue(result[0].applied)

    def test_conditional_update_with_prepared_statements(self):
        self.session.execute("INSERT INTO test3rf.test (k, v) VALUES (0, 0)")
        statement = self.session.prepare(
            "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=2")

        statement.serial_consistency_level = ConsistencyLevel.SERIAL
        future = self.session.execute_async(statement)
        result = future.result()
        self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.SERIAL)
        self.assertTrue(result)
        self.assertFalse(result[0].applied)

        statement = self.session.prepare(
            "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=0")
        bound = statement.bind(())
        bound.serial_consistency_level = ConsistencyLevel.LOCAL_SERIAL
        future = self.session.execute_async(bound)
        result = future.result()
        self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.LOCAL_SERIAL)
        self.assertTrue(result)
        self.assertTrue(result[0].applied)

    def test_conditional_update_with_batch_statements(self):
        self.session.execute("INSERT INTO test3rf.test (k, v) VALUES (0, 0)")
        statement = BatchStatement(serial_consistency_level=ConsistencyLevel.SERIAL)
        statement.add("UPDATE test3rf.test SET v=1 WHERE k=0 IF v=1")
        self.assertEqual(statement.serial_consistency_level, ConsistencyLevel.SERIAL)
        future = self.session.execute_async(statement)
        result = future.result()
        self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.SERIAL)
        self.assertTrue(result)
        self.assertFalse(result[0].applied)

        statement = BatchStatement(serial_consistency_level=ConsistencyLevel.LOCAL_SERIAL)
        statement.add("UPDATE test3rf.test SET v=1 WHERE k=0 IF v=0")
        self.assertEqual(statement.serial_consistency_level, ConsistencyLevel.LOCAL_SERIAL)
        future = self.session.execute_async(statement)
        result = future.result()
        self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.LOCAL_SERIAL)
        self.assertTrue(result)
        self.assertTrue(result[0].applied)

    def test_bad_consistency_level(self):
        statement = SimpleStatement("foo")
        self.assertRaises(ValueError, setattr, statement, 'serial_consistency_level', ConsistencyLevel.ONE)
        self.assertRaises(ValueError, SimpleStatement, 'foo', serial_consistency_level=ConsistencyLevel.ONE)
Exemple #15
0
class BatchStatementTests(BasicSharedKeyspaceUnitTestCase):

    def setUp(self):
        if PROTOCOL_VERSION < 2:
            raise unittest.SkipTest(
                "Protocol 2.0+ is required for BATCH operations, currently testing against %r"
                % (PROTOCOL_VERSION,))

        self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
        if PROTOCOL_VERSION < 3:
            self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
        self.session = self.cluster.connect()

    def tearDown(self):
        self.cluster.shutdown()

    def confirm_results(self):
        keys = set()
        values = set()
        # Assuming the test data is inserted at default CL.ONE, we need ALL here to guarantee we see
        # everything inserted
        results = self.session.execute(SimpleStatement("SELECT * FROM test3rf.test",
                                                       consistency_level=ConsistencyLevel.ALL))
        for result in results:
            keys.add(result.k)
            values.add(result.v)

        self.assertEqual(set(range(10)), keys, msg=results)
        self.assertEqual(set(range(10)), values, msg=results)

    def test_string_statements(self):
        batch = BatchStatement(BatchType.LOGGED)
        for i in range(10):
            batch.add("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", (i, i))

        self.session.execute(batch)
        self.session.execute_async(batch).result()
        self.confirm_results()

    def test_simple_statements(self):
        batch = BatchStatement(BatchType.LOGGED)
        for i in range(10):
            batch.add(SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)"), (i, i))

        self.session.execute(batch)
        self.session.execute_async(batch).result()
        self.confirm_results()

    def test_prepared_statements(self):
        prepared = self.session.prepare("INSERT INTO test3rf.test (k, v) VALUES (?, ?)")

        batch = BatchStatement(BatchType.LOGGED)
        for i in range(10):
            batch.add(prepared, (i, i))

        self.session.execute(batch)
        self.session.execute_async(batch).result()
        self.confirm_results()

    def test_bound_statements(self):
        prepared = self.session.prepare("INSERT INTO test3rf.test (k, v) VALUES (?, ?)")

        batch = BatchStatement(BatchType.LOGGED)
        for i in range(10):
            batch.add(prepared.bind((i, i)))

        self.session.execute(batch)
        self.session.execute_async(batch).result()
        self.confirm_results()

    def test_no_parameters(self):
        batch = BatchStatement(BatchType.LOGGED)
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (0, 0)")
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (1, 1)", ())
        batch.add(SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (2, 2)"))
        batch.add(SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (3, 3)"), ())

        prepared = self.session.prepare("INSERT INTO test3rf.test (k, v) VALUES (4, 4)")
        batch.add(prepared)
        batch.add(prepared, ())
        batch.add(prepared.bind([]))
        batch.add(prepared.bind([]), ())

        batch.add("INSERT INTO test3rf.test (k, v) VALUES (5, 5)", ())
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (6, 6)", ())
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (7, 7)", ())
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (8, 8)", ())
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (9, 9)", ())

        self.assertRaises(ValueError, batch.add, prepared.bind([]), (1))
        self.assertRaises(ValueError, batch.add, prepared.bind([]), (1, 2))
        self.assertRaises(ValueError, batch.add, prepared.bind([]), (1, 2, 3))

        self.session.execute(batch)
        self.confirm_results()

    def test_no_parameters_many_times(self):
        for i in range(1000):
            self.test_no_parameters()
            self.session.execute("TRUNCATE test3rf.test")

    def test_unicode(self):
        ddl = '''
            CREATE TABLE test3rf.testtext (
                k int PRIMARY KEY,
                v text )'''
        self.session.execute(ddl)
        unicode_text = u'Fran\u00E7ois'
        query = u'INSERT INTO test3rf.testtext (k, v) VALUES (%s, %s)'
        try:
            batch = BatchStatement(BatchType.LOGGED)
            batch.add(u"INSERT INTO test3rf.testtext (k, v) VALUES (%s, %s)", (0, unicode_text))
            self.session.execute(batch)
        finally:
            self.session.execute("DROP TABLE test3rf.testtext")
Exemple #16
0
def teardown(hosts):
    cluster = Cluster(hosts)
    cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
    session = cluster.connect()
    session.execute("DROP KEYSPACE " + KEYSPACE)
class BatchStatementTests(BasicSharedKeyspaceUnitTestCase):

    def setUp(self):
        if PROTOCOL_VERSION < 2:
            raise unittest.SkipTest(
                "Protocol 2.0+ is required for BATCH operations, currently testing against %r"
                % (PROTOCOL_VERSION,))

        self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
        if PROTOCOL_VERSION < 3:
            self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
        self.session = self.cluster.connect(wait_for_all_pools=True)

    def tearDown(self):
        self.cluster.shutdown()

    def confirm_results(self):
        keys = set()
        values = set()
        # Assuming the test data is inserted at default CL.ONE, we need ALL here to guarantee we see
        # everything inserted
        results = self.session.execute(SimpleStatement("SELECT * FROM test3rf.test",
                                                       consistency_level=ConsistencyLevel.ALL))
        for result in results:
            keys.add(result.k)
            values.add(result.v)

        self.assertEqual(set(range(10)), keys, msg=results)
        self.assertEqual(set(range(10)), values, msg=results)

    def test_string_statements(self):
        batch = BatchStatement(BatchType.LOGGED)
        for i in range(10):
            batch.add("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", (i, i))

        self.session.execute(batch)
        self.session.execute_async(batch).result()
        self.confirm_results()

    def test_simple_statements(self):
        batch = BatchStatement(BatchType.LOGGED)
        for i in range(10):
            batch.add(SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)"), (i, i))

        self.session.execute(batch)
        self.session.execute_async(batch).result()
        self.confirm_results()

    def test_prepared_statements(self):
        prepared = self.session.prepare("INSERT INTO test3rf.test (k, v) VALUES (?, ?)")

        batch = BatchStatement(BatchType.LOGGED)
        for i in range(10):
            batch.add(prepared, (i, i))

        self.session.execute(batch)
        self.session.execute_async(batch).result()
        self.confirm_results()

    def test_bound_statements(self):
        prepared = self.session.prepare("INSERT INTO test3rf.test (k, v) VALUES (?, ?)")

        batch = BatchStatement(BatchType.LOGGED)
        for i in range(10):
            batch.add(prepared.bind((i, i)))

        self.session.execute(batch)
        self.session.execute_async(batch).result()
        self.confirm_results()

    def test_no_parameters(self):
        batch = BatchStatement(BatchType.LOGGED)
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (0, 0)")
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (1, 1)", ())
        batch.add(SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (2, 2)"))
        batch.add(SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (3, 3)"), ())

        prepared = self.session.prepare("INSERT INTO test3rf.test (k, v) VALUES (4, 4)")
        batch.add(prepared)
        batch.add(prepared, ())
        batch.add(prepared.bind([]))
        batch.add(prepared.bind([]), ())

        batch.add("INSERT INTO test3rf.test (k, v) VALUES (5, 5)", ())
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (6, 6)", ())
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (7, 7)", ())
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (8, 8)", ())
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (9, 9)", ())

        self.assertRaises(ValueError, batch.add, prepared.bind([]), (1))
        self.assertRaises(ValueError, batch.add, prepared.bind([]), (1, 2))
        self.assertRaises(ValueError, batch.add, prepared.bind([]), (1, 2, 3))

        self.session.execute(batch)
        self.confirm_results()

    def test_unicode(self):
        ddl = '''
            CREATE TABLE test3rf.testtext (
                k int PRIMARY KEY,
                v text )'''
        self.session.execute(ddl)
        unicode_text = u'Fran\u00E7ois'
        query = u'INSERT INTO test3rf.testtext (k, v) VALUES (%s, %s)'
        try:
            batch = BatchStatement(BatchType.LOGGED)
            batch.add(u"INSERT INTO test3rf.testtext (k, v) VALUES (%s, %s)", (0, unicode_text))
            self.session.execute(batch)
        finally:
            self.session.execute("DROP TABLE test3rf.testtext")

    def test_too_many_statements(self):
        max_statements = 0xFFFF
        ss = SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (0, 0)")
        b = BatchStatement(batch_type=BatchType.UNLOGGED, consistency_level=ConsistencyLevel.ONE)

        # max works
        b.add_all([ss] * max_statements, [None] * max_statements)
        self.session.execute(b)

        # max + 1 raises
        self.assertRaises(ValueError, b.add, ss)

        # also would have bombed trying to encode
        b._statements_and_parameters.append((False, ss.query_string, ()))
        self.assertRaises(NoHostAvailable, self.session.execute, b)
class BatchStatementTests(unittest.TestCase):
    def setUp(self):
        if PROTOCOL_VERSION < 2:
            raise unittest.SkipTest(
                "Protocol 2.0+ is required for BATCH operations, currently testing against %r"
                % (PROTOCOL_VERSION, ))

        self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
        if PROTOCOL_VERSION < 3:
            self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
        self.session = self.cluster.connect()

        self.session.execute("TRUNCATE test3rf.test")

    def tearDown(self):
        self.cluster.shutdown()

    def confirm_results(self):
        keys = set()
        values = set()
        # Assuming the test data is inserted at default CL.ONE, we need ALL here to guarantee we see
        # everything inserted
        results = self.session.execute(
            SimpleStatement("SELECT * FROM test3rf.test",
                            consistency_level=ConsistencyLevel.ALL))
        for result in results:
            keys.add(result.k)
            values.add(result.v)

        self.assertEqual(set(range(10)), keys, msg=results)
        self.assertEqual(set(range(10)), values, msg=results)

    def test_string_statements(self):
        batch = BatchStatement(BatchType.LOGGED)
        for i in range(10):
            batch.add("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)",
                      (i, i))

        self.session.execute(batch)
        self.session.execute_async(batch).result()
        self.confirm_results()

    def test_simple_statements(self):
        batch = BatchStatement(BatchType.LOGGED)
        for i in range(10):
            batch.add(
                SimpleStatement(
                    "INSERT INTO test3rf.test (k, v) VALUES (%s, %s)"), (i, i))

        self.session.execute(batch)
        self.session.execute_async(batch).result()
        self.confirm_results()

    def test_prepared_statements(self):
        prepared = self.session.prepare(
            "INSERT INTO test3rf.test (k, v) VALUES (?, ?)")

        batch = BatchStatement(BatchType.LOGGED)
        for i in range(10):
            batch.add(prepared, (i, i))

        self.session.execute(batch)
        self.session.execute_async(batch).result()
        self.confirm_results()

    def test_bound_statements(self):
        prepared = self.session.prepare(
            "INSERT INTO test3rf.test (k, v) VALUES (?, ?)")

        batch = BatchStatement(BatchType.LOGGED)
        for i in range(10):
            batch.add(prepared.bind((i, i)))

        self.session.execute(batch)
        self.session.execute_async(batch).result()
        self.confirm_results()

    def test_no_parameters(self):
        batch = BatchStatement(BatchType.LOGGED)
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (0, 0)")
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (1, 1)", ())
        batch.add(
            SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (2, 2)"))
        batch.add(
            SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (3, 3)"),
            ())

        prepared = self.session.prepare(
            "INSERT INTO test3rf.test (k, v) VALUES (4, 4)")
        batch.add(prepared)
        batch.add(prepared, ())
        batch.add(prepared.bind([]))
        batch.add(prepared.bind([]), ())

        batch.add("INSERT INTO test3rf.test (k, v) VALUES (5, 5)", ())
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (6, 6)", ())
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (7, 7)", ())
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (8, 8)", ())
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (9, 9)", ())

        self.assertRaises(ValueError, batch.add, prepared.bind([]), (1))
        self.assertRaises(ValueError, batch.add, prepared.bind([]), (1, 2))
        self.assertRaises(ValueError, batch.add, prepared.bind([]), (1, 2, 3))

        self.session.execute(batch)
        self.confirm_results()

    def test_no_parameters_many_times(self):
        for i in range(1000):
            self.test_no_parameters()
            self.session.execute("TRUNCATE test3rf.test")
Exemple #19
0
class QueryPagingTests(unittest.TestCase):
    def setUp(self):
        if PROTOCOL_VERSION < 2:
            raise unittest.SkipTest(
                "Protocol 2.0+ is required for Paging state, currently testing against %r"
                % (PROTOCOL_VERSION, ))

        self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
        if PROTOCOL_VERSION < 3:
            self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
        self.session = self.cluster.connect(wait_for_all_pools=True)
        self.session.execute("TRUNCATE test3rf.test")

    def tearDown(self):
        self.cluster.shutdown()

    def test_paging(self):
        statements_and_params = zip(
            cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
            [(i, ) for i in range(100)])
        execute_concurrent(self.session, list(statements_and_params))

        prepared = self.session.prepare("SELECT * FROM test3rf.test")

        for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
            self.session.default_fetch_size = fetch_size
            self.assertEqual(
                100,
                len(list(self.session.execute("SELECT * FROM test3rf.test"))))

            statement = SimpleStatement("SELECT * FROM test3rf.test")
            self.assertEqual(100, len(list(self.session.execute(statement))))

            self.assertEqual(100, len(list(self.session.execute(prepared))))

    def test_paging_state(self):
        """
        Test to validate paging state api
        @since 3.7.0
        @jira_ticket PYTHON-200
        @expected_result paging state should returned should be accurate, and allow for queries to be resumed.

        @test_category queries
        """
        statements_and_params = zip(
            cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
            [(i, ) for i in range(100)])
        execute_concurrent(self.session, list(statements_and_params))

        list_all_results = []
        self.session.default_fetch_size = 3

        result_set = self.session.execute("SELECT * FROM test3rf.test")
        while (result_set.has_more_pages):
            for row in result_set.current_rows:
                self.assertNotIn(row, list_all_results)
            list_all_results.extend(result_set.current_rows)
            page_state = result_set.paging_state
            result_set = self.session.execute("SELECT * FROM test3rf.test",
                                              paging_state=page_state)

        if (len(result_set.current_rows) > 0):
            list_all_results.append(result_set.current_rows)
        self.assertEqual(len(list_all_results), 100)

    def test_paging_verify_writes(self):
        statements_and_params = zip(
            cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
            [(i, ) for i in range(100)])
        execute_concurrent(self.session, statements_and_params)

        prepared = self.session.prepare("SELECT * FROM test3rf.test")

        for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
            self.session.default_fetch_size = fetch_size
            results = self.session.execute("SELECT * FROM test3rf.test")
            result_array = set()
            result_set = set()
            for result in results:
                result_array.add(result.k)
                result_set.add(result.v)

            self.assertEqual(set(range(100)), result_array)
            self.assertEqual(set([0]), result_set)

            statement = SimpleStatement("SELECT * FROM test3rf.test")
            results = self.session.execute(statement)
            result_array = set()
            result_set = set()
            for result in results:
                result_array.add(result.k)
                result_set.add(result.v)

            self.assertEqual(set(range(100)), result_array)
            self.assertEqual(set([0]), result_set)

            results = self.session.execute(prepared)
            result_array = set()
            result_set = set()
            for result in results:
                result_array.add(result.k)
                result_set.add(result.v)

            self.assertEqual(set(range(100)), result_array)
            self.assertEqual(set([0]), result_set)

    def test_paging_verify_with_composite_keys(self):
        ddl = '''
            CREATE TABLE test3rf.test_paging_verify_2 (
                k1 int,
                k2 int,
                v int,
                PRIMARY KEY(k1, k2)
            )'''
        self.session.execute(ddl)

        statements_and_params = zip(
            cycle([
                "INSERT INTO test3rf.test_paging_verify_2 "
                "(k1, k2, v) VALUES (0, %s, %s)"
            ]), [(i, i + 1) for i in range(100)])
        execute_concurrent(self.session, statements_and_params)

        prepared = self.session.prepare(
            "SELECT * FROM test3rf.test_paging_verify_2")

        for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
            self.session.default_fetch_size = fetch_size
            results = self.session.execute(
                "SELECT * FROM test3rf.test_paging_verify_2")
            result_array = []
            value_array = []
            for result in results:
                result_array.append(result.k2)
                value_array.append(result.v)

            self.assertSequenceEqual(range(100), result_array)
            self.assertSequenceEqual(range(1, 101), value_array)

            statement = SimpleStatement(
                "SELECT * FROM test3rf.test_paging_verify_2")
            results = self.session.execute(statement)
            result_array = []
            value_array = []
            for result in results:
                result_array.append(result.k2)
                value_array.append(result.v)

            self.assertSequenceEqual(range(100), result_array)
            self.assertSequenceEqual(range(1, 101), value_array)

            results = self.session.execute(prepared)
            result_array = []
            value_array = []
            for result in results:
                result_array.append(result.k2)
                value_array.append(result.v)

            self.assertSequenceEqual(range(100), result_array)
            self.assertSequenceEqual(range(1, 101), value_array)

    def test_async_paging(self):
        statements_and_params = zip(
            cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
            [(i, ) for i in range(100)])
        execute_concurrent(self.session, list(statements_and_params))

        prepared = self.session.prepare("SELECT * FROM test3rf.test")

        for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
            self.session.default_fetch_size = fetch_size
            self.assertEqual(
                100,
                len(
                    list(
                        self.session.execute_async(
                            "SELECT * FROM test3rf.test").result())))

            statement = SimpleStatement("SELECT * FROM test3rf.test")
            self.assertEqual(
                100, len(list(self.session.execute_async(statement).result())))

            self.assertEqual(
                100, len(list(self.session.execute_async(prepared).result())))

    def test_async_paging_verify_writes(self):
        ddl = '''
            CREATE TABLE test3rf.test_async_paging_verify (
                k1 int,
                k2 int,
                v int,
                PRIMARY KEY(k1, k2)
            )'''
        self.session.execute(ddl)

        statements_and_params = zip(
            cycle([
                "INSERT INTO test3rf.test_async_paging_verify "
                "(k1, k2, v) VALUES (0, %s, %s)"
            ]), [(i, i + 1) for i in range(100)])
        execute_concurrent(self.session, statements_and_params)

        prepared = self.session.prepare(
            "SELECT * FROM test3rf.test_async_paging_verify")

        for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
            self.session.default_fetch_size = fetch_size
            results = self.session.execute_async(
                "SELECT * FROM test3rf.test_async_paging_verify").result()
            result_array = []
            value_array = []
            for result in results:
                result_array.append(result.k2)
                value_array.append(result.v)

            self.assertSequenceEqual(range(100), result_array)
            self.assertSequenceEqual(range(1, 101), value_array)

            statement = SimpleStatement(
                "SELECT * FROM test3rf.test_async_paging_verify")
            results = self.session.execute_async(statement).result()
            result_array = []
            value_array = []
            for result in results:
                result_array.append(result.k2)
                value_array.append(result.v)

            self.assertSequenceEqual(range(100), result_array)
            self.assertSequenceEqual(range(1, 101), value_array)

            results = self.session.execute_async(prepared).result()
            result_array = []
            value_array = []
            for result in results:
                result_array.append(result.k2)
                value_array.append(result.v)

            self.assertSequenceEqual(range(100), result_array)
            self.assertSequenceEqual(range(1, 101), value_array)

    def test_paging_callbacks(self):
        """
        Test to validate callback api
        @since 3.9.0
        @jira_ticket PYTHON-733
        @expected_result callbacks shouldn't be called twice per message
        and the fetch_size should be handled in a transparent way to the user

        @test_category queries
        """
        statements_and_params = zip(
            cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
            [(i, ) for i in range(100)])
        execute_concurrent(self.session, list(statements_and_params))

        prepared = self.session.prepare("SELECT * FROM test3rf.test")

        for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
            self.session.default_fetch_size = fetch_size
            future = self.session.execute_async("SELECT * FROM test3rf.test",
                                                timeout=20)

            event = Event()
            counter = count()
            number_of_calls = count()

            def handle_page(rows, future, counter, number_of_calls):
                next(number_of_calls)
                for row in rows:
                    next(counter)

                if future.has_more_pages:
                    future.start_fetching_next_page()
                else:
                    event.set()

            def handle_error(err):
                event.set()
                self.fail(err)

            future.add_callbacks(callback=handle_page,
                                 callback_args=(future, counter,
                                                number_of_calls),
                                 errback=handle_error)
            event.wait()
            self.assertEqual(next(number_of_calls), 100 // fetch_size + 1)
            self.assertEqual(next(counter), 100)

            # simple statement
            future = self.session.execute_async(
                SimpleStatement("SELECT * FROM test3rf.test"), timeout=20)
            event.clear()
            counter = count()
            number_of_calls = count()

            future.add_callbacks(callback=handle_page,
                                 callback_args=(future, counter,
                                                number_of_calls),
                                 errback=handle_error)
            event.wait()
            self.assertEqual(next(number_of_calls), 100 // fetch_size + 1)
            self.assertEqual(next(counter), 100)

            # prepared statement
            future = self.session.execute_async(prepared, timeout=20)
            event.clear()
            counter = count()
            number_of_calls = count()

            future.add_callbacks(callback=handle_page,
                                 callback_args=(future, counter,
                                                number_of_calls),
                                 errback=handle_error)
            event.wait()
            self.assertEqual(next(number_of_calls), 100 // fetch_size + 1)
            self.assertEqual(next(counter), 100)

    def test_concurrent_with_paging(self):
        statements_and_params = zip(
            cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
            [(i, ) for i in range(100)])
        execute_concurrent(self.session, list(statements_and_params))

        prepared = self.session.prepare("SELECT * FROM test3rf.test")

        for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
            self.session.default_fetch_size = fetch_size
            results = execute_concurrent_with_args(self.session, prepared,
                                                   [None] * 10)
            self.assertEqual(10, len(results))
            for (success, result) in results:
                self.assertTrue(success)
                self.assertEqual(100, len(list(result)))

    def test_fetch_size(self):
        """
        Ensure per-statement fetch_sizes override the default fetch size.
        """
        statements_and_params = zip(
            cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
            [(i, ) for i in range(100)])
        execute_concurrent(self.session, list(statements_and_params))

        prepared = self.session.prepare("SELECT * FROM test3rf.test")

        self.session.default_fetch_size = 10
        result = self.session.execute(prepared, [])
        self.assertTrue(result.has_more_pages)

        self.session.default_fetch_size = 2000
        result = self.session.execute(prepared, [])
        self.assertFalse(result.has_more_pages)

        self.session.default_fetch_size = None
        result = self.session.execute(prepared, [])
        self.assertFalse(result.has_more_pages)

        self.session.default_fetch_size = 10

        prepared.fetch_size = 2000
        result = self.session.execute(prepared, [])
        self.assertFalse(result.has_more_pages)

        prepared.fetch_size = None
        result = self.session.execute(prepared, [])
        self.assertFalse(result.has_more_pages)

        prepared.fetch_size = 10
        result = self.session.execute(prepared, [])
        self.assertTrue(result.has_more_pages)

        prepared.fetch_size = 2000
        bound = prepared.bind([])
        result = self.session.execute(bound, [])
        self.assertFalse(result.has_more_pages)

        prepared.fetch_size = None
        bound = prepared.bind([])
        result = self.session.execute(bound, [])
        self.assertFalse(result.has_more_pages)

        prepared.fetch_size = 10
        bound = prepared.bind([])
        result = self.session.execute(bound, [])
        self.assertTrue(result.has_more_pages)

        bound.fetch_size = 2000
        result = self.session.execute(bound, [])
        self.assertFalse(result.has_more_pages)

        bound.fetch_size = None
        result = self.session.execute(bound, [])
        self.assertFalse(result.has_more_pages)

        bound.fetch_size = 10
        result = self.session.execute(bound, [])
        self.assertTrue(result.has_more_pages)

        s = SimpleStatement("SELECT * FROM test3rf.test", fetch_size=None)
        result = self.session.execute(s, [])
        self.assertFalse(result.has_more_pages)

        s = SimpleStatement("SELECT * FROM test3rf.test")
        result = self.session.execute(s, [])
        self.assertTrue(result.has_more_pages)

        s = SimpleStatement("SELECT * FROM test3rf.test")
        s.fetch_size = None
        result = self.session.execute(s, [])
        self.assertFalse(result.has_more_pages)
class QueryPagingTests(unittest.TestCase):

    def setUp(self):
        if PROTOCOL_VERSION < 2:
            raise unittest.SkipTest(
                "Protocol 2.0+ is required for Paging state, currently testing against %r"
                % (PROTOCOL_VERSION,))

        self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
        if PROTOCOL_VERSION < 3:
            self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
        self.session = self.cluster.connect(wait_for_all_pools=True)
        self.session.execute("TRUNCATE test3rf.test")

    def tearDown(self):
        self.cluster.shutdown()

    def test_paging(self):
        statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
                                    [(i, ) for i in range(100)])
        execute_concurrent(self.session, list(statements_and_params))

        prepared = self.session.prepare("SELECT * FROM test3rf.test")

        for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
            self.session.default_fetch_size = fetch_size
            self.assertEqual(100, len(list(self.session.execute("SELECT * FROM test3rf.test"))))

            statement = SimpleStatement("SELECT * FROM test3rf.test")
            self.assertEqual(100, len(list(self.session.execute(statement))))

            self.assertEqual(100, len(list(self.session.execute(prepared))))

    def test_paging_state(self):
        """
        Test to validate paging state api
        @since 3.7.0
        @jira_ticket PYTHON-200
        @expected_result paging state should returned should be accurate, and allow for queries to be resumed.

        @test_category queries
        """
        statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
                                    [(i, ) for i in range(100)])
        execute_concurrent(self.session, list(statements_and_params))

        list_all_results = []
        self.session.default_fetch_size = 3

        result_set = self.session.execute("SELECT * FROM test3rf.test")
        while(result_set.has_more_pages):
            for row in result_set.current_rows:
                self.assertNotIn(row, list_all_results)
            list_all_results.extend(result_set.current_rows)
            page_state = result_set.paging_state
            result_set = self.session.execute("SELECT * FROM test3rf.test", paging_state=page_state)

        if(len(result_set.current_rows) > 0):
            list_all_results.append(result_set.current_rows)
        self.assertEqual(len(list_all_results), 100)

    def test_paging_verify_writes(self):
        statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
                                    [(i, ) for i in range(100)])
        execute_concurrent(self.session, statements_and_params)

        prepared = self.session.prepare("SELECT * FROM test3rf.test")

        for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
            self.session.default_fetch_size = fetch_size
            results = self.session.execute("SELECT * FROM test3rf.test")
            result_array = set()
            result_set = set()
            for result in results:
                result_array.add(result.k)
                result_set.add(result.v)

            self.assertEqual(set(range(100)), result_array)
            self.assertEqual(set([0]), result_set)

            statement = SimpleStatement("SELECT * FROM test3rf.test")
            results = self.session.execute(statement)
            result_array = set()
            result_set = set()
            for result in results:
                result_array.add(result.k)
                result_set.add(result.v)

            self.assertEqual(set(range(100)), result_array)
            self.assertEqual(set([0]), result_set)

            results = self.session.execute(prepared)
            result_array = set()
            result_set = set()
            for result in results:
                result_array.add(result.k)
                result_set.add(result.v)

            self.assertEqual(set(range(100)), result_array)
            self.assertEqual(set([0]), result_set)

    def test_paging_verify_with_composite_keys(self):
        ddl = '''
            CREATE TABLE test3rf.test_paging_verify_2 (
                k1 int,
                k2 int,
                v int,
                PRIMARY KEY(k1, k2)
            )'''
        self.session.execute(ddl)

        statements_and_params = zip(cycle(["INSERT INTO test3rf.test_paging_verify_2 "
                                           "(k1, k2, v) VALUES (0, %s, %s)"]),
                                    [(i, i + 1) for i in range(100)])
        execute_concurrent(self.session, statements_and_params)

        prepared = self.session.prepare("SELECT * FROM test3rf.test_paging_verify_2")

        for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
            self.session.default_fetch_size = fetch_size
            results = self.session.execute("SELECT * FROM test3rf.test_paging_verify_2")
            result_array = []
            value_array = []
            for result in results:
                result_array.append(result.k2)
                value_array.append(result.v)

            self.assertSequenceEqual(range(100), result_array)
            self.assertSequenceEqual(range(1, 101), value_array)

            statement = SimpleStatement("SELECT * FROM test3rf.test_paging_verify_2")
            results = self.session.execute(statement)
            result_array = []
            value_array = []
            for result in results:
                result_array.append(result.k2)
                value_array.append(result.v)

            self.assertSequenceEqual(range(100), result_array)
            self.assertSequenceEqual(range(1, 101), value_array)

            results = self.session.execute(prepared)
            result_array = []
            value_array = []
            for result in results:
                result_array.append(result.k2)
                value_array.append(result.v)

            self.assertSequenceEqual(range(100), result_array)
            self.assertSequenceEqual(range(1, 101), value_array)

    def test_async_paging(self):
        statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
                                    [(i, ) for i in range(100)])
        execute_concurrent(self.session, list(statements_and_params))

        prepared = self.session.prepare("SELECT * FROM test3rf.test")

        for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
            self.session.default_fetch_size = fetch_size
            self.assertEqual(100, len(list(self.session.execute_async("SELECT * FROM test3rf.test").result())))

            statement = SimpleStatement("SELECT * FROM test3rf.test")
            self.assertEqual(100, len(list(self.session.execute_async(statement).result())))

            self.assertEqual(100, len(list(self.session.execute_async(prepared).result())))

    def test_async_paging_verify_writes(self):
        ddl = '''
            CREATE TABLE test3rf.test_async_paging_verify (
                k1 int,
                k2 int,
                v int,
                PRIMARY KEY(k1, k2)
            )'''
        self.session.execute(ddl)

        statements_and_params = zip(cycle(["INSERT INTO test3rf.test_async_paging_verify "
                                           "(k1, k2, v) VALUES (0, %s, %s)"]),
                                    [(i, i + 1) for i in range(100)])
        execute_concurrent(self.session, statements_and_params)

        prepared = self.session.prepare("SELECT * FROM test3rf.test_async_paging_verify")

        for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
            self.session.default_fetch_size = fetch_size
            results = self.session.execute_async("SELECT * FROM test3rf.test_async_paging_verify").result()
            result_array = []
            value_array = []
            for result in results:
                result_array.append(result.k2)
                value_array.append(result.v)

            self.assertSequenceEqual(range(100), result_array)
            self.assertSequenceEqual(range(1, 101), value_array)

            statement = SimpleStatement("SELECT * FROM test3rf.test_async_paging_verify")
            results = self.session.execute_async(statement).result()
            result_array = []
            value_array = []
            for result in results:
                result_array.append(result.k2)
                value_array.append(result.v)

            self.assertSequenceEqual(range(100), result_array)
            self.assertSequenceEqual(range(1, 101), value_array)

            results = self.session.execute_async(prepared).result()
            result_array = []
            value_array = []
            for result in results:
                result_array.append(result.k2)
                value_array.append(result.v)

            self.assertSequenceEqual(range(100), result_array)
            self.assertSequenceEqual(range(1, 101), value_array)

    def test_paging_callbacks(self):
        """
        Test to validate callback api
        @since 3.9.0
        @jira_ticket PYTHON-733
        @expected_result callbacks shouldn't be called twice per message
        and the fetch_size should be handled in a transparent way to the user

        @test_category queries
        """
        statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
                                    [(i, ) for i in range(100)])
        execute_concurrent(self.session, list(statements_and_params))

        prepared = self.session.prepare("SELECT * FROM test3rf.test")

        for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
            self.session.default_fetch_size = fetch_size
            future = self.session.execute_async("SELECT * FROM test3rf.test", timeout=20)

            event = Event()
            counter = count()
            number_of_calls = count()

            def handle_page(rows, future, counter, number_of_calls):
                next(number_of_calls)
                for row in rows:
                    next(counter)

                if future.has_more_pages:
                    future.start_fetching_next_page()
                else:
                    event.set()

            def handle_error(err):
                event.set()
                self.fail(err)

            future.add_callbacks(callback=handle_page, callback_args=(future, counter, number_of_calls),
                                 errback=handle_error)
            event.wait()
            self.assertEqual(next(number_of_calls), 100 // fetch_size + 1)
            self.assertEqual(next(counter), 100)

            # simple statement
            future = self.session.execute_async(SimpleStatement("SELECT * FROM test3rf.test"), timeout=20)
            event.clear()
            counter = count()
            number_of_calls = count()

            future.add_callbacks(callback=handle_page, callback_args=(future, counter, number_of_calls),
                                 errback=handle_error)
            event.wait()
            self.assertEqual(next(number_of_calls), 100 // fetch_size + 1)
            self.assertEqual(next(counter), 100)

            # prepared statement
            future = self.session.execute_async(prepared, timeout=20)
            event.clear()
            counter = count()
            number_of_calls = count()

            future.add_callbacks(callback=handle_page, callback_args=(future, counter, number_of_calls),
                                 errback=handle_error)
            event.wait()
            self.assertEqual(next(number_of_calls), 100 // fetch_size + 1)
            self.assertEqual(next(counter), 100)

    def test_concurrent_with_paging(self):
        statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
                                    [(i, ) for i in range(100)])
        execute_concurrent(self.session, list(statements_and_params))

        prepared = self.session.prepare("SELECT * FROM test3rf.test")

        for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
            self.session.default_fetch_size = fetch_size
            results = execute_concurrent_with_args(self.session, prepared, [None] * 10)
            self.assertEqual(10, len(results))
            for (success, result) in results:
                self.assertTrue(success)
                self.assertEqual(100, len(list(result)))

    def test_fetch_size(self):
        """
        Ensure per-statement fetch_sizes override the default fetch size.
        """
        statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
                                    [(i, ) for i in range(100)])
        execute_concurrent(self.session, list(statements_and_params))

        prepared = self.session.prepare("SELECT * FROM test3rf.test")

        self.session.default_fetch_size = 10
        result = self.session.execute(prepared, [])
        self.assertTrue(result.has_more_pages)

        self.session.default_fetch_size = 2000
        result = self.session.execute(prepared, [])
        self.assertFalse(result.has_more_pages)

        self.session.default_fetch_size = None
        result = self.session.execute(prepared, [])
        self.assertFalse(result.has_more_pages)

        self.session.default_fetch_size = 10

        prepared.fetch_size = 2000
        result = self.session.execute(prepared, [])
        self.assertFalse(result.has_more_pages)

        prepared.fetch_size = None
        result = self.session.execute(prepared, [])
        self.assertFalse(result.has_more_pages)

        prepared.fetch_size = 10
        result = self.session.execute(prepared, [])
        self.assertTrue(result.has_more_pages)

        prepared.fetch_size = 2000
        bound = prepared.bind([])
        result = self.session.execute(bound, [])
        self.assertFalse(result.has_more_pages)

        prepared.fetch_size = None
        bound = prepared.bind([])
        result = self.session.execute(bound, [])
        self.assertFalse(result.has_more_pages)

        prepared.fetch_size = 10
        bound = prepared.bind([])
        result = self.session.execute(bound, [])
        self.assertTrue(result.has_more_pages)

        bound.fetch_size = 2000
        result = self.session.execute(bound, [])
        self.assertFalse(result.has_more_pages)

        bound.fetch_size = None
        result = self.session.execute(bound, [])
        self.assertFalse(result.has_more_pages)

        bound.fetch_size = 10
        result = self.session.execute(bound, [])
        self.assertTrue(result.has_more_pages)

        s = SimpleStatement("SELECT * FROM test3rf.test", fetch_size=None)
        result = self.session.execute(s, [])
        self.assertFalse(result.has_more_pages)

        s = SimpleStatement("SELECT * FROM test3rf.test")
        result = self.session.execute(s, [])
        self.assertTrue(result.has_more_pages)

        s = SimpleStatement("SELECT * FROM test3rf.test")
        s.fetch_size = None
        result = self.session.execute(s, [])
        self.assertFalse(result.has_more_pages)
class BatchStatementTests(unittest.TestCase):

    def setUp(self):
        if PROTOCOL_VERSION < 2:
            raise unittest.SkipTest(
                "Protocol 2.0+ is required for BATCH operations, currently testing against %r"
                % (PROTOCOL_VERSION,))

        self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
        if PROTOCOL_VERSION < 3:
            self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
        self.session = self.cluster.connect()

        self.session.execute("TRUNCATE test3rf.test")

    def tearDown(self):
        self.cluster.shutdown()

    def confirm_results(self):
        keys = set()
        values = set()
        results = self.session.execute("SELECT * FROM test3rf.test")
        for result in results:
            keys.add(result.k)
            values.add(result.v)

        self.assertEqual(set(range(10)), keys)
        self.assertEqual(set(range(10)), values)

    def test_string_statements(self):
        batch = BatchStatement(BatchType.LOGGED)
        for i in range(10):
            batch.add("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", (i, i))

        self.session.execute(batch)
        self.session.execute_async(batch).result()
        self.confirm_results()

    def test_simple_statements(self):
        batch = BatchStatement(BatchType.LOGGED)
        for i in range(10):
            batch.add(SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)"), (i, i))

        self.session.execute(batch)
        self.session.execute_async(batch).result()
        self.confirm_results()

    def test_prepared_statements(self):
        prepared = self.session.prepare("INSERT INTO test3rf.test (k, v) VALUES (?, ?)")

        batch = BatchStatement(BatchType.LOGGED)
        for i in range(10):
            batch.add(prepared, (i, i))

        self.session.execute(batch)
        self.session.execute_async(batch).result()
        self.confirm_results()

    def test_bound_statements(self):
        prepared = self.session.prepare("INSERT INTO test3rf.test (k, v) VALUES (?, ?)")

        batch = BatchStatement(BatchType.LOGGED)
        for i in range(10):
            batch.add(prepared.bind((i, i)))

        self.session.execute(batch)
        self.session.execute_async(batch).result()
        self.confirm_results()

    def test_no_parameters(self):
        batch = BatchStatement(BatchType.LOGGED)
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (0, 0)")
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (1, 1)", ())
        batch.add(SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (2, 2)"))
        batch.add(SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (3, 3)"), ())

        prepared = self.session.prepare("INSERT INTO test3rf.test (k, v) VALUES (4, 4)")
        batch.add(prepared)
        batch.add(prepared, ())
        batch.add(prepared.bind([]))
        batch.add(prepared.bind([]), ())

        batch.add("INSERT INTO test3rf.test (k, v) VALUES (5, 5)", ())
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (6, 6)", ())
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (7, 7)", ())
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (8, 8)", ())
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (9, 9)", ())

        self.assertRaises(ValueError, batch.add, prepared.bind([]), (1))
        self.assertRaises(ValueError, batch.add, prepared.bind([]), (1, 2))
        self.assertRaises(ValueError, batch.add, prepared.bind([]), (1, 2, 3))

        self.session.execute(batch)
        self.confirm_results()
class QueryPagingTests(unittest.TestCase):

    def setUp(self):
        if PROTOCOL_VERSION < 2:
            raise unittest.SkipTest(
                "Protocol 2.0+ is required for Paging state, currently testing against %r"
                % (PROTOCOL_VERSION,))

        self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
        if PROTOCOL_VERSION < 3:
            self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
        self.session = self.cluster.connect()
        self.session.execute("TRUNCATE test3rf.test")

    def tearDown(self):
        self.cluster.shutdown()

    def test_paging(self):
        statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
                                    [(i, ) for i in range(100)])
        execute_concurrent(self.session, list(statements_and_params))

        prepared = self.session.prepare("SELECT * FROM test3rf.test")

        for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
            self.session.default_fetch_size = fetch_size
            self.assertEqual(100, len(list(self.session.execute("SELECT * FROM test3rf.test"))))

            statement = SimpleStatement("SELECT * FROM test3rf.test")
            self.assertEqual(100, len(list(self.session.execute(statement))))

            self.assertEqual(100, len(list(self.session.execute(prepared))))

    def test_paging_verify_writes(self):
        statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
                                    [(i, ) for i in range(100)])
        execute_concurrent(self.session, statements_and_params)

        prepared = self.session.prepare("SELECT * FROM test3rf.test")

        for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
            self.session.default_fetch_size = fetch_size
            results = self.session.execute("SELECT * FROM test3rf.test")
            result_array = set()
            result_set = set()
            for result in results:
                result_array.add(result.k)
                result_set.add(result.v)

            self.assertEqual(set(range(100)), result_array)
            self.assertEqual(set([0]), result_set)

            statement = SimpleStatement("SELECT * FROM test3rf.test")
            results = self.session.execute(statement)
            result_array = set()
            result_set = set()
            for result in results:
                result_array.add(result.k)
                result_set.add(result.v)

            self.assertEqual(set(range(100)), result_array)
            self.assertEqual(set([0]), result_set)

            results = self.session.execute(prepared)
            result_array = set()
            result_set = set()
            for result in results:
                result_array.add(result.k)
                result_set.add(result.v)

            self.assertEqual(set(range(100)), result_array)
            self.assertEqual(set([0]), result_set)

    def test_paging_verify_with_composite_keys(self):
        ddl = '''
            CREATE TABLE test3rf.test_paging_verify_2 (
                k1 int,
                k2 int,
                v int,
                PRIMARY KEY(k1, k2)
            )'''
        self.session.execute(ddl)

        statements_and_params = zip(cycle(["INSERT INTO test3rf.test_paging_verify_2 "
                                           "(k1, k2, v) VALUES (0, %s, %s)"]),
                                    [(i, i + 1) for i in range(100)])
        execute_concurrent(self.session, statements_and_params)

        prepared = self.session.prepare("SELECT * FROM test3rf.test_paging_verify_2")

        for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
            self.session.default_fetch_size = fetch_size
            results = self.session.execute("SELECT * FROM test3rf.test_paging_verify_2")
            result_array = []
            value_array = []
            for result in results:
                result_array.append(result.k2)
                value_array.append(result.v)

            self.assertSequenceEqual(range(100), result_array)
            self.assertSequenceEqual(range(1, 101), value_array)

            statement = SimpleStatement("SELECT * FROM test3rf.test_paging_verify_2")
            results = self.session.execute(statement)
            result_array = []
            value_array = []
            for result in results:
                result_array.append(result.k2)
                value_array.append(result.v)

            self.assertSequenceEqual(range(100), result_array)
            self.assertSequenceEqual(range(1, 101), value_array)

            results = self.session.execute(prepared)
            result_array = []
            value_array = []
            for result in results:
                result_array.append(result.k2)
                value_array.append(result.v)

            self.assertSequenceEqual(range(100), result_array)
            self.assertSequenceEqual(range(1, 101), value_array)

    def test_async_paging(self):
        statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
                                    [(i, ) for i in range(100)])
        execute_concurrent(self.session, list(statements_and_params))

        prepared = self.session.prepare("SELECT * FROM test3rf.test")

        for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
            self.session.default_fetch_size = fetch_size
            self.assertEqual(100, len(list(self.session.execute_async("SELECT * FROM test3rf.test").result())))

            statement = SimpleStatement("SELECT * FROM test3rf.test")
            self.assertEqual(100, len(list(self.session.execute_async(statement).result())))

            self.assertEqual(100, len(list(self.session.execute_async(prepared).result())))

    def test_async_paging_verify_writes(self):
        ddl = '''
            CREATE TABLE test3rf.test_async_paging_verify (
                k1 int,
                k2 int,
                v int,
                PRIMARY KEY(k1, k2)
            )'''
        self.session.execute(ddl)

        statements_and_params = zip(cycle(["INSERT INTO test3rf.test_async_paging_verify "
                                           "(k1, k2, v) VALUES (0, %s, %s)"]),
                                    [(i, i + 1) for i in range(100)])
        execute_concurrent(self.session, statements_and_params)

        prepared = self.session.prepare("SELECT * FROM test3rf.test_async_paging_verify")

        for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
            self.session.default_fetch_size = fetch_size
            results = self.session.execute_async("SELECT * FROM test3rf.test_async_paging_verify").result()
            result_array = []
            value_array = []
            for result in results:
                result_array.append(result.k2)
                value_array.append(result.v)

            self.assertSequenceEqual(range(100), result_array)
            self.assertSequenceEqual(range(1, 101), value_array)

            statement = SimpleStatement("SELECT * FROM test3rf.test_async_paging_verify")
            results = self.session.execute_async(statement).result()
            result_array = []
            value_array = []
            for result in results:
                result_array.append(result.k2)
                value_array.append(result.v)

            self.assertSequenceEqual(range(100), result_array)
            self.assertSequenceEqual(range(1, 101), value_array)

            results = self.session.execute_async(prepared).result()
            result_array = []
            value_array = []
            for result in results:
                result_array.append(result.k2)
                value_array.append(result.v)

            self.assertSequenceEqual(range(100), result_array)
            self.assertSequenceEqual(range(1, 101), value_array)

    def test_paging_callbacks(self):
        statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
                                    [(i, ) for i in range(100)])
        execute_concurrent(self.session, list(statements_and_params))

        prepared = self.session.prepare("SELECT * FROM test3rf.test")

        for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
            self.session.default_fetch_size = fetch_size
            future = self.session.execute_async("SELECT * FROM test3rf.test")

            event = Event()
            counter = count()

            def handle_page(rows, future, counter):
                for row in rows:
                    next(counter)

                if future.has_more_pages:
                    future.start_fetching_next_page()
                else:
                    event.set()

            def handle_error(err):
                event.set()
                self.fail(err)

            future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error)
            event.wait()
            self.assertEqual(next(counter), 100)

            # simple statement
            future = self.session.execute_async(SimpleStatement("SELECT * FROM test3rf.test"))
            event.clear()
            counter = count()

            future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error)
            event.wait()
            self.assertEqual(next(counter), 100)

            # prepared statement
            future = self.session.execute_async(prepared)
            event.clear()
            counter = count()

            future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error)
            event.wait()
            self.assertEqual(next(counter), 100)

    def test_concurrent_with_paging(self):
        statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
                                    [(i, ) for i in range(100)])
        execute_concurrent(self.session, list(statements_and_params))

        prepared = self.session.prepare("SELECT * FROM test3rf.test")

        for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
            self.session.default_fetch_size = fetch_size
            results = execute_concurrent_with_args(self.session, prepared, [None] * 10)
            self.assertEqual(10, len(results))
            for (success, result) in results:
                self.assertTrue(success)
                self.assertEqual(100, len(list(result)))

    def test_fetch_size(self):
        """
        Ensure per-statement fetch_sizes override the default fetch size.
        """
        statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
                                    [(i, ) for i in range(100)])
        execute_concurrent(self.session, list(statements_and_params))

        prepared = self.session.prepare("SELECT * FROM test3rf.test")

        self.session.default_fetch_size = 10
        result = self.session.execute(prepared, [])
        self.assertTrue(result.has_more_pages)

        self.session.default_fetch_size = 2000
        result = self.session.execute(prepared, [])
        self.assertFalse(result.has_more_pages)

        self.session.default_fetch_size = None
        result = self.session.execute(prepared, [])
        self.assertFalse(result.has_more_pages)

        self.session.default_fetch_size = 10

        prepared.fetch_size = 2000
        result = self.session.execute(prepared, [])
        self.assertFalse(result.has_more_pages)

        prepared.fetch_size = None
        result = self.session.execute(prepared, [])
        self.assertFalse(result.has_more_pages)

        prepared.fetch_size = 10
        result = self.session.execute(prepared, [])
        self.assertTrue(result.has_more_pages)

        prepared.fetch_size = 2000
        bound = prepared.bind([])
        result = self.session.execute(bound, [])
        self.assertFalse(result.has_more_pages)

        prepared.fetch_size = None
        bound = prepared.bind([])
        result = self.session.execute(bound, [])
        self.assertFalse(result.has_more_pages)

        prepared.fetch_size = 10
        bound = prepared.bind([])
        result = self.session.execute(bound, [])
        self.assertTrue(result.has_more_pages)

        bound.fetch_size = 2000
        result = self.session.execute(bound, [])
        self.assertFalse(result.has_more_pages)

        bound.fetch_size = None
        result = self.session.execute(bound, [])
        self.assertFalse(result.has_more_pages)

        bound.fetch_size = 10
        result = self.session.execute(bound, [])
        self.assertTrue(result.has_more_pages)

        s = SimpleStatement("SELECT * FROM test3rf.test", fetch_size=None)
        result = self.session.execute(s, [])
        self.assertFalse(result.has_more_pages)

        s = SimpleStatement("SELECT * FROM test3rf.test")
        result = self.session.execute(s, [])
        self.assertTrue(result.has_more_pages)

        s = SimpleStatement("SELECT * FROM test3rf.test")
        s.fetch_size = None
        result = self.session.execute(s, [])
        self.assertFalse(result.has_more_pages)
Exemple #23
0
class ClusterTests(unittest.TestCase):

    def setUp(self):
        self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
        self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
        self.session = self.cluster.connect()
        self.session.row_factory = tuple_factory

    def test_execute_concurrent(self):
        for num_statements in (0, 1, 2, 7, 10, 99, 100, 101, 199, 200, 201):
            # write
            statement = SimpleStatement(
                "INSERT INTO test3rf.test (k, v) VALUES (%s, %s)",
                consistency_level=ConsistencyLevel.QUORUM)
            statements = cycle((statement, ))
            parameters = [(i, i) for i in range(num_statements)]

            results = execute_concurrent(self.session, list(zip(statements, parameters)))
            self.assertEqual(num_statements, len(results))
            self.assertEqual([(True, None)] * num_statements, results)

            # read
            statement = SimpleStatement(
                "SELECT v FROM test3rf.test WHERE k=%s",
                consistency_level=ConsistencyLevel.QUORUM)
            statements = cycle((statement, ))
            parameters = [(i, ) for i in range(num_statements)]

            results = execute_concurrent(self.session, list(zip(statements, parameters)))
            self.assertEqual(num_statements, len(results))
            self.assertEqual([(True, [(i,)]) for i in range(num_statements)], results)

    def test_execute_concurrent_with_args(self):
        for num_statements in (0, 1, 2, 7, 10, 99, 100, 101, 199, 200, 201):
            statement = SimpleStatement(
                "INSERT INTO test3rf.test (k, v) VALUES (%s, %s)",
                consistency_level=ConsistencyLevel.QUORUM)
            parameters = [(i, i) for i in range(num_statements)]

            results = execute_concurrent_with_args(self.session, statement, parameters)
            self.assertEqual(num_statements, len(results))
            self.assertEqual([(True, None)] * num_statements, results)

            # read
            statement = SimpleStatement(
                "SELECT v FROM test3rf.test WHERE k=%s",
                consistency_level=ConsistencyLevel.QUORUM)
            parameters = [(i, ) for i in range(num_statements)]

            results = execute_concurrent_with_args(self.session, statement, parameters)
            self.assertEqual(num_statements, len(results))
            self.assertEqual([(True, [(i,)]) for i in range(num_statements)], results)

    def test_first_failure(self):
        statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
        parameters = [(i, i) for i in range(100)]

        # we'll get an error back from the server
        parameters[57] = ('efefef', 'awefawefawef')

        self.assertRaises(
            InvalidRequest,
            execute_concurrent, self.session, list(zip(statements, parameters)), raise_on_first_error=True)

    def test_first_failure_client_side(self):
        statement = SimpleStatement(
            "INSERT INTO test3rf.test (k, v) VALUES (%s, %s)",
            consistency_level=ConsistencyLevel.QUORUM)
        statements = cycle((statement, ))
        parameters = [(i, i) for i in range(100)]

        # the driver will raise an error when binding the params
        parameters[57] = 1

        self.assertRaises(
            TypeError,
            execute_concurrent, self.session, list(zip(statements, parameters)), raise_on_first_error=True)

    def test_no_raise_on_first_failure(self):
        statement = SimpleStatement(
            "INSERT INTO test3rf.test (k, v) VALUES (%s, %s)",
            consistency_level=ConsistencyLevel.QUORUM)
        statements = cycle((statement, ))
        parameters = [(i, i) for i in range(100)]

        # we'll get an error back from the server
        parameters[57] = ('efefef', 'awefawefawef')

        results = execute_concurrent(self.session, list(zip(statements, parameters)), raise_on_first_error=False)
        for i, (success, result) in enumerate(results):
            if i == 57:
                self.assertFalse(success)
                self.assertIsInstance(result, InvalidRequest)
            else:
                self.assertTrue(success)
                self.assertEqual(None, result)

    def test_no_raise_on_first_failure_client_side(self):
        statement = SimpleStatement(
            "INSERT INTO test3rf.test (k, v) VALUES (%s, %s)",
            consistency_level=ConsistencyLevel.QUORUM)
        statements = cycle((statement, ))
        parameters = [(i, i) for i in range(100)]

        # the driver will raise an error when binding the params
        parameters[57] = 1

        results = execute_concurrent(self.session, list(zip(statements, parameters)), raise_on_first_error=False)
        for i, (success, result) in enumerate(results):
            if i == 57:
                self.assertFalse(success)
                self.assertIsInstance(result, TypeError)
            else:
                self.assertTrue(success)
                self.assertEqual(None, result)
class QueryPagingTests(unittest.TestCase):
    def setUp(self):
        if PROTOCOL_VERSION < 2:
            raise unittest.SkipTest(
                "Protocol 2.0+ is required for Paging state, currently testing against %r"
                % (PROTOCOL_VERSION, ))

        self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
        if PROTOCOL_VERSION < 3:
            self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
        self.session = self.cluster.connect()
        self.session.execute("TRUNCATE test3rf.test")

    def test_paging(self):
        statements_and_params = zip(
            cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
            [(i, ) for i in range(100)])
        execute_concurrent(self.session, list(statements_and_params))

        prepared = self.session.prepare("SELECT * FROM test3rf.test")

        for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
            self.session.default_fetch_size = fetch_size
            self.assertEqual(
                100,
                len(list(self.session.execute("SELECT * FROM test3rf.test"))))

            statement = SimpleStatement("SELECT * FROM test3rf.test")
            self.assertEqual(100, len(list(self.session.execute(statement))))

            self.assertEqual(100, len(list(self.session.execute(prepared))))

    def test_paging_verify_writes(self):
        statements_and_params = zip(
            cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
            [(i, ) for i in range(100)])
        execute_concurrent(self.session, statements_and_params)

        prepared = self.session.prepare("SELECT * FROM test3rf.test")

        for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
            self.session.default_fetch_size = fetch_size
            results = self.session.execute("SELECT * FROM test3rf.test")
            result_array = set()
            result_set = set()
            for result in results:
                result_array.add(result.k)
                result_set.add(result.v)

            self.assertEqual(set(range(100)), result_array)
            self.assertEqual(set([0]), result_set)

            statement = SimpleStatement("SELECT * FROM test3rf.test")
            results = self.session.execute(statement)
            result_array = set()
            result_set = set()
            for result in results:
                result_array.add(result.k)
                result_set.add(result.v)

            self.assertEqual(set(range(100)), result_array)
            self.assertEqual(set([0]), result_set)

            results = self.session.execute(prepared)
            result_array = set()
            result_set = set()
            for result in results:
                result_array.add(result.k)
                result_set.add(result.v)

            self.assertEqual(set(range(100)), result_array)
            self.assertEqual(set([0]), result_set)

    def test_paging_verify_with_composite_keys(self):
        ddl = '''
            CREATE TABLE test3rf.test_paging_verify_2 (
                k1 int,
                k2 int,
                v int,
                PRIMARY KEY(k1, k2)
            )'''
        self.session.execute(ddl)

        statements_and_params = zip(
            cycle([
                "INSERT INTO test3rf.test_paging_verify_2 "
                "(k1, k2, v) VALUES (0, %s, %s)"
            ]), [(i, i + 1) for i in range(100)])
        execute_concurrent(self.session, statements_and_params)

        prepared = self.session.prepare(
            "SELECT * FROM test3rf.test_paging_verify_2")

        for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
            self.session.default_fetch_size = fetch_size
            results = self.session.execute(
                "SELECT * FROM test3rf.test_paging_verify_2")
            result_array = []
            value_array = []
            for result in results:
                result_array.append(result.k2)
                value_array.append(result.v)

            self.assertSequenceEqual(range(100), result_array)
            self.assertSequenceEqual(range(1, 101), value_array)

            statement = SimpleStatement(
                "SELECT * FROM test3rf.test_paging_verify_2")
            results = self.session.execute(statement)
            result_array = []
            value_array = []
            for result in results:
                result_array.append(result.k2)
                value_array.append(result.v)

            self.assertSequenceEqual(range(100), result_array)
            self.assertSequenceEqual(range(1, 101), value_array)

            results = self.session.execute(prepared)
            result_array = []
            value_array = []
            for result in results:
                result_array.append(result.k2)
                value_array.append(result.v)

            self.assertSequenceEqual(range(100), result_array)
            self.assertSequenceEqual(range(1, 101), value_array)

    def test_async_paging(self):
        statements_and_params = zip(
            cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
            [(i, ) for i in range(100)])
        execute_concurrent(self.session, list(statements_and_params))

        prepared = self.session.prepare("SELECT * FROM test3rf.test")

        for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
            self.session.default_fetch_size = fetch_size
            self.assertEqual(
                100,
                len(
                    list(
                        self.session.execute_async(
                            "SELECT * FROM test3rf.test").result())))

            statement = SimpleStatement("SELECT * FROM test3rf.test")
            self.assertEqual(
                100, len(list(self.session.execute_async(statement).result())))

            self.assertEqual(
                100, len(list(self.session.execute_async(prepared).result())))

    def test_async_paging_verify_writes(self):
        ddl = '''
            CREATE TABLE test3rf.test_async_paging_verify (
                k1 int,
                k2 int,
                v int,
                PRIMARY KEY(k1, k2)
            )'''
        self.session.execute(ddl)

        statements_and_params = zip(
            cycle([
                "INSERT INTO test3rf.test_async_paging_verify "
                "(k1, k2, v) VALUES (0, %s, %s)"
            ]), [(i, i + 1) for i in range(100)])
        execute_concurrent(self.session, statements_and_params)

        prepared = self.session.prepare(
            "SELECT * FROM test3rf.test_async_paging_verify")

        for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
            self.session.default_fetch_size = fetch_size
            results = self.session.execute_async(
                "SELECT * FROM test3rf.test_async_paging_verify").result()
            result_array = []
            value_array = []
            for result in results:
                result_array.append(result.k2)
                value_array.append(result.v)

            self.assertSequenceEqual(range(100), result_array)
            self.assertSequenceEqual(range(1, 101), value_array)

            statement = SimpleStatement(
                "SELECT * FROM test3rf.test_async_paging_verify")
            results = self.session.execute_async(statement).result()
            result_array = []
            value_array = []
            for result in results:
                result_array.append(result.k2)
                value_array.append(result.v)

            self.assertSequenceEqual(range(100), result_array)
            self.assertSequenceEqual(range(1, 101), value_array)

            results = self.session.execute_async(prepared).result()
            result_array = []
            value_array = []
            for result in results:
                result_array.append(result.k2)
                value_array.append(result.v)

            self.assertSequenceEqual(range(100), result_array)
            self.assertSequenceEqual(range(1, 101), value_array)

    def test_paging_callbacks(self):
        statements_and_params = zip(
            cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
            [(i, ) for i in range(100)])
        execute_concurrent(self.session, list(statements_and_params))

        prepared = self.session.prepare("SELECT * FROM test3rf.test")

        for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
            self.session.default_fetch_size = fetch_size
            future = self.session.execute_async("SELECT * FROM test3rf.test")

            event = Event()
            counter = count()

            def handle_page(rows, future, counter):
                for row in rows:
                    next(counter)

                if future.has_more_pages:
                    future.start_fetching_next_page()
                else:
                    event.set()

            def handle_error(err):
                event.set()
                self.fail(err)

            future.add_callbacks(callback=handle_page,
                                 callback_args=(future, counter),
                                 errback=handle_error)
            event.wait()
            self.assertEqual(next(counter), 100)

            # simple statement
            future = self.session.execute_async(
                SimpleStatement("SELECT * FROM test3rf.test"))
            event.clear()
            counter = count()

            future.add_callbacks(callback=handle_page,
                                 callback_args=(future, counter),
                                 errback=handle_error)
            event.wait()
            self.assertEqual(next(counter), 100)

            # prepared statement
            future = self.session.execute_async(prepared)
            event.clear()
            counter = count()

            future.add_callbacks(callback=handle_page,
                                 callback_args=(future, counter),
                                 errback=handle_error)
            event.wait()
            self.assertEqual(next(counter), 100)

    def test_concurrent_with_paging(self):
        statements_and_params = zip(
            cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
            [(i, ) for i in range(100)])
        execute_concurrent(self.session, list(statements_and_params))

        prepared = self.session.prepare("SELECT * FROM test3rf.test")

        for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
            self.session.default_fetch_size = fetch_size
            results = execute_concurrent_with_args(self.session, prepared,
                                                   [None] * 10)
            self.assertEqual(10, len(results))
            for (success, result) in results:
                self.assertTrue(success)
                self.assertEqual(100, len(list(result)))

    def test_fetch_size(self):
        """
        Ensure per-statement fetch_sizes override the default fetch size.
        """
        statements_and_params = zip(
            cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
            [(i, ) for i in range(100)])
        execute_concurrent(self.session, list(statements_and_params))

        prepared = self.session.prepare("SELECT * FROM test3rf.test")

        self.session.default_fetch_size = 10
        result = self.session.execute(prepared, [])
        self.assertIsInstance(result, PagedResult)

        self.session.default_fetch_size = 2000
        result = self.session.execute(prepared, [])
        self.assertIsInstance(result, list)

        self.session.default_fetch_size = None
        result = self.session.execute(prepared, [])
        self.assertIsInstance(result, list)

        self.session.default_fetch_size = 10

        prepared.fetch_size = 2000
        result = self.session.execute(prepared, [])
        self.assertIsInstance(result, list)

        prepared.fetch_size = None
        result = self.session.execute(prepared, [])
        self.assertIsInstance(result, list)

        prepared.fetch_size = 10
        result = self.session.execute(prepared, [])
        self.assertIsInstance(result, PagedResult)

        prepared.fetch_size = 2000
        bound = prepared.bind([])
        result = self.session.execute(bound, [])
        self.assertIsInstance(result, list)

        prepared.fetch_size = None
        bound = prepared.bind([])
        result = self.session.execute(bound, [])
        self.assertIsInstance(result, list)

        prepared.fetch_size = 10
        bound = prepared.bind([])
        result = self.session.execute(bound, [])
        self.assertIsInstance(result, PagedResult)

        bound.fetch_size = 2000
        result = self.session.execute(bound, [])
        self.assertIsInstance(result, list)

        bound.fetch_size = None
        result = self.session.execute(bound, [])
        self.assertIsInstance(result, list)

        bound.fetch_size = 10
        result = self.session.execute(bound, [])
        self.assertIsInstance(result, PagedResult)

        s = SimpleStatement("SELECT * FROM test3rf.test", fetch_size=None)
        result = self.session.execute(s, [])
        self.assertIsInstance(result, list)

        s = SimpleStatement("SELECT * FROM test3rf.test")
        result = self.session.execute(s, [])
        self.assertIsInstance(result, PagedResult)

        s = SimpleStatement("SELECT * FROM test3rf.test")
        s.fetch_size = None
        result = self.session.execute(s, [])
        self.assertIsInstance(result, list)
Exemple #25
0
class SerialConsistencyTests(unittest.TestCase):
    def setUp(self):
        if PROTOCOL_VERSION < 2:
            raise unittest.SkipTest(
                "Protocol 2.0+ is required for Serial Consistency, currently testing against %r"
                % (PROTOCOL_VERSION,))

        self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
        if PROTOCOL_VERSION < 3:
            self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
        self.session = self.cluster.connect()

    def tearDown(self):
        self.cluster.shutdown()

    def test_conditional_update(self):
        self.session.execute("INSERT INTO test3rf.test (k, v) VALUES (0, 0)")
        statement = SimpleStatement(
            "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=1",
            serial_consistency_level=ConsistencyLevel.SERIAL)
        # crazy test, but PYTHON-299
        # TODO: expand to check more parameters get passed to statement, and on to messages
        self.assertEqual(statement.serial_consistency_level, ConsistencyLevel.SERIAL)
        future = self.session.execute_async(statement)
        result = future.result()
        self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.SERIAL)
        self.assertTrue(result)
        self.assertFalse(result[0].applied)

        statement = SimpleStatement(
            "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=0",
            serial_consistency_level=ConsistencyLevel.LOCAL_SERIAL)
        self.assertEqual(statement.serial_consistency_level, ConsistencyLevel.LOCAL_SERIAL)
        future = self.session.execute_async(statement)
        result = future.result()
        self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.LOCAL_SERIAL)
        self.assertTrue(result)
        self.assertTrue(result[0].applied)

    def test_conditional_update_with_prepared_statements(self):
        self.session.execute("INSERT INTO test3rf.test (k, v) VALUES (0, 0)")
        statement = self.session.prepare(
            "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=2")

        statement.serial_consistency_level = ConsistencyLevel.SERIAL
        future = self.session.execute_async(statement)
        result = future.result()
        self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.SERIAL)
        self.assertTrue(result)
        self.assertFalse(result[0].applied)

        statement = self.session.prepare(
            "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=0")
        bound = statement.bind(())
        bound.serial_consistency_level = ConsistencyLevel.LOCAL_SERIAL
        future = self.session.execute_async(bound)
        result = future.result()
        self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.LOCAL_SERIAL)
        self.assertTrue(result)
        self.assertTrue(result[0].applied)

    def test_conditional_update_with_batch_statements(self):
        self.session.execute("INSERT INTO test3rf.test (k, v) VALUES (0, 0)")
        statement = BatchStatement(serial_consistency_level=ConsistencyLevel.SERIAL)
        statement.add("UPDATE test3rf.test SET v=1 WHERE k=0 IF v=1")
        self.assertEqual(statement.serial_consistency_level, ConsistencyLevel.SERIAL)
        future = self.session.execute_async(statement)
        result = future.result()
        self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.SERIAL)
        self.assertTrue(result)
        self.assertFalse(result[0].applied)

        statement = BatchStatement(serial_consistency_level=ConsistencyLevel.LOCAL_SERIAL)
        statement.add("UPDATE test3rf.test SET v=1 WHERE k=0 IF v=0")
        self.assertEqual(statement.serial_consistency_level, ConsistencyLevel.LOCAL_SERIAL)
        future = self.session.execute_async(statement)
        result = future.result()
        self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.LOCAL_SERIAL)
        self.assertTrue(result)
        self.assertTrue(result[0].applied)

    def test_bad_consistency_level(self):
        statement = SimpleStatement("foo")
        self.assertRaises(ValueError, setattr, statement, 'serial_consistency_level', ConsistencyLevel.ONE)
        self.assertRaises(ValueError, SimpleStatement, 'foo', serial_consistency_level=ConsistencyLevel.ONE)
Exemple #26
0
def teardown(hosts):
    cluster = Cluster(hosts)
    cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
    session = cluster.connect()
    session.execute("DROP KEYSPACE " + KEYSPACE)
    cluster.shutdown()
class ClusterTests(unittest.TestCase):
    def setUp(self):
        self.cluster = Cluster()
        self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
        self.session = self.cluster.connect()
        self.session.row_factory = tuple_factory

    def test_execute_concurrent(self):
        for num_statements in (0, 1, 2, 7, 10, 99, 100, 101, 199, 200, 201):
            # write
            statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)",))
            parameters = [(i, i) for i in range(num_statements)]

            results = execute_concurrent(self.session, zip(statements, parameters))
            self.assertEqual(num_statements, len(results))
            self.assertEqual([(True, None)] * num_statements, results)

            # read
            statements = cycle(("SELECT v FROM test3rf.test WHERE k=%s",))
            parameters = [(i,) for i in range(num_statements)]

            results = execute_concurrent(self.session, zip(statements, parameters))
            self.assertEqual(num_statements, len(results))
            self.assertEqual([(True, [(i,)]) for i in range(num_statements)], results)

    def test_execute_concurrent_with_args(self):
        for num_statements in (0, 1, 2, 7, 10, 99, 100, 101, 199, 200, 201):
            statement = "INSERT INTO test3rf.test (k, v) VALUES (%s, %s)"
            parameters = [(i, i) for i in range(num_statements)]

            results = execute_concurrent_with_args(self.session, statement, parameters)
            self.assertEqual(num_statements, len(results))
            self.assertEqual([(True, None)] * num_statements, results)

            # read
            statement = "SELECT v FROM test3rf.test WHERE k=%s"
            parameters = [(i,) for i in range(num_statements)]

            results = execute_concurrent_with_args(self.session, statement, parameters)
            self.assertEqual(num_statements, len(results))
            self.assertEqual([(True, [(i,)]) for i in range(num_statements)], results)

    def test_first_failure(self):
        statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)",))
        parameters = [(i, i) for i in range(100)]

        # we'll get an error back from the server
        parameters[57] = ("efefef", "awefawefawef")

        self.assertRaises(
            InvalidRequest, execute_concurrent, self.session, zip(statements, parameters), raise_on_first_error=True
        )

    def test_first_failure_client_side(self):
        statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)",))
        parameters = [(i, i) for i in range(100)]

        # the driver will raise an error when binding the params
        parameters[57] = 1

        self.assertRaises(
            TypeError, execute_concurrent, self.session, zip(statements, parameters), raise_on_first_error=True
        )

    def test_no_raise_on_first_failure(self):
        statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)",))
        parameters = [(i, i) for i in range(100)]

        # we'll get an error back from the server
        parameters[57] = ("efefef", "awefawefawef")

        results = execute_concurrent(self.session, zip(statements, parameters), raise_on_first_error=False)
        for i, (success, result) in enumerate(results):
            if i == 57:
                self.assertFalse(success)
                self.assertIsInstance(result, InvalidRequest)
            else:
                self.assertTrue(success)
                self.assertEqual(None, result)

    def test_no_raise_on_first_failure_client_side(self):
        statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)",))
        parameters = [(i, i) for i in range(100)]

        # the driver will raise an error when binding the params
        parameters[57] = i

        results = execute_concurrent(self.session, zip(statements, parameters), raise_on_first_error=False)
        for i, (success, result) in enumerate(results):
            if i == 57:
                self.assertFalse(success)
                self.assertIsInstance(result, TypeError)
            else:
                self.assertTrue(success)
                self.assertEqual(None, result)
class ClusterTests(unittest.TestCase):

    def setUp(self):
        self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
        if PROTOCOL_VERSION < 3:
            self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
        self.session = self.cluster.connect()
        self.session.row_factory = tuple_factory

    def test_execute_concurrent(self):
        for num_statements in (0, 1, 2, 7, 10, 99, 100, 101, 199, 200, 201):
            # write
            statement = SimpleStatement(
                "INSERT INTO test3rf.test (k, v) VALUES (%s, %s)",
                consistency_level=ConsistencyLevel.QUORUM)
            statements = cycle((statement, ))
            parameters = [(i, i) for i in range(num_statements)]

            results = execute_concurrent(self.session, list(zip(statements, parameters)))
            self.assertEqual(num_statements, len(results))
            self.assertEqual([(True, None)] * num_statements, results)

            # read
            statement = SimpleStatement(
                "SELECT v FROM test3rf.test WHERE k=%s",
                consistency_level=ConsistencyLevel.QUORUM)
            statements = cycle((statement, ))
            parameters = [(i, ) for i in range(num_statements)]

            results = execute_concurrent(self.session, list(zip(statements, parameters)))
            self.assertEqual(num_statements, len(results))
            self.assertEqual([(True, [(i,)]) for i in range(num_statements)], results)

    def test_execute_concurrent_with_args(self):
        for num_statements in (0, 1, 2, 7, 10, 99, 100, 101, 199, 200, 201):
            statement = SimpleStatement(
                "INSERT INTO test3rf.test (k, v) VALUES (%s, %s)",
                consistency_level=ConsistencyLevel.QUORUM)
            parameters = [(i, i) for i in range(num_statements)]

            results = execute_concurrent_with_args(self.session, statement, parameters)
            self.assertEqual(num_statements, len(results))
            self.assertEqual([(True, None)] * num_statements, results)

            # read
            statement = SimpleStatement(
                "SELECT v FROM test3rf.test WHERE k=%s",
                consistency_level=ConsistencyLevel.QUORUM)
            parameters = [(i, ) for i in range(num_statements)]

            results = execute_concurrent_with_args(self.session, statement, parameters)
            self.assertEqual(num_statements, len(results))
            self.assertEqual([(True, [(i,)]) for i in range(num_statements)], results)

    def test_execute_concurrent_paged_result(self):
        if PROTOCOL_VERSION < 2:
            raise unittest.SkipTest(
                "Protocol 2+ is required for Paging, currently testing against %r"
                % (PROTOCOL_VERSION,))

        num_statements = 201
        statement = SimpleStatement(
            "INSERT INTO test3rf.test (k, v) VALUES (%s, %s)",
            consistency_level=ConsistencyLevel.QUORUM)
        parameters = [(i, i) for i in range(num_statements)]

        results = execute_concurrent_with_args(self.session, statement, parameters)
        self.assertEqual(num_statements, len(results))
        self.assertEqual([(True, None)] * num_statements, results)

        # read
        statement = SimpleStatement(
            "SELECT * FROM test3rf.test LIMIT %s",
            consistency_level=ConsistencyLevel.QUORUM,
            fetch_size=int(num_statements / 2))
        parameters = [(i, ) for i in range(num_statements)]

        results = execute_concurrent_with_args(self.session, statement, [(num_statements,)])
        self.assertEqual(1, len(results))
        self.assertTrue(results[0][0])
        result = results[0][1]
        self.assertIsInstance(result, PagedResult)
        self.assertEqual(num_statements, sum(1 for _ in result))

    def test_first_failure(self):
        statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
        parameters = [(i, i) for i in range(100)]

        # we'll get an error back from the server
        parameters[57] = ('efefef', 'awefawefawef')

        self.assertRaises(
            InvalidRequest,
            execute_concurrent, self.session, list(zip(statements, parameters)), raise_on_first_error=True)

    def test_first_failure_client_side(self):
        statement = SimpleStatement(
            "INSERT INTO test3rf.test (k, v) VALUES (%s, %s)",
            consistency_level=ConsistencyLevel.QUORUM)
        statements = cycle((statement, ))
        parameters = [(i, i) for i in range(100)]

        # the driver will raise an error when binding the params
        parameters[57] = 1

        self.assertRaises(
            TypeError,
            execute_concurrent, self.session, list(zip(statements, parameters)), raise_on_first_error=True)

    def test_no_raise_on_first_failure(self):
        statement = SimpleStatement(
            "INSERT INTO test3rf.test (k, v) VALUES (%s, %s)",
            consistency_level=ConsistencyLevel.QUORUM)
        statements = cycle((statement, ))
        parameters = [(i, i) for i in range(100)]

        # we'll get an error back from the server
        parameters[57] = ('efefef', 'awefawefawef')

        results = execute_concurrent(self.session, list(zip(statements, parameters)), raise_on_first_error=False)
        for i, (success, result) in enumerate(results):
            if i == 57:
                self.assertFalse(success)
                self.assertIsInstance(result, InvalidRequest)
            else:
                self.assertTrue(success)
                self.assertEqual(None, result)

    def test_no_raise_on_first_failure_client_side(self):
        statement = SimpleStatement(
            "INSERT INTO test3rf.test (k, v) VALUES (%s, %s)",
            consistency_level=ConsistencyLevel.QUORUM)
        statements = cycle((statement, ))
        parameters = [(i, i) for i in range(100)]

        # the driver will raise an error when binding the params
        parameters[57] = 1

        results = execute_concurrent(self.session, list(zip(statements, parameters)), raise_on_first_error=False)
        for i, (success, result) in enumerate(results):
            if i == 57:
                self.assertFalse(success)
                self.assertIsInstance(result, TypeError)
            else:
                self.assertTrue(success)
                self.assertEqual(None, result)