def step(self):

        # Connect to Cassandra
        cluster = Cluster(['192.168.3.2'],
                          port= 9042)

        session = cluster.connect()

        # Link to kafka
        consumer = KafkaConsumer('observation-persist',
                                 bootstrap_servers="192.168.3.5:9092")


        # Process observations
        for msg in consumer:
            split_msg = string.split(msg.value,"::")

            if(len(split_msg) == 16)    :

                session.execute(
                    """
                    INSERT INTO observation.observations_numeric (feature, procedure, observableproperty,
                    year, month, phenomenontimestart, phenomenontimeend, value, quality, accuracy, status,
                    processing, uncertml, comment, location, parameters)
                    VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
                    """,
                    (split_msg[0],split_msg[1],split_msg[2],int(split_msg[3]),int(split_msg[4]),int(split_msg[5]),int(split_msg[6]),
                     float(split_msg[7]),split_msg[8],float(split_msg[9]),split_msg[10],split_msg[11],split_msg[12],
                     split_msg[13],split_msg[14],split_msg[15])
                )

        # Close link to kafka
        consumer.close()
        cluster.shutdown()
예제 #2
0
    def test_udts_with_nulls(self):
        """
        Test UDTs with null and empty string fields.
        """
        c = Cluster(protocol_version=PROTOCOL_VERSION)
        s = c.connect()

        s.execute("""
            CREATE KEYSPACE test_udts_with_nulls
            WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' }
            """)
        s.set_keyspace("test_udts_with_nulls")
        s.execute("CREATE TYPE user (a text, b int, c uuid, d blob)")
        User = namedtuple('user', ('a', 'b', 'c', 'd'))
        c.register_user_type("test_udts_with_nulls", "user", User)

        s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen<user>)")

        insert = s.prepare("INSERT INTO mytable (a, b) VALUES (0, ?)")
        s.execute(insert, [User(None, None, None, None)])

        results = s.execute("SELECT b FROM mytable WHERE a=0")
        self.assertEqual((None, None, None, None), results[0].b)

        select = s.prepare("SELECT b FROM mytable WHERE a=0")
        self.assertEqual((None, None, None, None), s.execute(select)[0].b)

        # also test empty strings
        s.execute(insert, [User('', None, None, '')])
        results = s.execute("SELECT b FROM mytable WHERE a=0")
        self.assertEqual(('', None, None, ''), results[0].b)
        self.assertEqual(('', None, None, ''), s.execute(select)[0].b)

        c.shutdown()
def insert_rows(starting_partition, ending_partition, rows_per_partition, counter, counter_lock):
    cluster = Cluster(['127.0.0.1'], load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()))
    try:
        session = cluster.connect('ks')
        try:
            statement = session.prepare('INSERT INTO tbl (a, b, c, d) VALUES (?, ?, ?, ?)')
            for partition_key in xrange(starting_partition, ending_partition):
                batch = None
                batch_size = 0
                for cluster_column in xrange(rows_per_partition):
                    if batch is None:
                        batch = BatchStatement(batch_type=BatchType.UNLOGGED)
                    value1 = random.randint(1, 1000000)
                    value2 = random.randint(1, 1000000)
                    batch.add(statement, [partition_key, cluster_column, value1, value2])
                    batch_size += 1
                    if (batch_size == MAX_BATCH_SIZE) or (cluster_column + 1 == rows_per_partition):
                        with counter_lock:
                            counter.value += batch_size
                        session.execute(batch)
                        batch = None
                        batch_size = 0
        finally:
            session.shutdown()
    finally:
        cluster.shutdown()
예제 #4
0
def validate_ssl_options(ssl_options):
        # find absolute path to client CA_CERTS
        tries = 0
        while True:
            if tries > 5:
                raise RuntimeError("Failed to connect to SSL cluster after 5 attempts")
            try:
                cluster = Cluster(protocol_version=PROTOCOL_VERSION, ssl_options=ssl_options)
                session = cluster.connect()
                break
            except Exception:
                ex_type, ex, tb = sys.exc_info()
                log.warn("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb)))
                del tb
                tries += 1

        # attempt a few simple commands.
        insert_keyspace = """CREATE KEYSPACE ssltest
            WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'}
            """
        statement = SimpleStatement(insert_keyspace)
        statement.consistency_level = 3
        session.execute(statement)

        drop_keyspace = "DROP KEYSPACE ssltest"
        statement = SimpleStatement(drop_keyspace)
        statement.consistency_level = ConsistencyLevel.ANY
        session.execute(statement)

        cluster.shutdown()
예제 #5
0
    def test_cannot_connect_with_bad_client_auth(self):
        """
         Test to validate that we cannot connect with invalid client auth.

        This test will use bad keys/certs to preform client authentication. It will then attempt to connect
        to a server that has client authentication enabled.


        @since 2.7.0
        @expected_result The client will throw an exception on connect

        @test_category connection:ssl
        """

        # Setup absolute paths to key/cert files
        abs_path_ca_cert_path = os.path.abspath(CLIENT_CA_CERTS)
        abs_driver_keyfile = os.path.abspath(DRIVER_KEYFILE)
        abs_driver_certfile = os.path.abspath(DRIVER_CERTFILE_BAD)

        cluster = Cluster(protocol_version=PROTOCOL_VERSION, ssl_options={'ca_certs': abs_path_ca_cert_path,
                                                                          'ssl_version': ssl.PROTOCOL_TLSv1,
                                                                          'keyfile': abs_driver_keyfile,
                                                                          'certfile': abs_driver_certfile})
        with self.assertRaises(NoHostAvailable) as context:
            cluster.connect()
        cluster.shutdown()
    def test_white_list(self):
        use_singledc()
        keyspace = 'test_white_list'

        cluster = Cluster(('127.0.0.2',), load_balancing_policy=WhiteListRoundRobinPolicy((IP_FORMAT % 2,)),
                          protocol_version=PROTOCOL_VERSION, topology_event_refresh_window=0,
                          status_event_refresh_window=0)
        session = cluster.connect()
        self._wait_for_nodes_up([1, 2, 3])

        create_schema(cluster, session, keyspace)
        self._insert(session, keyspace)
        self._query(session, keyspace)

        self.coordinator_stats.assert_query_count_equals(self, 1, 0)
        self.coordinator_stats.assert_query_count_equals(self, 2, 12)
        self.coordinator_stats.assert_query_count_equals(self, 3, 0)

        # white list policy should not allow reconnecting to ignored hosts
        force_stop(3)
        self._wait_for_nodes_down([3])
        self.assertFalse(cluster.metadata._hosts[IP_FORMAT % 3].is_currently_reconnecting())

        self.coordinator_stats.reset_counts()
        force_stop(2)
        self._wait_for_nodes_down([2])

        try:
            self._query(session, keyspace)
            self.fail()
        except NoHostAvailable:
            pass

        cluster.shutdown()
예제 #7
0
def setup_test_keyspace():
    cluster = Cluster()
    session = cluster.connect()

    try:
        results = session.execute("SELECT keyspace_name FROM system.schema_keyspaces")
        existing_keyspaces = [row[0] for row in results]
        for ksname in ('test1rf', 'test2rf', 'test3rf'):
            if ksname in existing_keyspaces:
                session.execute("DROP KEYSPACE %s" % ksname)

        ddl = '''
            CREATE KEYSPACE test3rf
            WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'}'''
        session.execute(ddl)

        ddl = '''
            CREATE KEYSPACE test2rf
            WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '2'}'''
        session.execute(ddl)

        ddl = '''
            CREATE KEYSPACE test1rf
            WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}'''
        session.execute(ddl)

        ddl = '''
            CREATE TABLE test3rf.test (
                k int PRIMARY KEY,
                v int )'''
        session.execute(ddl)
    finally:
        cluster.shutdown()
    def test_numpy_results_paged(self):
        """
        Test Numpy-based parser that returns a NumPy array
        """
        # arrays = { 'a': arr1, 'b': arr2, ... }
        cluster = Cluster(protocol_version=PROTOCOL_VERSION)
        session = cluster.connect(keyspace="testspace")
        session.row_factory = tuple_factory
        session.client_protocol_handler = NumpyProtocolHandler
        session.default_fetch_size = 2

        expected_pages = (self.N_ITEMS + session.default_fetch_size - 1) // session.default_fetch_size

        self.assertLess(session.default_fetch_size, self.N_ITEMS)

        results = session.execute("SELECT * FROM test_table")

        self.assertTrue(results.has_more_pages)
        for count, page in enumerate(results, 1):
            self.assertIsInstance(page, dict)
            for colname, arr in page.items():
                if count <= expected_pages:
                    self.assertGreater(len(arr), 0, "page count: %d" % (count,))
                    self.assertLessEqual(len(arr), session.default_fetch_size)
                else:
                    # we get one extra item out of this iteration because of the way NumpyParser returns results
                    # The last page is returned as a dict with zero-length arrays
                    self.assertEqual(len(arr), 0)
            self.assertEqual(self._verify_numpy_page(page), len(arr))
        self.assertEqual(count, expected_pages + 1)  # see note about extra 'page' above

        cluster.shutdown()
def insert_into_cassandra(partition):         
    if partition:
        if (USE_REDIS):
            r1 = redis.StrictRedis(host=REDIS_NODE, port=6379, db=1) # find post by user on batch layer
            r2 = redis.StrictRedis(host=REDIS_NODE, port=6379, db=2) # find user by post on batch layer

        if (USE_CASSANDRA):
            cluster = Cluster(CASSANDRA_CLUSTER_IP_LIST)
            session = cluster.connect(KEY_SPACE)
            user_post_stmt = session.prepare("INSERT INTO user_post_table (user, created_utc, url, subreddit, title, year_month, body) VALUES (?,?,?,?,?,?,?)")
            post_user_stmt = session.prepare("INSERT INTO post_user_table (url, user, created_utc, subreddit, title, year_month, body) VALUES (?, ?, ?, ?, ?, ?, ?)")

        for item in partition:
            if (USE_REDIS):
                agg2Redis(r1, item[0], item[10])
                agg2Redis(r2, item[10], item[0])


            if (USE_CASSANDRA):
                                                # author  created_utc            url     subreddit  id   year_month body
                session.execute(user_post_stmt, (item[0], long(item[2]) * 1000, item[10], item[3], item[9], item[1], item[5]))
                session.execute(post_user_stmt, (item[10], item[0], long(item[2]) * 1000, item[3], item[9], item[1], item[5]))
        if (USE_CASSANDRA):
            session.shutdown()
            cluster.shutdown()
예제 #10
0
    def test_can_insert_nested_registered_udts_with_different_namedtuples(self):
        """
        Test for ensuring nested udts are inserted correctly when the
        created namedtuples are use names that are different the cql type.
        """

        c = Cluster(protocol_version=PROTOCOL_VERSION)
        s = c.connect(self.keyspace_name, wait_for_all_pools=True)
        s.row_factory = dict_factory

        MAX_NESTING_DEPTH = 16

        # create the schema
        self.nested_udt_schema_helper(s, MAX_NESTING_DEPTH)

        # create and register the seed udt type
        udts = []
        udt = namedtuple('level_0', ('age', 'name'))
        udts.append(udt)
        c.register_user_type(self.keyspace_name, "depth_0", udts[0])

        # create and register the nested udt types
        for i in range(MAX_NESTING_DEPTH):
            udt = namedtuple('level_{0}'.format(i + 1), ('value'))
            udts.append(udt)
            c.register_user_type(self.keyspace_name, "depth_{0}".format(i + 1), udts[i + 1])

        # insert udts and verify inserts with reads
        self.nested_udt_verification_helper(s, MAX_NESTING_DEPTH, udts)

        c.shutdown()
예제 #11
0
class CassandraClient(object):

    # Cassandra 2.1 only supports protocol versions 3 and lower.
    NATIVE_PROTOCOL_VERSION = 3

    def __init__(self, contact_points, user, password, keyspace):
        super(CassandraClient, self).__init__()
        self._cluster = None
        self._session = None
        self._cluster = Cluster(
            contact_points=contact_points,
            auth_provider=PlainTextAuthProvider(user, password),
            protocol_version=self.NATIVE_PROTOCOL_VERSION)
        self._session = self._connect(keyspace)

    def _connect(self, keyspace):
        if not self._cluster.is_shutdown:
            return self._cluster.connect(keyspace)
        else:
            raise Exception("Cannot perform this operation on a terminated "
                            "cluster.")

    @property
    def session(self):
        return self._session

    def __del__(self):
        if self._cluster is not None:
            self._cluster.shutdown()

        if self._session is not None:
            self._session.shutdown()
예제 #12
0
class ConnectionTimeoutTest(unittest.TestCase):

    def setUp(self):
        self.defaultInFlight = Connection.max_in_flight
        Connection.max_in_flight = 2
        self.cluster = Cluster(protocol_version=PROTOCOL_VERSION, load_balancing_policy=WhiteListRoundRobinPolicy(['127.0.0.1']))
        self.session = self.cluster.connect()

    def tearDown(self):
        Connection.max_in_flight = self.defaultInFlight
        self.cluster.shutdown()

    def test_in_flight_timeout(self):
        """
        Test to ensure that connection id fetching will block when max_id is reached/

        In previous versions of the driver this test will cause a
        NoHostAvailable exception to be thrown, when the max_id is restricted

        @since 3.3
        @jira_ticket PYTHON-514
        @expected_result When many requests are run on a single node connection acquisition should block
        until connection is available or the request times out.

        @test_category connection timeout
        """
        futures = []
        query = '''SELECT * FROM system.local'''
        for i in range(100):
            futures.append(self.session.execute_async(query))

        for future in futures:
            future.result()
예제 #13
0
    def test_can_insert_udts_with_nulls(self):
        """
        Test the insertion of UDTs with null and empty string fields
        """

        c = Cluster(protocol_version=PROTOCOL_VERSION)
        s = c.connect(self.keyspace_name, wait_for_all_pools=True)

        s.execute("CREATE TYPE user (a text, b int, c uuid, d blob)")
        User = namedtuple('user', ('a', 'b', 'c', 'd'))
        c.register_user_type(self.keyspace_name, "user", User)

        s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen<user>)")

        insert = s.prepare("INSERT INTO mytable (a, b) VALUES (0, ?)")
        s.execute(insert, [User(None, None, None, None)])

        results = s.execute("SELECT b FROM mytable WHERE a=0")
        self.assertEqual((None, None, None, None), results[0].b)

        select = s.prepare("SELECT b FROM mytable WHERE a=0")
        self.assertEqual((None, None, None, None), s.execute(select)[0].b)

        # also test empty strings
        s.execute(insert, [User('', None, None, six.binary_type())])
        results = s.execute("SELECT b FROM mytable WHERE a=0")
        self.assertEqual(('', None, None, six.binary_type()), results[0].b)

        c.shutdown()
예제 #14
0
    def test_submit_schema_refresh(self):
        """
        Ensure new new schema is refreshed after submit_schema_refresh()
        """

        cluster = Cluster(protocol_version=PROTOCOL_VERSION)
        cluster.connect()
        self.assertNotIn("newkeyspace", cluster.metadata.keyspaces)

        other_cluster = Cluster(protocol_version=PROTOCOL_VERSION)
        session = other_cluster.connect()
        session.execute(
            """
            CREATE KEYSPACE newkeyspace
            WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}
            """)

        future = cluster.submit_schema_refresh()
        future.result()

        self.assertIn("newkeyspace", cluster.metadata.keyspaces)

        session.execute("DROP KEYSPACE newkeyspace")
        cluster.shutdown()
        other_cluster.shutdown()
예제 #15
0
    def test_pool_management(self):
        # Ensure that in_flight and request_ids quiesce after cluster operations
        cluster = Cluster(protocol_version=PROTOCOL_VERSION, idle_heartbeat_interval=0)  # no idle heartbeat here, pool management is tested in test_idle_heartbeat
        session = cluster.connect()
        session2 = cluster.connect()

        # prepare
        p = session.prepare("SELECT * FROM system.local WHERE key=?")
        self.assertTrue(session.execute(p, ('local',)))

        # simple
        self.assertTrue(session.execute("SELECT * FROM system.local WHERE key='local'"))

        # set keyspace
        session.set_keyspace('system')
        session.set_keyspace('system_traces')

        # use keyspace
        session.execute('USE system')
        session.execute('USE system_traces')

        # refresh schema
        cluster.refresh_schema_metadata()
        cluster.refresh_schema_metadata(max_schema_agreement_wait=0)

        # submit schema refresh
        future = cluster.submit_schema_refresh()
        future.result()

        assert_quiescent_pool_state(self, cluster)

        cluster.shutdown()
예제 #16
0
class DuplicateRpcTest(unittest.TestCase):

    load_balancing_policy = WhiteListRoundRobinPolicy(['127.0.0.1'])

    def setUp(self):
        self.cluster = Cluster(protocol_version=PROTOCOL_VERSION, load_balancing_policy=self.load_balancing_policy)
        self.session = self.cluster.connect()
        self.session.execute("UPDATE system.peers SET rpc_address = '127.0.0.1' WHERE peer='127.0.0.2'")

    def tearDown(self):
        self.session.execute("UPDATE system.peers SET rpc_address = '127.0.0.2' WHERE peer='127.0.0.2'")
        self.cluster.shutdown()

    def test_duplicate(self):
        """
        Test duplicate RPC addresses.

        Modifies the system.peers table to make hosts have the same rpc address. Ensures such hosts are filtered out and a message is logged

        @since 3.4
        @jira_ticket PYTHON-366
        @expected_result only one hosts' metadata will be populated

        @test_category metadata
        """
        mock_handler = MockLoggingHandler()
        logger = logging.getLogger(cassandra.cluster.__name__)
        logger.addHandler(mock_handler)
        test_cluster = self.cluster = Cluster(protocol_version=PROTOCOL_VERSION, load_balancing_policy=self.load_balancing_policy)
        test_cluster.connect()
        warnings = mock_handler.messages.get("warning")
        self.assertEqual(len(warnings), 1)
        self.assertTrue('multiple' in warnings[0])
        logger.removeHandler(mock_handler)
예제 #17
0
 def test_connect_to_already_shutdown_cluster(self):
     """
     Ensure you cannot connect to a cluster that's been shutdown
     """
     cluster = Cluster(protocol_version=PROTOCOL_VERSION)
     cluster.shutdown()
     self.assertRaises(Exception, cluster.connect)
예제 #18
0
    def test_tuples_with_nulls(self):
        """
        Test tuples with null and empty string fields.
        """
        if self._cass_version < (2, 1, 0):
            raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")

        c = Cluster(protocol_version=PROTOCOL_VERSION)
        s = c.connect()

        s.execute("""CREATE KEYSPACE test_tuples_with_nulls
            WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")
        s.set_keyspace("test_tuples_with_nulls")

        s.execute("CREATE TABLE mytable (k int PRIMARY KEY, t tuple<text, int, uuid, blob>)")

        insert = s.prepare("INSERT INTO mytable (k, t) VALUES (0, ?)")
        s.execute(insert, [(None, None, None, None)])

        result = s.execute("SELECT * FROM mytable WHERE k=0")
        self.assertEquals((None, None, None, None), result[0].t)

        read = s.prepare("SELECT * FROM mytable WHERE k=0")
        self.assertEquals((None, None, None, None), s.execute(read)[0].t)

        # also test empty strings where compatible
        s.execute(insert, [('', None, None, '')])
        result = s.execute("SELECT * FROM mytable WHERE k=0")
        self.assertEquals(('', None, None, ''), result[0].t)
        self.assertEquals(('', None, None, ''), s.execute(read)[0].t)

        c.shutdown()
    def test_custom_raw_row_results_all_types(self):
        """
        Test to validate that custom protocol handlers work with varying types of
        results

        Connect, create a table with all sorts of data. Query the data, make the sure the custom results handler is
        used correctly.

        @since 2.7
        @jira_ticket PYTHON-313
        @expected_result custom protocol handler is invoked with various result types

        @test_category data_types:serialization
        """
        # Connect using a custom protocol handler that tracks the various types the result message is used with.
        session = Cluster(protocol_version=PROTOCOL_VERSION).connect(keyspace="custserdes")
        session.client_protocol_handler = CustomProtocolHandlerResultMessageTracked
        session.row_factory = tuple_factory

        colnames = create_table_with_all_types("alltypes", session, 1)
        columns_string = ", ".join(colnames)

        # verify data
        params = get_all_primitive_params(0)
        results = session.execute("SELECT {0} FROM alltypes WHERE primkey=0".format(columns_string))[0]
        for expected, actual in zip(params, results):
            self.assertEqual(actual, expected)
        # Ensure we have covered the various primitive types
        self.assertEqual(len(CustomResultMessageTracked.checked_rev_row_set), len(PRIMITIVE_DATATYPES)-1)
        session.shutdown()
    def test_none_values_dicts(self):
        """
        Ensure binding None is handled correctly with dict bindings
        """

        cluster = Cluster(protocol_version=PROTOCOL_VERSION)
        session = cluster.connect()

        # test with new dict binding
        prepared = session.prepare(
            """
            INSERT INTO test3rf.test (k, v) VALUES  (?, ?)
            """)

        self.assertIsInstance(prepared, PreparedStatement)
        bound = prepared.bind({'k': 1, 'v': None})
        session.execute(bound)

        prepared = session.prepare(
            """
            SELECT * FROM test3rf.test WHERE k=?
            """)
        self.assertIsInstance(prepared, PreparedStatement)

        bound = prepared.bind({'k': 1})
        results = session.execute(bound)
        self.assertEqual(results[0].v, None)

        cluster.shutdown()
예제 #21
0
    def _test_downgrading_cl(self, keyspace, rf, accepted):
        cluster = Cluster(
            load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()),
            default_retry_policy=DowngradingConsistencyRetryPolicy(),
            protocol_version=PROTOCOL_VERSION)
        session = cluster.connect(wait_for_all_pools=True)

        create_schema(cluster, session, keyspace, replication_factor=rf)
        self._insert(session, keyspace, 1)
        self._query(session, keyspace, 1)
        self.coordinator_stats.assert_query_count_equals(self, 1, 0)
        self.coordinator_stats.assert_query_count_equals(self, 2, 1)
        self.coordinator_stats.assert_query_count_equals(self, 3, 0)

        try:
            force_stop(2)
            wait_for_down(cluster, 2)

            self._assert_writes_succeed(session, keyspace, accepted)
            self._assert_reads_succeed(session, keyspace,
                                       accepted - set([ConsistencyLevel.ANY]))
            self._assert_writes_fail(session, keyspace,
                                     SINGLE_DC_CONSISTENCY_LEVELS - accepted)
            self._assert_reads_fail(session, keyspace,
                                    SINGLE_DC_CONSISTENCY_LEVELS - accepted)
        finally:
            start(2)
            wait_for_up(cluster, 2)

        cluster.shutdown()
    def test_async_binding_dicts(self):
        """
        Ensure None binding over async queries with dict bindings
        """

        cluster = Cluster(protocol_version=PROTOCOL_VERSION)
        session = cluster.connect()

        prepared = session.prepare(
            """
            INSERT INTO test3rf.test (k, v) VALUES  (?, ?)
            """)

        self.assertIsInstance(prepared, PreparedStatement)
        future = session.execute_async(prepared, {'k': 873, 'v': None})
        future.result()

        prepared = session.prepare(
            """
            SELECT * FROM test3rf.test WHERE k=?
            """)
        self.assertIsInstance(prepared, PreparedStatement)

        future = session.execute_async(prepared, {'k': 873})
        results = future.result()
        self.assertEqual(results[0].v, None)

        cluster.shutdown()
예제 #23
0
    def test_session_no_cluster(self):
        """
        Test session context without cluster context.

        @since 3.4
        @jira_ticket PYTHON-521
        @expected_result session should be created correctly. Session should shutdown correctly outside of context

        @test_category configuration
        """
        cluster = Cluster(**self.cluster_kwargs)
        unmanaged_session = cluster.connect()
        with cluster.connect() as session:
            self.assertFalse(cluster.is_shutdown)
            self.assertFalse(session.is_shutdown)
            self.assertFalse(unmanaged_session.is_shutdown)
            self.assertTrue(session.execute('select release_version from system.local')[0])
        self.assertTrue(session.is_shutdown)
        self.assertFalse(cluster.is_shutdown)
        self.assertFalse(unmanaged_session.is_shutdown)
        unmanaged_session.shutdown()
        self.assertTrue(unmanaged_session.is_shutdown)
        self.assertFalse(cluster.is_shutdown)
        cluster.shutdown()
        self.assertTrue(cluster.is_shutdown)
예제 #24
0
def setup(hosts):
    log.info("Using 'cassandra' package from %s", cassandra.__path__)

    cluster = Cluster(hosts)
    cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
    try:
        session = cluster.connect()

        log.debug("Creating keyspace...")
        session.execute("""
            CREATE KEYSPACE %s
            WITH replication = { 'class': 'SimpleStrategy', 'replication_factor': '2' }
            """ % KEYSPACE)

        log.debug("Setting keyspace...")
        session.set_keyspace(KEYSPACE)

        log.debug("Creating table...")
        session.execute("""
            CREATE TABLE %s (
                thekey text,
                col1 text,
                col2 text,
                PRIMARY KEY (thekey, col1)
            )
            """ % TABLE)
    finally:
        cluster.shutdown()
예제 #25
0
 def handle_noargs(self, **options):
     cluster = Cluster()
     session = cluster.connect()
     
     # Checking if keysapce exists
     query = "SELECT * FROM system.schema_keyspaces WHERE keyspace_name='%s';" % KEYSPACE_NAME
     result = session.execute(query)
     if len(result) != 0:
         msg = 'Looks like you already have a %s keyspace.\nDo you want to delete it and recreate it? All current data will be deleted! (y/n): ' % KEYSPACE_NAME
         resp = raw_input(msg)
         if not resp or resp[0] != 'y':
             print "Ok, then we're done here."
             return
         
         query = "DROP KEYSPACE %s" % KEYSPACE_NAME
         session.execute(query)
     
     # Creating keysapce
     query = "CREATE KEYSPACE tess WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': 1};"
     session.execute(query)
         
     # Creating tables
     query = "USE tess;"
     session.execute(query)
     
     query = "CREATE TABLE emotiv_eeg_record (test_id int, time double, AF3 double, F7 double, F3 double, FC5 double, T7 double, P7 double, O1 double, O2 double, P8 double, T8 double, FC6 double, F4 double, F8 double, AF4 double,  PRIMARY KEY (test_id, time));"
     session.execute(query)
     
     cluster.shutdown()
     
     print 'All done!'
예제 #26
0
    def test_can_insert_tuples_all_primitive_datatypes(self):
        """
        Ensure tuple subtypes are appropriately handled.
        """

        if self.cass_version < (2, 1, 0):
            raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")

        c = Cluster(protocol_version=PROTOCOL_VERSION)
        s = c.connect(self.keyspace_name)
        s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple

        s.execute("CREATE TABLE tuple_primitive ("
                  "k int PRIMARY KEY, "
                  "v frozen<tuple<%s>>)" % ','.join(PRIMITIVE_DATATYPES))

        values = []
        type_count = len(PRIMITIVE_DATATYPES)
        for i, data_type in enumerate(PRIMITIVE_DATATYPES):
            # create tuples to be written and ensure they match with the expected response
            # responses have trailing None values for every element that has not been written
            values.append(get_sample(data_type))
            expected = tuple(values + [None] * (type_count - len(values)))
            s.execute("INSERT INTO tuple_primitive (k, v) VALUES (%s, %s)", (i, tuple(values)))
            result = s.execute("SELECT v FROM tuple_primitive WHERE k=%s", (i,))[0]
            self.assertEqual(result.v, expected)
        c.shutdown()
    def test_raise_error_on_prepared_statement_execution_dropped_table(self):
        """
        test for error in executing prepared statement on a dropped table

        test_raise_error_on_execute_prepared_statement_dropped_table tests that an InvalidRequest is raised when a
        prepared statement is executed after its corresponding table is dropped. This happens because if a prepared
        statement is invalid, the driver attempts to automatically re-prepare it on a non-existing table.

        @expected_errors InvalidRequest If a prepared statement is executed on a dropped table

        @since 2.6.0
        @jira_ticket PYTHON-207
        @expected_result InvalidRequest error should be raised upon prepared statement execution.

        @test_category prepared_statements
        """
        cluster = Cluster(protocol_version=PROTOCOL_VERSION)
        session = cluster.connect("test3rf")

        session.execute("CREATE TABLE error_test (k int PRIMARY KEY, v int)")
        prepared = session.prepare("SELECT * FROM error_test WHERE k=?")
        session.execute("DROP TABLE error_test")

        with self.assertRaises(InvalidRequest):
            session.execute(prepared, [0])

        cluster.shutdown()
예제 #28
0
    def test_basic(self):
        cluster = Cluster()
        session = cluster.connect()
        result = session.execute(
            """
            CREATE KEYSPACE clustertests
            WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}
            """)
        self.assertEquals(None, result)

        result = session.execute(
            """
            CREATE TABLE clustertests.cf0 (
                a text,
                b text,
                c text,
                PRIMARY KEY (a, b)
            )
            """)
        self.assertEquals(None, result)

        result = session.execute(
            """
            INSERT INTO clustertests.cf0 (a, b, c) VALUES ('a', 'b', 'c')
            """)
        self.assertEquals(None, result)

        result = session.execute("SELECT * FROM clustertests.cf0")
        self.assertEquals([('a', 'b', 'c')], result)

        cluster.shutdown()
예제 #29
0
 def teardown_class(cls):
     cluster = Cluster(['127.0.0.1'])
     session = cluster.connect()
     try:
         session.execute("DROP KEYSPACE %s" % cls.ksname)
     finally:
         cluster.shutdown()
def insert_graph(rdd):
    if rdd:
        if (USE_REDIS):
            r4 = redis.StrictRedis(host=REDIS_NODE, port=6379, db=4) # read modify write on realtime graph
            for item in rdd:
                agg2graph(r4, item[0], item[1], item[2])
                agg2graph(r4, item[2], item[1], item[0])


        if (USE_CASSANDRA):
            cluster = Cluster(CASSANDRA_CLUSTER_IP_LIST)
            session = cluster.connect(KEY_SPACE)
            graph_stmt = session.prepare("INSERT INTO user_graph_realtime (user1, nCommonPosts, user2) VALUES (?,?,?)")

            for item in rdd:
                userPair = session.execute("SELECT * FROM user_graph_realtime WHERE user1=%s and user2=%s ALLOW FILTERING", parameters=[item[0], item[2]])
                if (userPair == None): # insert new entry into realtime graph
                    session.execute(graph_stmt, (item[0], int(item[1]), item[2]))
                    session.execute(graph_stmt, (item[2], int(item[1]), item[0]))
                else: # update entry in realtime graph
                    oldEdgeWeight = userPair.nCommonPosts
                    session.execute("UPDATE user_graph_realtime SET nCommonPosts=%d WHERE user1=%s and user2=%s ALLOW FILTERING", parameters=[int(item[1]) + int(oldEdgeWeight), item[0], item[2]])
                    session.execute("UPDATE user_graph_realtime SET nCommonPosts=%d WHERE user1=%s and user2=%s ALLOW FILTERING", parameters=[int(item[1]) + int(oldEdgeWeight), item[2], item[0]])

            session.shutdown()
            cluster.shutdown()
예제 #31
0
class DatastoreProxy(AppDBInterface):
  """ 
    Cassandra implementation of the AppDBInterface
  """
  def __init__(self, log_level=logging.INFO):
    """
    Constructor.
    """
    class_name = self.__class__.__name__
    self.logger = logging.getLogger(class_name)
    self.logger.setLevel(log_level)
    self.logger.info('Starting {}'.format(class_name))

    self.hosts = appscale_info.get_db_ips()
    self.retry_policy = IdempotentRetryPolicy()
    self.no_retries = FallthroughRetryPolicy()

    remaining_retries = INITIAL_CONNECT_RETRIES
    while True:
      try:
        self.cluster = Cluster(self.hosts,
                               default_retry_policy=self.retry_policy)
        self.session = self.cluster.connect(KEYSPACE)
        break
      except cassandra.cluster.NoHostAvailable as connection_error:
        remaining_retries -= 1
        if remaining_retries < 0:
          raise connection_error
        time.sleep(3)

    self.session.default_consistency_level = ConsistencyLevel.QUORUM

  def close(self):
    """ Close all sessions and connections to Cassandra. """
    self.cluster.shutdown()

  def batch_get_entity(self, table_name, row_keys, column_names):
    """
    Takes in batches of keys and retrieves their corresponding rows.
    
    Args:
      table_name: The table to access
      row_keys: A list of keys to access
      column_names: A list of columns to access
    Returns:
      A dictionary of rows and columns/values of those rows. The format 
      looks like such: {key:{column_name:value,...}}
    Raises:
      TypeError: If an argument passed in was not of the expected type.
      AppScaleDBConnectionError: If the batch_get could not be performed due to
        an error with Cassandra.
    """
    if not isinstance(table_name, str): raise TypeError("Expected a str")
    if not isinstance(column_names, list): raise TypeError("Expected a list")
    if not isinstance(row_keys, list): raise TypeError("Expected a list")

    row_keys_bytes = [bytearray(row_key) for row_key in row_keys]

    statement = 'SELECT * FROM "{table}" '\
                'WHERE {key} IN %s and {column} IN %s'.format(
                  table=table_name,
                  key=ThriftColumn.KEY,
                  column=ThriftColumn.COLUMN_NAME,
                )
    query = SimpleStatement(statement, retry_policy=self.retry_policy)
    parameters = (ValueSequence(row_keys_bytes), ValueSequence(column_names))

    try:
      results = self.session.execute(query, parameters=parameters)

      results_dict = {row_key: {} for row_key in row_keys}
      for (key, column, value) in results:
        if key not in results_dict:
          results_dict[key] = {}
        results_dict[key][column] = value

      return results_dict
    except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
      message = 'Exception during batch_get_entity'
      logging.exception(message)
      raise AppScaleDBConnectionError(message)

  def batch_put_entity(self, table_name, row_keys, column_names, cell_values,
                       ttl=None):
    """
    Allows callers to store multiple rows with a single call. A row can 
    have multiple columns and values with them. We refer to each row as 
    an entity.
   
    Args: 
      table_name: The table to mutate
      row_keys: A list of keys to store on
      column_names: A list of columns to mutate
      cell_values: A dict of key/value pairs
      ttl: The number of seconds to keep the row.
    Raises:
      TypeError: If an argument passed in was not of the expected type.
      AppScaleDBConnectionError: If the batch_put could not be performed due to
        an error with Cassandra.
    """
    if not isinstance(table_name, str):
      raise TypeError("Expected a str")
    if not isinstance(column_names, list):
      raise TypeError("Expected a list")
    if not isinstance(row_keys, list):
      raise TypeError("Expected a list")
    if not isinstance(cell_values, dict):
      raise TypeError("Expected a dict")

    insert_str = """
      INSERT INTO "{table}" ({key}, {column}, {value})
      VALUES (?, ?, ?)
    """.format(table=table_name,
               key=ThriftColumn.KEY,
               column=ThriftColumn.COLUMN_NAME,
               value=ThriftColumn.VALUE)

    if ttl is not None:
      insert_str += 'USING TTL {}'.format(ttl)

    statement = self.session.prepare(insert_str)

    statements_and_params = []
    for row_key in row_keys:
      for column in column_names:
        params = (bytearray(row_key), column,
                  bytearray(cell_values[row_key][column]))
        statements_and_params.append((statement, params))

    try:
      execute_concurrent(self.session, statements_and_params,
                         raise_on_first_error=True)
    except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
      message = 'Exception during batch_put_entity'
      logging.exception(message)
      raise AppScaleDBConnectionError(message)

  def prepare_insert(self, table):
    """ Prepare an insert statement.

    Args:
      table: A string containing the table name.
    Returns:
      A PreparedStatement object.
    """
    statement = """
      INSERT INTO "{table}" ({key}, {column}, {value})
      VALUES (?, ?, ?)
    """.format(table=table,
               key=ThriftColumn.KEY,
               column=ThriftColumn.COLUMN_NAME,
               value=ThriftColumn.VALUE)
    return self.session.prepare(statement)

  def prepare_delete(self, table):
    """ Prepare a delete statement.

    Args:
      table: A string containing the table name.
    Returns:
      A PreparedStatement object.
    """
    statement = """
      DELETE FROM "{table}" WHERE {key} = ?
    """.format(table=table, key=ThriftColumn.KEY)
    return self.session.prepare(statement)

  def _normal_batch(self, mutations):
    """ Use Cassandra's native batch statement to apply mutations atomically.

    Args:
      mutations: A list of dictionaries representing mutations.
    """
    self.logger.debug('Normal batch: {} mutations'.format(len(mutations)))
    batch = BatchStatement(consistency_level=ConsistencyLevel.QUORUM,
                           retry_policy=self.retry_policy)
    prepared_statements = {'insert': {}, 'delete': {}}
    for mutation in mutations:
      table = mutation['table']
      if mutation['operation'] == TxnActions.PUT:
        if table not in prepared_statements['insert']:
          prepared_statements['insert'][table] = self.prepare_insert(table)
        values = mutation['values']
        for column in values:
          batch.add(
            prepared_statements['insert'][table],
            (bytearray(mutation['key']), column, bytearray(values[column]))
          )
      elif mutation['operation'] == TxnActions.DELETE:
        if table not in prepared_statements['delete']:
          prepared_statements['delete'][table] = self.prepare_delete(table)
        batch.add(
          prepared_statements['delete'][table],
          (bytearray(mutation['key']),)
        )

    try:
      self.session.execute(batch)
    except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
      message = 'Exception during batch_mutate'
      logging.exception(message)
      raise AppScaleDBConnectionError(message)

  def apply_mutations(self, mutations):
    """ Apply mutations across tables.

    Args:
      mutations: A list of dictionaries representing mutations.
    """
    prepared_statements = {'insert': {}, 'delete': {}}
    statements_and_params = []
    for mutation in mutations:
      table = mutation['table']
      if mutation['operation'] == TxnActions.PUT:
        if table not in prepared_statements['insert']:
          prepared_statements['insert'][table] = self.prepare_insert(table)
        values = mutation['values']
        for column in values:
          params = (bytearray(mutation['key']), column,
                    bytearray(values[column]))
          statements_and_params.append(
            (prepared_statements['insert'][table], params))
      elif mutation['operation'] == TxnActions.DELETE:
        if table not in prepared_statements['delete']:
          prepared_statements['delete'][table] = self.prepare_delete(table)
        params = (bytearray(mutation['key']),)
        statements_and_params.append(
          (prepared_statements['delete'][table], params))

    execute_concurrent(self.session, statements_and_params,
                       raise_on_first_error=True)

  def _large_batch(self, app, mutations, entity_changes, txn):
    """ Insert or delete multiple rows across tables in an atomic statement.

    Args:
      app: A string containing the application ID.
      mutations: A list of dictionaries representing mutations.
      entity_changes: A list of changes at the entity level.
      txn: A transaction ID handler.
    Raises:
      FailedBatch if a concurrent process modifies the batch status.
      AppScaleDBConnectionError if a database connection error was encountered.
    """
    self.logger.debug('Large batch: transaction {}, {} mutations'.
                      format(txn, len(mutations)))
    set_status = """
      INSERT INTO batch_status (app, transaction, applied)
      VALUES (%(app)s, %(transaction)s, False)
      IF NOT EXISTS
    """
    parameters = {'app': app, 'transaction': txn}
    result = self.session.execute(set_status, parameters)
    if not result.was_applied:
      raise FailedBatch('A batch for transaction {} already exists'.
                        format(txn))

    insert_item = """
      INSERT INTO batches (app, transaction, namespace, path,
                           old_value, new_value)
      VALUES (?, ?, ?, ?, ?, ?)
    """
    insert_statement = self.session.prepare(insert_item)

    statements_and_params = []
    for entity_change in entity_changes:
      old_value = None
      if entity_change['old'] is not None:
        old_value = bytearray(entity_change['old'].Encode())
      new_value = None
      if entity_change['new'] is not None:
        new_value = bytearray(entity_change['new'].Encode())

      parameters = (app, txn, entity_change['key'].name_space(),
                    bytearray(entity_change['key'].path().Encode()), old_value,
                    new_value)
      statements_and_params.append((insert_statement, parameters))

    try:
      execute_concurrent(self.session, statements_and_params,
                         raise_on_first_error=True)
    except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
      message = 'Exception during large batch'
      logging.exception(message)
      raise AppScaleDBConnectionError(message)

    update_status = """
      UPDATE batch_status
      SET applied = True
      WHERE app = %(app)s
      AND transaction = %(transaction)s
      IF applied = False
    """
    parameters = {'app': app, 'transaction': txn}
    result = self.session.execute(update_status, parameters)
    if not result.was_applied:
      raise FailedBatch('Another process modified batch for transaction {}'.
                        format(txn))

    try:
      self.apply_mutations(mutations)
    except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
      message = 'Exception during large batch'
      logging.exception(message)
      raise AppScaleDBConnectionError(message)

    clear_batch = """
      DELETE FROM batches
      WHERE app = %(app)s AND transaction = %(transaction)s
    """
    parameters = {'app': app, 'transaction': txn}
    self.session.execute(clear_batch, parameters)

    clear_status = """
      DELETE FROM batch_status
      WHERE app = %(app)s and transaction = %(transaction)s
    """
    parameters = {'app': app, 'transaction': txn}
    self.session.execute(clear_status, parameters)

  def batch_mutate(self, app, mutations, entity_changes, txn):
    """ Insert or delete multiple rows across tables in an atomic statement.

    Args:
      app: A string containing the application ID.
      mutations: A list of dictionaries representing mutations.
      entity_changes: A list of changes at the entity level.
      txn: A transaction ID handler.
    """
    size = batch_size(mutations)
    self.logger.debug('batch_size: {}'.format(size))
    if size > LARGE_BATCH_THRESHOLD:
      self._large_batch(app, mutations, entity_changes, txn)
    else:
      self._normal_batch(mutations)

  def batch_delete(self, table_name, row_keys, column_names=()):
    """
    Remove a set of rows corresponding to a set of keys.
     
    Args:
      table_name: Table to delete rows from
      row_keys: A list of keys to remove
      column_names: Not used
    Raises:
      TypeError: If an argument passed in was not of the expected type.
      AppScaleDBConnectionError: If the batch_delete could not be performed due
        to an error with Cassandra.
    """ 
    if not isinstance(table_name, str): raise TypeError("Expected a str")
    if not isinstance(row_keys, list): raise TypeError("Expected a list")

    row_keys_bytes = [bytearray(row_key) for row_key in row_keys]

    statement = 'DELETE FROM "{table}" WHERE {key} IN %s'.\
      format(
        table=table_name,
        key=ThriftColumn.KEY
      )
    query = SimpleStatement(statement, retry_policy=self.retry_policy)
    parameters = (ValueSequence(row_keys_bytes),)

    try:
      self.session.execute(query, parameters=parameters)
    except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
      message = 'Exception during batch_delete'
      logging.exception(message)
      raise AppScaleDBConnectionError(message)

  def delete_table(self, table_name):
    """ 
    Drops a given table (aka column family in Cassandra)
  
    Args:
      table_name: A string name of the table to drop
    Raises:
      TypeError: If an argument passed in was not of the expected type.
      AppScaleDBConnectionError: If the delete_table could not be performed due
        to an error with Cassandra.
    """
    if not isinstance(table_name, str): raise TypeError("Expected a str")

    statement = 'DROP TABLE IF EXISTS "{table}"'.format(table=table_name)
    query = SimpleStatement(statement, retry_policy=self.retry_policy)

    try:
      self.session.execute(query)
    except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
      message = 'Exception during delete_table'
      logging.exception(message)
      raise AppScaleDBConnectionError(message)

  def create_table(self, table_name, column_names):
    """ 
    Creates a table if it doesn't already exist.
    
    Args:
      table_name: The column family name
      column_names: Not used but here to match the interface
    Raises:
      TypeError: If an argument passed in was not of the expected type.
      AppScaleDBConnectionError: If the create_table could not be performed due
        to an error with Cassandra.
    """
    if not isinstance(table_name, str): raise TypeError("Expected a str")
    if not isinstance(column_names, list): raise TypeError("Expected a list")

    statement = 'CREATE TABLE IF NOT EXISTS "{table}" ('\
        '{key} blob,'\
        '{column} text,'\
        '{value} blob,'\
        'PRIMARY KEY ({key}, {column})'\
      ') WITH COMPACT STORAGE'.format(
        table=table_name,
        key=ThriftColumn.KEY,
        column=ThriftColumn.COLUMN_NAME,
        value=ThriftColumn.VALUE
      )
    query = SimpleStatement(statement, retry_policy=self.no_retries)

    try:
      self.session.execute(query)
    except cassandra.OperationTimedOut:
      logging.warning('Encountered an operation timeout while creating a '
                      'table. Waiting 1 minute for schema to settle.')
      time.sleep(60)
      raise AppScaleDBConnectionError('Exception during create_table')
    except (error for error in dbconstants.TRANSIENT_CASSANDRA_ERRORS
            if error != cassandra.OperationTimedOut):
      message = 'Exception during create_table'
      logging.exception(message)
      raise AppScaleDBConnectionError(message)

  def range_query(self,
                  table_name,
                  column_names, 
                  start_key, 
                  end_key, 
                  limit, 
                  offset=0, 
                  start_inclusive=True, 
                  end_inclusive=True,
                  keys_only=False):
    """ 
    Gets a dense range ordered by keys. Returns an ordered list of 
    a dictionary of [key:{column1:value1, column2:value2},...]
    or a list of keys if keys only.
     
    Args:
      table_name: Name of table to access
      column_names: Columns which get returned within the key range
      start_key: String for which the query starts at
      end_key: String for which the query ends at
      limit: Maximum number of results to return
      offset: Cuts off these many from the results [offset:]
      start_inclusive: Boolean if results should include the start_key
      end_inclusive: Boolean if results should include the end_key
      keys_only: Boolean if to only keys and not values
    Raises:
      TypeError: If an argument passed in was not of the expected type.
      AppScaleDBConnectionError: If the range_query could not be performed due
        to an error with Cassandra.
    Returns:
      An ordered list of dictionaries of key=>columns/values
    """
    if not isinstance(table_name, str):
      raise TypeError('table_name must be a string')
    if not isinstance(column_names, list):
      raise TypeError('column_names must be a list')
    if not isinstance(start_key, str):
      raise TypeError('start_key must be a string')
    if not isinstance(end_key, str):
      raise TypeError('end_key must be a string')
    if not isinstance(limit, (int, long)) and limit is not None:
      raise TypeError('limit must be int, long, or NoneType')
    if not isinstance(offset, (int, long)):
      raise TypeError('offset must be int or long')

    if start_inclusive:
      gt_compare = '>='
    else:
      gt_compare = '>'

    if end_inclusive:
      lt_compare = '<='
    else:
      lt_compare = '<'

    query_limit = ''
    if limit is not None:
      query_limit = 'LIMIT {}'.format(len(column_names) * limit)

    statement = """
      SELECT * FROM "{table}" WHERE
      token({key}) {gt_compare} %s AND
      token({key}) {lt_compare} %s AND
      {column} IN %s
      {limit}
      ALLOW FILTERING
    """.format(table=table_name,
               key=ThriftColumn.KEY,
               gt_compare=gt_compare,
               lt_compare=lt_compare,
               column=ThriftColumn.COLUMN_NAME,
               limit=query_limit)

    query = SimpleStatement(statement, retry_policy=self.retry_policy)
    parameters = (bytearray(start_key), bytearray(end_key),
                  ValueSequence(column_names))

    try:
      results = self.session.execute(query, parameters=parameters)

      results_list = []
      current_item = {}
      current_key = None
      for (key, column, value) in results:
        if keys_only:
          results_list.append(key)
          continue

        if key != current_key:
          if current_item:
            results_list.append({current_key: current_item})
          current_item = {}
          current_key = key

        current_item[column] = value
      if current_item:
        results_list.append({current_key: current_item})
      return results_list[offset:]
    except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
      message = 'Exception during range_query'
      logging.exception(message)
      raise AppScaleDBConnectionError(message)

  def get_metadata(self, key):
    """ Retrieve a value from the datastore metadata table.

    Args:
      key: A string containing the key to fetch.
    Returns:
      A string containing the value or None if the key is not present.
    """
    statement = """
      SELECT {value} FROM "{table}"
      WHERE {key} = %s
      AND {column} = %s
    """.format(
      value=ThriftColumn.VALUE,
      table=dbconstants.DATASTORE_METADATA_TABLE,
      key=ThriftColumn.KEY,
      column=ThriftColumn.COLUMN_NAME
    )
    try:
      results = self.session.execute(statement, (bytearray(key), key))
    except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
      message = 'Unable to fetch {} from datastore metadata'.format(key)
      logging.exception(message)
      raise AppScaleDBConnectionError(message)

    try:
      return results[0].value
    except IndexError:
      return None

  def set_metadata(self, key, value):
    """ Set a datastore metadata value.

    Args:
      key: A string containing the key to set.
      value: A string containing the value to set.
    """
    if not isinstance(key, str):
      raise TypeError('key should be a string')

    if not isinstance(value, str):
      raise TypeError('value should be a string')

    statement = """
      INSERT INTO "{table}" ({key}, {column}, {value})
      VALUES (%(key)s, %(column)s, %(value)s)
    """.format(
      table=dbconstants.DATASTORE_METADATA_TABLE,
      key=ThriftColumn.KEY,
      column=ThriftColumn.COLUMN_NAME,
      value=ThriftColumn.VALUE
    )
    parameters = {'key': bytearray(key),
                  'column': key,
                  'value': bytearray(value)}
    try:
      self.session.execute(statement, parameters)
    except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
      message = 'Unable to set datastore metadata for {}'.format(key)
      logging.exception(message)
      raise AppScaleDBConnectionError(message)
    except cassandra.InvalidRequest:
      self.create_table(dbconstants.DATASTORE_METADATA_TABLE,
                        dbconstants.DATASTORE_METADATA_SCHEMA)
      self.session.execute(statement, parameters)

  def get_indices(self, app_id):
    """ Gets the indices of the given application.

    Args:
      app_id: Name of the application.
    Returns:
      Returns a list of encoded entity_pb.CompositeIndex objects.
    """
    start_key = dbconstants.KEY_DELIMITER.join([app_id, 'index', ''])
    end_key = dbconstants.KEY_DELIMITER.join(
      [app_id, 'index', dbconstants.TERMINATING_STRING])
    result = self.range_query(
      dbconstants.METADATA_TABLE,
      dbconstants.METADATA_SCHEMA,
      start_key,
      end_key,
      dbconstants.MAX_NUMBER_OF_COMPOSITE_INDEXES,
      offset=0,
      start_inclusive=True,
      end_inclusive=True)
    list_result = []
    for list_item in result:
      for key, value in list_item.iteritems():
        list_result.append(value['data'])
    return list_result

  def valid_data_version(self):
    """ Checks whether or not the data layout can be used.

    Returns:
      A boolean.
    """
    try:
      version = self.get_metadata(VERSION_INFO_KEY)
    except cassandra.InvalidRequest:
      return False

    return version is not None and float(version) == EXPECTED_DATA_VERSION
                fr = bool(row[6])
                price = float(row[7])
                session.execute(
                    """
                INSERT INTO Application (id, name, category, rating, reviews, size, installs, free, price_dollar, content_rating, genres, last_update, current_ver, android_ver)
                values (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
                """,
                    (uuid.uuid1(), row[0], row[1], rating, rev, row[4], row[5],
                     fr, price, row[8], row[9], row[10], row[11], row[12]))
            line_count += 1
    return


if __name__ == "__main__":
    args = parse_arguments()
    if args.file is None or not os.path.exists(args.file):
        exit(0)
    #Connect cluster
    auth_provider = PlainTextAuthProvider(username='******',
                                          password='******')
    cluster = Cluster(['cassan2'], auth_provider=auth_provider, port=9042)
    session = cluster.connect()

    #Enter in the keyspace
    session.execute("USE Customer1;")

    fill_database(session, args.file)

    #Close cluster
    cluster.shutdown()
class PreparedStatementTests(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.cass_version = get_server_versions()

    def setUp(self):
        self.cluster = Cluster(metrics_enabled=True,
                               protocol_version=PROTOCOL_VERSION)
        self.session = self.cluster.connect()

    def tearDown(self):
        self.cluster.shutdown()

    def test_basic(self):
        """
        Test basic PreparedStatement usage
        """
        self.session.execute("""
            DROP KEYSPACE IF EXISTS preparedtests
            """)
        self.session.execute("""
            CREATE KEYSPACE preparedtests
            WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}
            """)

        self.session.set_keyspace("preparedtests")
        self.session.execute("""
            CREATE TABLE cf0 (
                a text,
                b text,
                c text,
                PRIMARY KEY (a, b)
            )
            """)

        prepared = self.session.prepare("""
            INSERT INTO cf0 (a, b, c) VALUES  (?, ?, ?)
            """)

        self.assertIsInstance(prepared, PreparedStatement)
        bound = prepared.bind(('a', 'b', 'c'))

        self.session.execute(bound)

        prepared = self.session.prepare("""
            SELECT * FROM cf0 WHERE a=?
            """)
        self.assertIsInstance(prepared, PreparedStatement)

        bound = prepared.bind(('a'))
        results = self.session.execute(bound)
        self.assertEqual(results, [('a', 'b', 'c')])

        # test with new dict binding
        prepared = self.session.prepare("""
            INSERT INTO cf0 (a, b, c) VALUES  (?, ?, ?)
            """)

        self.assertIsInstance(prepared, PreparedStatement)
        bound = prepared.bind({'a': 'x', 'b': 'y', 'c': 'z'})

        self.session.execute(bound)

        prepared = self.session.prepare("""
            SELECT * FROM cf0 WHERE a=?
            """)

        self.assertIsInstance(prepared, PreparedStatement)

        bound = prepared.bind({'a': 'x'})
        results = self.session.execute(bound)
        self.assertEqual(results, [('x', 'y', 'z')])

    def test_missing_primary_key(self):
        """
        Ensure an InvalidRequest is thrown
        when prepared statements are missing the primary key
        """

        self._run_missing_primary_key(self.session)

    def _run_missing_primary_key(self, session):
        statement_to_prepare = """INSERT INTO test3rf.test (v) VALUES  (?)"""
        # logic needed work with changes in CASSANDRA-6237
        if self.cass_version[0] >= (3, 0, 0):
            self.assertRaises(InvalidRequest, session.prepare,
                              statement_to_prepare)
        else:
            prepared = session.prepare(statement_to_prepare)
            self.assertIsInstance(prepared, PreparedStatement)
            bound = prepared.bind((1, ))
            self.assertRaises(InvalidRequest, session.execute, bound)

    def test_missing_primary_key_dicts(self):
        """
        Ensure an InvalidRequest is thrown
        when prepared statements are missing the primary key
        with dict bindings
        """
        self._run_missing_primary_key_dicts(self.session)

    def _run_missing_primary_key_dicts(self, session):
        statement_to_prepare = """ INSERT INTO test3rf.test (v) VALUES  (?)"""
        # logic needed work with changes in CASSANDRA-6237
        if self.cass_version[0] >= (3, 0, 0):
            self.assertRaises(InvalidRequest, session.prepare,
                              statement_to_prepare)
        else:
            prepared = session.prepare(statement_to_prepare)
            self.assertIsInstance(prepared, PreparedStatement)
            bound = prepared.bind({'v': 1})
            self.assertRaises(InvalidRequest, session.execute, bound)

    def test_too_many_bind_values(self):
        """
        Ensure a ValueError is thrown when attempting to bind too many variables
        """
        self._run_too_many_bind_values(self.session)

    def _run_too_many_bind_values(self, session):
        statement_to_prepare = """ INSERT INTO test3rf.test (v) VALUES  (?)"""
        # logic needed work with changes in CASSANDRA-6237
        if self.cass_version[0] >= (3, 0, 0):
            self.assertRaises(InvalidRequest, session.prepare,
                              statement_to_prepare)
        else:
            prepared = session.prepare(statement_to_prepare)
            self.assertIsInstance(prepared, PreparedStatement)
            self.assertRaises(ValueError, prepared.bind, (1, 2))

    def test_imprecise_bind_values_dicts(self):
        """
        Ensure an error is thrown when attempting to bind the wrong values
        with dict bindings
        """

        prepared = self.session.prepare("""
            INSERT INTO test3rf.test (k, v) VALUES  (?, ?)
            """)

        self.assertIsInstance(prepared, PreparedStatement)

        # too many values is ok - others are ignored
        prepared.bind({'k': 1, 'v': 2, 'v2': 3})

        # right number, but one does not belong
        if PROTOCOL_VERSION < 4:
            # pre v4, the driver bails with key error when 'v' is found missing
            self.assertRaises(KeyError, prepared.bind, {'k': 1, 'v2': 3})
        else:
            # post v4, the driver uses UNSET_VALUE for 'v' and 'v2' is ignored
            prepared.bind({'k': 1, 'v2': 3})

        # also catch too few variables with dicts
        self.assertIsInstance(prepared, PreparedStatement)
        if PROTOCOL_VERSION < 4:
            self.assertRaises(KeyError, prepared.bind, {})
        else:
            # post v4, the driver attempts to use UNSET_VALUE for unspecified keys
            self.assertRaises(ValueError, prepared.bind, {})

    def test_none_values(self):
        """
        Ensure binding None is handled correctly
        """

        prepared = self.session.prepare("""
            INSERT INTO test3rf.test (k, v) VALUES  (?, ?)
            """)

        self.assertIsInstance(prepared, PreparedStatement)
        bound = prepared.bind((1, None))
        self.session.execute(bound)

        prepared = self.session.prepare("""
            SELECT * FROM test3rf.test WHERE k=?
            """)
        self.assertIsInstance(prepared, PreparedStatement)

        bound = prepared.bind((1, ))
        results = self.session.execute(bound)
        self.assertEqual(results[0].v, None)

    def test_unset_values(self):
        """
        Test to validate that UNSET_VALUEs are bound, and have the expected effect

        Prepare a statement and insert all values. Then follow with execute excluding
        parameters. Verify that the original values are unaffected.

        @since 2.6.0

        @jira_ticket PYTHON-317
        @expected_result UNSET_VALUE is implicitly added to bind parameters, and properly encoded, leving unset values unaffected.

        @test_category prepared_statements:binding
        """
        if PROTOCOL_VERSION < 4:
            raise unittest.SkipTest(
                "Binding UNSET values is not supported in protocol version < 4"
            )

        # table with at least two values so one can be used as a marker
        self.session.execute(
            "CREATE TABLE IF NOT EXISTS test1rf.test_unset_values (k int PRIMARY KEY, v0 int, v1 int)"
        )
        insert = self.session.prepare(
            "INSERT INTO test1rf.test_unset_values (k, v0, v1) VALUES  (?, ?, ?)"
        )
        select = self.session.prepare(
            "SELECT * FROM test1rf.test_unset_values WHERE k=?")

        bind_expected = [
            # initial condition
            ((0, 0, 0), (0, 0, 0)),
            # unset implicit
            ((
                0,
                1,
            ), (0, 1, 0)),
            ({
                'k': 0,
                'v0': 2
            }, (0, 2, 0)),
            ({
                'k': 0,
                'v1': 1
            }, (0, 2, 1)),
            # unset explicit
            ((0, 3, UNSET_VALUE), (0, 3, 1)),
            ((0, UNSET_VALUE, 2), (0, 3, 2)),
            ({
                'k': 0,
                'v0': 4,
                'v1': UNSET_VALUE
            }, (0, 4, 2)),
            ({
                'k': 0,
                'v0': UNSET_VALUE,
                'v1': 3
            }, (0, 4, 3)),
            # nulls still work
            ((0, None, None), (0, None, None)),
        ]

        for params, expected in bind_expected:
            self.session.execute(insert, params)
            results = self.session.execute(select, (0, ))
            self.assertEqual(results[0], expected)

        self.assertRaises(ValueError, self.session.execute, select,
                          (UNSET_VALUE, 0, 0))

    def test_no_meta(self):

        prepared = self.session.prepare("""
            INSERT INTO test3rf.test (k, v) VALUES  (0, 0)
            """)

        self.assertIsInstance(prepared, PreparedStatement)
        bound = prepared.bind(None)
        bound.consistency_level = ConsistencyLevel.ALL
        self.session.execute(bound)

        prepared = self.session.prepare("""
            SELECT * FROM test3rf.test WHERE k=0
            """)
        self.assertIsInstance(prepared, PreparedStatement)

        bound = prepared.bind(None)
        bound.consistency_level = ConsistencyLevel.ALL
        results = self.session.execute(bound)
        self.assertEqual(results[0].v, 0)

    def test_none_values_dicts(self):
        """
        Ensure binding None is handled correctly with dict bindings
        """

        # test with new dict binding
        prepared = self.session.prepare("""
            INSERT INTO test3rf.test (k, v) VALUES  (?, ?)
            """)

        self.assertIsInstance(prepared, PreparedStatement)
        bound = prepared.bind({'k': 1, 'v': None})
        self.session.execute(bound)

        prepared = self.session.prepare("""
            SELECT * FROM test3rf.test WHERE k=?
            """)
        self.assertIsInstance(prepared, PreparedStatement)

        bound = prepared.bind({'k': 1})
        results = self.session.execute(bound)
        self.assertEqual(results[0].v, None)

    def test_async_binding(self):
        """
        Ensure None binding over async queries
        """

        prepared = self.session.prepare("""
            INSERT INTO test3rf.test (k, v) VALUES  (?, ?)
            """)

        self.assertIsInstance(prepared, PreparedStatement)
        future = self.session.execute_async(prepared, (873, None))
        future.result()

        prepared = self.session.prepare("""
            SELECT * FROM test3rf.test WHERE k=?
            """)
        self.assertIsInstance(prepared, PreparedStatement)

        future = self.session.execute_async(prepared, (873, ))
        results = future.result()
        self.assertEqual(results[0].v, None)

    def test_async_binding_dicts(self):
        """
        Ensure None binding over async queries with dict bindings
        """
        prepared = self.session.prepare("""
            INSERT INTO test3rf.test (k, v) VALUES  (?, ?)
            """)

        self.assertIsInstance(prepared, PreparedStatement)
        future = self.session.execute_async(prepared, {'k': 873, 'v': None})
        future.result()

        prepared = self.session.prepare("""
            SELECT * FROM test3rf.test WHERE k=?
            """)
        self.assertIsInstance(prepared, PreparedStatement)

        future = self.session.execute_async(prepared, {'k': 873})
        results = future.result()
        self.assertEqual(results[0].v, None)

    def test_raise_error_on_prepared_statement_execution_dropped_table(self):
        """
        test for error in executing prepared statement on a dropped table

        test_raise_error_on_execute_prepared_statement_dropped_table tests that an InvalidRequest is raised when a
        prepared statement is executed after its corresponding table is dropped. This happens because if a prepared
        statement is invalid, the driver attempts to automatically re-prepare it on a non-existing table.

        @expected_errors InvalidRequest If a prepared statement is executed on a dropped table

        @since 2.6.0
        @jira_ticket PYTHON-207
        @expected_result InvalidRequest error should be raised upon prepared statement execution.

        @test_category prepared_statements
        """

        self.session.execute(
            "CREATE TABLE test3rf.error_test (k int PRIMARY KEY, v int)")
        prepared = self.session.prepare(
            "SELECT * FROM test3rf.error_test WHERE k=?")
        self.session.execute("DROP TABLE test3rf.error_test")

        with self.assertRaises(InvalidRequest):
            self.session.execute(prepared, [0])

    # TODO revisit this test
    @unittest.skip
    def test_invalidated_result_metadata(self):
        """
        Tests to make sure cached metadata is updated when an invalidated prepared statement is reprepared.

        @since 2.7.0
        @jira_ticket PYTHON-621

        Prior to this fix, the request would blow up with a protocol error when the result was decoded expecting a different
        number of columns.
        """
        s = self.session
        s.result_factory = tuple_factory

        table = "test1rf.%s" % self._testMethodName.lower()

        s.execute("DROP TABLE IF EXISTS %s" % table)
        s.execute("CREATE TABLE %s (k int PRIMARY KEY, a int, b int, c int)" %
                  table)
        s.execute("INSERT INTO %s (k, a, b, c) VALUES (0, 0, 0, 0)" % table)

        wildcard_prepared = s.prepare("SELECT * FROM %s" % table)
        original_result_metadata = wildcard_prepared.result_metadata
        self.assertEqual(len(original_result_metadata), 4)

        r = s.execute(wildcard_prepared)
        self.assertEqual(r[0], (0, 0, 0, 0))

        s.execute("ALTER TABLE %s DROP c" % table)

        # Get a bunch of requests in the pipeline with varying states of result_meta, reprepare, resolved
        futures = set(
            s.execute_async(wildcard_prepared.bind(None)) for _ in range(200))
        for f in futures:

            self.assertEqual(f.result()[0], (0, 0, 0))
        self.assertIsNot(wildcard_prepared.result_metadata,
                         original_result_metadata)
        s.execute("DROP TABLE %s" % table)
예제 #34
0
    def test_export_keyspace_schema_udts(self):
        """
        Test udt exports
        """

        if get_server_versions()[0] < (2, 1, 0):
            raise unittest.SkipTest('UDTs were introduced in Cassandra 2.1')

        if PROTOCOL_VERSION < 3:
            raise unittest.SkipTest(
                "Protocol 3.0+ is required for UDT change events, currently testing against %r"
                % (PROTOCOL_VERSION, ))

        if sys.version_info[2:] != (2, 7):
            raise unittest.SkipTest(
                'This test compares static strings generated from dict items, which may change orders. Test with 2.7.'
            )

        cluster = Cluster(protocol_version=PROTOCOL_VERSION)
        session = cluster.connect()

        session.execute("""
            CREATE KEYSPACE export_udts
            WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}
            AND durable_writes = true;
        """)
        session.execute("""
            CREATE TYPE export_udts.street (
                street_number int,
                street_name text)
        """)
        session.execute("""
            CREATE TYPE export_udts.zip (
                zipcode int,
                zip_plus_4 int)
        """)
        session.execute("""
            CREATE TYPE export_udts.address (
                street_address frozen<street>,
                zip_code frozen<zip>)
        """)
        session.execute("""
            CREATE TABLE export_udts.users (
            user text PRIMARY KEY,
            addresses map<text, frozen<address>>)
        """)

        expected_string = """CREATE KEYSPACE export_udts WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}  AND durable_writes = true;

CREATE TYPE export_udts.street (
    street_number int,
    street_name text
);

CREATE TYPE export_udts.zip (
    zipcode int,
    zip_plus_4 int
);

CREATE TYPE export_udts.address (
    street_address frozen<street>,
    zip_code frozen<zip>
);

CREATE TABLE export_udts.users (
    user text PRIMARY KEY,
    addresses map<text, frozen<address>>
) WITH bloom_filter_fp_chance = 0.01
    AND caching = '{"keys":"ALL", "rows_per_partition":"NONE"}'
    AND comment = ''
    AND compaction = {'min_threshold': '4', 'class': 'org.apache.cassandra.db.compaction.SizeTieredCompactionStrategy', 'max_threshold': '32'}
    AND compression = {'sstable_compression': 'org.apache.cassandra.io.compress.LZ4Compressor'}
    AND dclocal_read_repair_chance = 0.1
    AND default_time_to_live = 0
    AND gc_grace_seconds = 864000
    AND max_index_interval = 2048
    AND memtable_flush_period_in_ms = 0
    AND min_index_interval = 128
    AND read_repair_chance = 0.0
    AND speculative_retry = '99.0PERCENTILE';"""

        self.assert_equal_diff(
            cluster.metadata.keyspaces['export_udts'].export_as_string(),
            expected_string)

        table_meta = cluster.metadata.keyspaces['export_udts'].tables['users']

        expected_string = """CREATE TABLE export_udts.users (
    user text PRIMARY KEY,
    addresses map<text, frozen<address>>
) WITH bloom_filter_fp_chance = 0.01
    AND caching = '{"keys":"ALL", "rows_per_partition":"NONE"}'
    AND comment = ''
    AND compaction = {'min_threshold': '4', 'class': 'org.apache.cassandra.db.compaction.SizeTieredCompactionStrategy', 'max_threshold': '32'}
    AND compression = {'sstable_compression': 'org.apache.cassandra.io.compress.LZ4Compressor'}
    AND dclocal_read_repair_chance = 0.1
    AND default_time_to_live = 0
    AND gc_grace_seconds = 864000
    AND max_index_interval = 2048
    AND memtable_flush_period_in_ms = 0
    AND min_index_interval = 128
    AND read_repair_chance = 0.0
    AND speculative_retry = '99.0PERCENTILE';"""

        self.assert_equal_diff(table_meta.export_as_string(), expected_string)

        cluster.shutdown()
예제 #35
0
 def test_set_keyspace_twice(self):
     cluster = Cluster(protocol_version=PROTOCOL_VERSION)
     session = cluster.connect()
     session.execute("USE system")
     session.execute("USE system")
     cluster.shutdown()
예제 #36
0
class TimeoutTimerTest(unittest.TestCase):
    def setUp(self):
        """
        Setup sessions and pause node1
        """
        self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
        self.session = self.cluster.connect()

        # self.node1, self.node2, self.node3 = get_cluster().nodes.values()
        self.node1 = get_node(1)
        self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
        self.session = self.cluster.connect()

        ddl = '''
            CREATE TABLE test3rf.timeout (
                k int PRIMARY KEY,
                v int )'''
        self.session.execute(ddl)
        self.node1.pause()

    def tearDown(self):
        """
        Shutdown cluster and resume node1
        """
        self.node1.resume()
        self.session.execute("DROP TABLE test3rf.timeout")
        self.cluster.shutdown()

    def test_async_timeouts(self):
        """
        Test to validate that timeouts are honored


        Exercise the underlying timeouts, by attempting a query that will timeout. Ensure the default timeout is still
        honored. Make sure that user timeouts are also honored.

        @since 2.7.0
        @jira_ticket PYTHON-108
        @expected_result timeouts should be honored

        @test_category

        """

        # Because node1 is stopped these statements will all timeout
        ss = SimpleStatement('SELECT * FROM test3rf.test',
                             consistency_level=ConsistencyLevel.ALL)

        # Test with default timeout (should be 10)
        start_time = time.time()
        future = self.session.execute_async(ss)
        with self.assertRaises(OperationTimedOut):
            future.result()
        end_time = time.time()
        total_time = end_time - start_time
        expected_time = self.session.default_timeout
        # check timeout and ensure it's within a reasonable range
        self.assertAlmostEqual(expected_time, total_time, delta=.05)

        # Test with user defined timeout (Should be 1)
        start_time = time.time()
        future = self.session.execute_async(ss, timeout=1)
        mock_callback = Mock(return_value=None)
        mock_errorback = Mock(return_value=None)
        future.add_callback(mock_callback)
        future.add_errback(mock_errorback)

        with self.assertRaises(OperationTimedOut):
            future.result()
        end_time = time.time()
        total_time = end_time - start_time
        expected_time = 1
        # check timeout and ensure it's within a reasonable range
        self.assertAlmostEqual(expected_time, total_time, delta=.05)
        self.assertTrue(mock_errorback.called)
        self.assertFalse(mock_callback.called)
예제 #37
0
class TimeoutTimerTest(unittest.TestCase):
    def setUp(self):
        """
        Setup sessions and pause node1
        """
        self.cluster = Cluster(
            protocol_version=PROTOCOL_VERSION,
            execution_profiles={
                EXEC_PROFILE_DEFAULT:
                ExecutionProfile(load_balancing_policy=HostFilterPolicy(
                    RoundRobinPolicy(),
                    lambda host: host.address == "127.0.0.1"))
            })

        self.session = self.cluster.connect(wait_for_all_pools=True)

        self.control_connection_host_number = 1
        self.node_to_stop = get_node(self.control_connection_host_number)

        ddl = '''
            CREATE TABLE test3rf.timeout (
                k int PRIMARY KEY,
                v int )'''
        self.session.execute(ddl)
        self.node_to_stop.pause()

    def tearDown(self):
        """
        Shutdown cluster and resume node1
        """
        self.node_to_stop.resume()
        self.session.execute("DROP TABLE test3rf.timeout")
        self.cluster.shutdown()

    def test_async_timeouts(self):
        """
        Test to validate that timeouts are honored


        Exercise the underlying timeouts, by attempting a query that will timeout. Ensure the default timeout is still
        honored. Make sure that user timeouts are also honored.

        @since 2.7.0
        @jira_ticket PYTHON-108
        @expected_result timeouts should be honored

        @test_category

        """

        # Because node1 is stopped these statements will all timeout
        ss = SimpleStatement('SELECT * FROM test3rf.test',
                             consistency_level=ConsistencyLevel.ALL)

        # Test with default timeout (should be 10)
        start_time = time.time()
        future = self.session.execute_async(ss)
        with self.assertRaises(OperationTimedOut):
            future.result()
        end_time = time.time()
        total_time = end_time - start_time
        expected_time = self.cluster.profile_manager.default.request_timeout
        # check timeout and ensure it's within a reasonable range
        self.assertAlmostEqual(expected_time, total_time, delta=.05)

        # Test with user defined timeout (Should be 1)
        expected_time = 1
        start_time = time.time()
        future = self.session.execute_async(ss, timeout=expected_time)
        mock_callback = Mock(return_value=None)
        mock_errorback = Mock(return_value=None)
        future.add_callback(mock_callback)
        future.add_errback(mock_errorback)

        with self.assertRaises(OperationTimedOut):
            future.result()
        end_time = time.time()
        total_time = end_time - start_time
        # check timeout and ensure it's within a reasonable range
        self.assertAlmostEqual(expected_time, total_time, delta=.05)
        self.assertTrue(mock_errorback.called)
        self.assertFalse(mock_callback.called)
예제 #38
0
def benchmark(thread_class):
    options, args = parse_options()
    for conn_class in options.supported_reactors:
        setup(options.hosts)
        log.info("==== %s ====" % (conn_class.__name__, ))

        kwargs = {
            'metrics_enabled': options.enable_metrics,
            'connection_class': conn_class
        }
        if options.protocol_version:
            kwargs['protocol_version'] = options.protocol_version
        cluster = Cluster(options.hosts, **kwargs)
        session = cluster.connect(KEYSPACE)

        log.debug("Sleeping for two seconds...")
        time.sleep(2.0)

        query = session.prepare("""
            INSERT INTO {table} (thekey, col1, col2) VALUES (?, ?, ?)
            """.format(table=TABLE))
        values = ('key', 'a', 'b')

        per_thread = options.num_ops // options.threads
        threads = []

        log.debug("Beginning inserts...")
        start = time.time()
        try:
            for i in range(options.threads):
                thread = thread_class(i, session, query, values, per_thread,
                                      cluster.protocol_version,
                                      options.profile)
                thread.daemon = True
                threads.append(thread)

            for thread in threads:
                thread.start()

            for thread in threads:
                while thread.is_alive():
                    thread.join(timeout=0.5)

            end = time.time()
        finally:
            cluster.shutdown()
            teardown(options.hosts)

        total = end - start
        log.info("Total time: %0.2fs" % total)
        log.info("Average throughput: %0.2f/sec" % (options.num_ops / total))
        if options.enable_metrics:
            stats = scales.getStats()['cassandra']
            log.info("Connection errors: %d", stats['connection_errors'])
            log.info("Write timeouts: %d", stats['write_timeouts'])
            log.info("Read timeouts: %d", stats['read_timeouts'])
            log.info("Unavailables: %d", stats['unavailables'])
            log.info("Other errors: %d", stats['other_errors'])
            log.info("Retries: %d", stats['retries'])

            request_timer = stats['request_timer']
            log.info("Request latencies:")
            log.info("  min: %0.4fs", request_timer['min'])
            log.info("  max: %0.4fs", request_timer['max'])
            log.info("  mean: %0.4fs", request_timer['mean'])
            log.info("  stddev: %0.4fs", request_timer['stddev'])
            log.info("  median: %0.4fs", request_timer['median'])
            log.info("  75th: %0.4fs", request_timer['75percentile'])
            log.info("  95th: %0.4fs", request_timer['95percentile'])
            log.info("  98th: %0.4fs", request_timer['98percentile'])
            log.info("  99th: %0.4fs", request_timer['99percentile'])
            log.info("  99.9th: %0.4fs", request_timer['999percentile'])
def main() -> None:
    """Main function."""

    args = create_parser().parse_args()

    cluster = Cluster(args.db_nodes)
    session = cluster.connect(args.keyspace)

    # default start and end date
    start_date = args.start_date
    end_date = args.end_date

    if datetime.fromisoformat(start_date) < datetime.fromisoformat(MIN_START):
        start_date = MIN_START

    # query most recent data
    if not args.force:
        most_recent_date = query_most_recent_date(session, args.keyspace,
                                                  args.table)
        if most_recent_date is not None:
            start_date = most_recent_date

    print(f"*** Starting exchange rate ingest for {args.cryptocurrency} ***")
    print(f"Start date: {start_date}")
    print(f"End date: {end_date}")
    print(f"Target fiat currencies: {args.fiat_currencies}")

    if datetime.fromisoformat(start_date) > datetime.fromisoformat(end_date):
        print("Error: start date after end date.")
        cluster.shutdown()
        raise SystemExit

    # fetch cryptocurrency exchange rates in USD
    cmc_rates = fetch_cmc_rates(start_date, end_date, args.cryptocurrency)

    ecb_rates = fetch_ecb_rates(args.fiat_currencies)
    # query conversion rates and merge converted values in exchange rates
    exchange_rates = cmc_rates
    date_range = pd.date_range(date.fromisoformat(start_date),
                               date.fromisoformat(end_date))
    date_range = pd.DataFrame(date_range, columns=["date"])
    date_range = date_range["date"].dt.strftime("%Y-%m-%d")

    for fiat_currency in set(args.fiat_currencies) - set(["USD"]):
        ecb_rate = ecb_rates[["date", fiat_currency
                              ]].rename(columns={fiat_currency: "fx_rate"})
        merged_df = cmc_rates.merge(ecb_rate, on="date",
                                    how="left").merge(date_range, how="right")
        # fill gaps over weekends
        merged_df["fx_rate"].fillna(method="ffill", inplace=True)
        merged_df["fx_rate"].fillna(method="bfill", inplace=True)
        merged_df[fiat_currency] = merged_df["USD"] * merged_df["fx_rate"]
        merged_df = merged_df[["date", fiat_currency]]
        exchange_rates = exchange_rates.merge(merged_df, on="date")

    # insert final exchange rates into Cassandra
    if "USD" not in args.fiat_currencies:
        exchange_rates.drop("USD", axis=1, inplace=True)
    exchange_rates["fiat_values"] = exchange_rates.drop(
        "date", axis=1).to_dict(orient="records")
    exchange_rates.drop(args.fiat_currencies, axis=1, inplace=True)

    print(f"{exchange_rates.iloc[0].date} - {exchange_rates.iloc[-1].date}")

    # insert exchange rates into Cassandra table
    insert_exchange_rates(session, args.keyspace, args.table, exchange_rates)
    print(f"Inserted rates for {len(exchange_rates)} days: ", end="")
    print(f"{exchange_rates.iloc[0].date} - {exchange_rates.iloc[-1].date}")

    cluster.shutdown()
예제 #40
0
class CustomPayloadTests(unittest.TestCase):
    def setUp(self):
        if PROTOCOL_VERSION < 4:
            raise unittest.SkipTest(
                "Native protocol 4,0+ is required for custom payloads, currently using %r"
                % (PROTOCOL_VERSION, ))
        self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
        self.session = self.cluster.connect()

    def tearDown(self):

        self.cluster.shutdown()

    def test_custom_query_basic(self):
        """
        Test to validate that custom payloads work with simple queries

        creates a simple query and ensures that custom payloads are passed to C*. A custom
        query provider is used with C* so we can validate that same custom payloads are sent back
        with the results


        @since 2.6
        @jira_ticket PYTHON-280
        @expected_result valid custom payloads should be sent and received

        @test_category queries:custom_payload
        """

        # Create a simple query statement a
        query = "SELECT * FROM system.local"
        statement = SimpleStatement(query)
        # Validate that various types of custom payloads are sent and received okay
        self.validate_various_custom_payloads(statement=statement)

    def test_custom_query_batching(self):
        """
        Test to validate that custom payloads work with batch queries

        creates a batch query and ensures that custom payloads are passed to C*. A custom
        query provider is used with C* so we can validate that same custom payloads are sent back
        with the results


        @since 2.6
        @jira_ticket PYTHON-280
        @expected_result valid custom payloads should be sent and received

        @test_category queries:custom_payload
        """

        # Construct Batch Statement
        batch = BatchStatement(BatchType.LOGGED)
        for i in range(10):
            batch.add(
                SimpleStatement(
                    "INSERT INTO test3rf.test (k, v) VALUES (%s, %s)"), (i, i))

        # Validate that various types of custom payloads are sent and received okay
        self.validate_various_custom_payloads(statement=batch)

    def test_custom_query_prepared(self):
        """
        Test to validate that custom payloads work with prepared queries

        creates a batch query and ensures that custom payloads are passed to C*. A custom
        query provider is used with C* so we can validate that same custom payloads are sent back
        with the results


        @since 2.6
        @jira_ticket PYTHON-280
        @expected_result valid custom payloads should be sent and received

        @test_category queries:custom_payload
        """

        # Construct prepared statement
        prepared = self.session.prepare("""
            INSERT INTO test3rf.test (k, v) VALUES  (?, ?)
            """)

        bound = prepared.bind((1, None))

        # Validate that various custom payloads are validated correctly
        self.validate_various_custom_payloads(statement=bound)

    def validate_various_custom_payloads(self, statement):
        """
        This is a utility method that given a statement will attempt
        to submit the statement with various custom payloads. It will
        validate that the custom payloads are sent and received correctly.

        @param statement The statement to validate the custom queries in conjunction with
        """

        # Simple key value
        custom_payload = {'test': b'test_return'}
        self.execute_async_validate_custom_payload(
            statement=statement, custom_payload=custom_payload)

        # no key value
        custom_payload = {'': b''}
        self.execute_async_validate_custom_payload(
            statement=statement, custom_payload=custom_payload)

        # Space value
        custom_payload = {' ': b' '}
        self.execute_async_validate_custom_payload(
            statement=statement, custom_payload=custom_payload)

        # Long key value pair
        key_value = "x" * 10
        custom_payload = {key_value: six.b(key_value)}
        self.execute_async_validate_custom_payload(
            statement=statement, custom_payload=custom_payload)

        # Max supported value key pairs according C* binary protocol v4 should be 65534 (unsigned short max value)
        for i in range(65534):
            custom_payload[str(i)] = six.b('x')
        self.execute_async_validate_custom_payload(
            statement=statement, custom_payload=custom_payload)

        # Add one custom payload to this is too many key value pairs and should fail
        custom_payload[str(65535)] = six.b('x')
        with self.assertRaises(ValueError):
            self.execute_async_validate_custom_payload(
                statement=statement, custom_payload=custom_payload)

    def execute_async_validate_custom_payload(self, statement, custom_payload):
        """
        This is just a simple method that submits a statement with a payload, and validates
        that the custom payload we submitted matches the one that we got back
        @param statement The statement to execute
        @param custom_payload The custom payload to submit with
        """

        # Submit the statement with our custom payload. Validate the one
        # we receive from the server matches
        response_future = self.session.execute_async(
            statement, custom_payload=custom_payload)
        response_future.result()
        returned_custom_payload = response_future.custom_payload
        self.assertEqual(custom_payload, returned_custom_payload)
예제 #41
0
    def test_nonprimitive_datatypes(self):
        """
        Test for inserting various types of DATA_TYPE_NON_PRIMITIVE into UDT's
        """
        raise unittest.SkipTest("Collections are not allowed in UDTs")
        c = Cluster(protocol_version=PROTOCOL_VERSION)
        s = c.connect()

        # create keyspace
        s.execute("""
            CREATE KEYSPACE test_nonprimitive_datatypes
            WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' }
            """)
        s.set_keyspace("test_nonprimitive_datatypes")

        # create UDT
        alpha_type_list = []
        start_index = ord('a')
        for i, nonprim_datatype in enumerate(DATA_TYPE_NON_PRIMITIVE_NAMES):
            for j, datatype in enumerate(DATA_TYPE_PRIMITIVES):
                if nonprim_datatype == "map":
                    type_string = "{0}_{1} {2}<{3}, {3}>".format(
                        chr(start_index + i), chr(start_index + j),
                        nonprim_datatype, datatype)
                else:
                    type_string = "{0}_{1} {2}<{3}>".format(
                        chr(start_index + i), chr(start_index + j),
                        nonprim_datatype, datatype)
                alpha_type_list.append(type_string)

        s.execute("""
            CREATE TYPE alldatatypes ({0})
        """.format(', '.join(alpha_type_list)))

        s.execute(
            "CREATE TABLE mytable (a int PRIMARY KEY, b frozen<alldatatypes>)")

        # register UDT
        alphabet_list = []
        for i in range(ord('a'),
                       ord('a') + len(DATA_TYPE_NON_PRIMITIVE_NAMES)):
            for j in range(ord('a'), ord('a') + len(DATA_TYPE_PRIMITIVES)):
                alphabet_list.append('{0}_{1}'.format(chr(i), chr(j)))

        Alldatatypes = namedtuple("alldatatypes", alphabet_list)
        c.register_user_type("test_nonprimitive_datatypes", "alldatatypes",
                             Alldatatypes)

        # insert UDT data
        params = []
        for nonprim_datatype in DATA_TYPE_NON_PRIMITIVE_NAMES:
            for datatype in DATA_TYPE_PRIMITIVES:
                params.append((get_nonprim_sample(nonprim_datatype, datatype)))

        insert = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)")
        s.execute(insert, (0, Alldatatypes(*params)))

        # retrieve and verify data
        results = s.execute("SELECT * FROM mytable")
        self.assertEqual(1, len(results))

        row = results[0].b
        for expected, actual in zip(params, row):
            self.assertEqual(expected, actual)

        c.shutdown()
예제 #42
0
    def test_legacy_tables(self):

        if get_server_versions()[0] < (2, 1, 0):
            raise unittest.SkipTest(
                'Test schema output assumes 2.1.0+ options')

        if sys.version_info[2:] != (2, 7):
            raise unittest.SkipTest(
                'This test compares static strings generated from dict items, which may change orders. Test with 2.7.'
            )

        cli_script = """CREATE KEYSPACE legacy
WITH placement_strategy = 'SimpleStrategy'
AND strategy_options = {replication_factor:1};

USE legacy;

CREATE COLUMN FAMILY simple_no_col
 WITH comparator = UTF8Type
 AND key_validation_class = UUIDType
 AND default_validation_class = UTF8Type;

CREATE COLUMN FAMILY simple_with_col
 WITH comparator = UTF8Type
 and key_validation_class = UUIDType
 and default_validation_class = UTF8Type
 AND column_metadata = [
 {column_name: col_with_meta, validation_class: UTF8Type}
 ];

CREATE COLUMN FAMILY composite_partition_no_col
 WITH comparator = UTF8Type
 AND key_validation_class = 'CompositeType(UUIDType,UTF8Type)'
 AND default_validation_class = UTF8Type;

CREATE COLUMN FAMILY composite_partition_with_col
 WITH comparator = UTF8Type
 AND key_validation_class = 'CompositeType(UUIDType,UTF8Type)'
 AND default_validation_class = UTF8Type
 AND column_metadata = [
 {column_name: col_with_meta, validation_class: UTF8Type}
 ];

CREATE COLUMN FAMILY nested_composite_key
 WITH comparator = UTF8Type
 and key_validation_class = 'CompositeType(CompositeType(UUIDType,UTF8Type), LongType)'
 and default_validation_class = UTF8Type
 AND column_metadata = [
 {column_name: full_name, validation_class: UTF8Type}
 ];

create column family composite_comp_no_col
  with column_type = 'Standard'
  and comparator = 'DynamicCompositeType(t=>org.apache.cassandra.db.marshal.TimeUUIDType,s=>org.apache.cassandra.db.marshal.UTF8Type,b=>org.apache.cassandra.db.marshal.BytesType)'
  and default_validation_class = 'BytesType'
  and key_validation_class = 'BytesType'
  and read_repair_chance = 0.0
  and dclocal_read_repair_chance = 0.1
  and gc_grace = 864000
  and min_compaction_threshold = 4
  and max_compaction_threshold = 32
  and compaction_strategy = 'org.apache.cassandra.db.compaction.SizeTieredCompactionStrategy'
  and caching = 'KEYS_ONLY'
  and cells_per_row_to_cache = '0'
  and default_time_to_live = 0
  and speculative_retry = 'NONE'
  and comment = 'Stores file meta data';

create column family composite_comp_with_col
  with column_type = 'Standard'
  and comparator = 'DynamicCompositeType(t=>org.apache.cassandra.db.marshal.TimeUUIDType,s=>org.apache.cassandra.db.marshal.UTF8Type,b=>org.apache.cassandra.db.marshal.BytesType)'
  and default_validation_class = 'BytesType'
  and key_validation_class = 'BytesType'
  and read_repair_chance = 0.0
  and dclocal_read_repair_chance = 0.1
  and gc_grace = 864000
  and min_compaction_threshold = 4
  and max_compaction_threshold = 32
  and compaction_strategy = 'org.apache.cassandra.db.compaction.SizeTieredCompactionStrategy'
  and caching = 'KEYS_ONLY'
  and cells_per_row_to_cache = '0'
  and default_time_to_live = 0
  and speculative_retry = 'NONE'
  and comment = 'Stores file meta data'
  and column_metadata = [
    {column_name : 'b@6d616d6d616a616d6d61',
    validation_class : BytesType,
    index_name : 'idx_one',
    index_type : 0},
    {column_name : 'b@6869746d65776974686d75736963',
    validation_class : BytesType,
    index_name : 'idx_two',
    index_type : 0}]
  and compression_options = {'sstable_compression' : 'org.apache.cassandra.io.compress.LZ4Compressor'};"""

        # note: the inner key type for legacy.nested_composite_key
        # (org.apache.cassandra.db.marshal.CompositeType(org.apache.cassandra.db.marshal.UUIDType, org.apache.cassandra.db.marshal.UTF8Type))
        # is a bit strange, but it replays in CQL with desired results
        expected_string = """CREATE KEYSPACE legacy WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}  AND durable_writes = true;

/*
Warning: Table legacy.composite_comp_with_col omitted because it has constructs not compatible with CQL (was created via legacy API).

Approximate structure, for reference:
(this should not be used to reproduce this schema)

CREATE TABLE legacy.composite_comp_with_col (
    key blob,
    t timeuuid,
    b blob,
    s text,
    "b@6869746d65776974686d75736963" blob,
    "b@6d616d6d616a616d6d61" blob,
    PRIMARY KEY (key, t, b, s)
) WITH COMPACT STORAGE
    AND CLUSTERING ORDER BY (t ASC, b ASC, s ASC)
    AND caching = '{"keys":"ALL", "rows_per_partition":"NONE"}'
    AND comment = 'Stores file meta data'
    AND compaction = {'min_threshold': '4', 'class': 'org.apache.cassandra.db.compaction.SizeTieredCompactionStrategy', 'max_threshold': '32'}
    AND compression = {'sstable_compression': 'org.apache.cassandra.io.compress.LZ4Compressor'}
    AND dclocal_read_repair_chance = 0.1
    AND default_time_to_live = 0
    AND gc_grace_seconds = 864000
    AND max_index_interval = 2048
    AND memtable_flush_period_in_ms = 0
    AND min_index_interval = 128
    AND read_repair_chance = 0.0
    AND speculative_retry = 'NONE';
CREATE INDEX idx_two ON legacy.composite_comp_with_col ("b@6869746d65776974686d75736963");
CREATE INDEX idx_one ON legacy.composite_comp_with_col ("b@6d616d6d616a616d6d61");
*/

CREATE TABLE legacy.nested_composite_key (
    key 'org.apache.cassandra.db.marshal.CompositeType(org.apache.cassandra.db.marshal.UUIDType, org.apache.cassandra.db.marshal.UTF8Type)',
    key2 bigint,
    full_name text,
    PRIMARY KEY ((key, key2))
) WITH COMPACT STORAGE
    AND caching = '{"keys":"ALL", "rows_per_partition":"NONE"}'
    AND comment = ''
    AND compaction = {'min_threshold': '4', 'class': 'org.apache.cassandra.db.compaction.SizeTieredCompactionStrategy', 'max_threshold': '32'}
    AND compression = {'sstable_compression': 'org.apache.cassandra.io.compress.LZ4Compressor'}
    AND dclocal_read_repair_chance = 0.1
    AND default_time_to_live = 0
    AND gc_grace_seconds = 864000
    AND max_index_interval = 2048
    AND memtable_flush_period_in_ms = 0
    AND min_index_interval = 128
    AND read_repair_chance = 0.0
    AND speculative_retry = 'NONE';

CREATE TABLE legacy.composite_partition_with_col (
    key uuid,
    key2 text,
    col_with_meta text,
    PRIMARY KEY ((key, key2))
) WITH COMPACT STORAGE
    AND caching = '{"keys":"ALL", "rows_per_partition":"NONE"}'
    AND comment = ''
    AND compaction = {'min_threshold': '4', 'class': 'org.apache.cassandra.db.compaction.SizeTieredCompactionStrategy', 'max_threshold': '32'}
    AND compression = {'sstable_compression': 'org.apache.cassandra.io.compress.LZ4Compressor'}
    AND dclocal_read_repair_chance = 0.1
    AND default_time_to_live = 0
    AND gc_grace_seconds = 864000
    AND max_index_interval = 2048
    AND memtable_flush_period_in_ms = 0
    AND min_index_interval = 128
    AND read_repair_chance = 0.0
    AND speculative_retry = 'NONE';

CREATE TABLE legacy.composite_partition_no_col (
    key uuid,
    key2 text,
    column1 text,
    value text,
    PRIMARY KEY ((key, key2), column1)
) WITH COMPACT STORAGE
    AND CLUSTERING ORDER BY (column1 ASC)
    AND caching = '{"keys":"ALL", "rows_per_partition":"NONE"}'
    AND comment = ''
    AND compaction = {'min_threshold': '4', 'class': 'org.apache.cassandra.db.compaction.SizeTieredCompactionStrategy', 'max_threshold': '32'}
    AND compression = {'sstable_compression': 'org.apache.cassandra.io.compress.LZ4Compressor'}
    AND dclocal_read_repair_chance = 0.1
    AND default_time_to_live = 0
    AND gc_grace_seconds = 864000
    AND max_index_interval = 2048
    AND memtable_flush_period_in_ms = 0
    AND min_index_interval = 128
    AND read_repair_chance = 0.0
    AND speculative_retry = 'NONE';

CREATE TABLE legacy.simple_with_col (
    key uuid PRIMARY KEY,
    col_with_meta text
) WITH COMPACT STORAGE
    AND caching = '{"keys":"ALL", "rows_per_partition":"NONE"}'
    AND comment = ''
    AND compaction = {'min_threshold': '4', 'class': 'org.apache.cassandra.db.compaction.SizeTieredCompactionStrategy', 'max_threshold': '32'}
    AND compression = {'sstable_compression': 'org.apache.cassandra.io.compress.LZ4Compressor'}
    AND dclocal_read_repair_chance = 0.1
    AND default_time_to_live = 0
    AND gc_grace_seconds = 864000
    AND max_index_interval = 2048
    AND memtable_flush_period_in_ms = 0
    AND min_index_interval = 128
    AND read_repair_chance = 0.0
    AND speculative_retry = 'NONE';

CREATE TABLE legacy.simple_no_col (
    key uuid,
    column1 text,
    value text,
    PRIMARY KEY (key, column1)
) WITH COMPACT STORAGE
    AND CLUSTERING ORDER BY (column1 ASC)
    AND caching = '{"keys":"ALL", "rows_per_partition":"NONE"}'
    AND comment = ''
    AND compaction = {'min_threshold': '4', 'class': 'org.apache.cassandra.db.compaction.SizeTieredCompactionStrategy', 'max_threshold': '32'}
    AND compression = {'sstable_compression': 'org.apache.cassandra.io.compress.LZ4Compressor'}
    AND dclocal_read_repair_chance = 0.1
    AND default_time_to_live = 0
    AND gc_grace_seconds = 864000
    AND max_index_interval = 2048
    AND memtable_flush_period_in_ms = 0
    AND min_index_interval = 128
    AND read_repair_chance = 0.0
    AND speculative_retry = 'NONE';

/*
Warning: Table legacy.composite_comp_no_col omitted because it has constructs not compatible with CQL (was created via legacy API).

Approximate structure, for reference:
(this should not be used to reproduce this schema)

CREATE TABLE legacy.composite_comp_no_col (
    key blob,
    column1 'org.apache.cassandra.db.marshal.DynamicCompositeType(org.apache.cassandra.db.marshal.TimeUUIDType, org.apache.cassandra.db.marshal.BytesType, org.apache.cassandra.db.marshal.UTF8Type)',
    column2 text,
    value blob,
    PRIMARY KEY (key, column1, column1, column2)
) WITH COMPACT STORAGE
    AND CLUSTERING ORDER BY (column1 ASC, column1 ASC, column2 ASC)
    AND caching = '{"keys":"ALL", "rows_per_partition":"NONE"}'
    AND comment = 'Stores file meta data'
    AND compaction = {'min_threshold': '4', 'class': 'org.apache.cassandra.db.compaction.SizeTieredCompactionStrategy', 'max_threshold': '32'}
    AND compression = {'sstable_compression': 'org.apache.cassandra.io.compress.LZ4Compressor'}
    AND dclocal_read_repair_chance = 0.1
    AND default_time_to_live = 0
    AND gc_grace_seconds = 864000
    AND max_index_interval = 2048
    AND memtable_flush_period_in_ms = 0
    AND min_index_interval = 128
    AND read_repair_chance = 0.0
    AND speculative_retry = 'NONE';
*/"""

        ccm = get_cluster()
        ccm.run_cli(cli_script)

        cluster = Cluster(protocol_version=PROTOCOL_VERSION)
        session = cluster.connect()

        legacy_meta = cluster.metadata.keyspaces['legacy']
        self.assert_equal_diff(legacy_meta.export_as_string(), expected_string)

        session.execute('DROP KEYSPACE legacy')

        cluster.shutdown()
예제 #43
0
class ClientExceptionTests(unittest.TestCase):
    def setUp(self):
        """
        Test is skipped if run with native protocol version <4
        """
        self.support_v5 = True
        if PROTOCOL_VERSION < 4:
            raise unittest.SkipTest(
                "Native protocol 4,0+ is required for custom payloads, currently using %r"
                % (PROTOCOL_VERSION, ))
        try:
            self.cluster = Cluster(
                protocol_version=ProtocolVersion.MAX_SUPPORTED,
                allow_beta_protocol_version=True)
            self.session = self.cluster.connect()
        except NoHostAvailable:
            log.info("Protocol Version 5 not supported,")
            self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
            self.session = self.cluster.connect()
            self.support_v5 = False

        self.nodes_currently_failing = []
        self.node1, self.node2, self.node3 = get_cluster().nodes.values()

    def tearDown(self):

        self.cluster.shutdown()
        failing_nodes = []

        # Restart the nodes to fully functional again
        self.setFailingNodes(failing_nodes, "testksfail")

    def execute_helper(self, session, query):
        tries = 0
        while tries < 100:
            try:
                return session.execute(query)
            except OperationTimedOut:
                ex_type, ex, tb = sys.exc_info()
                log.warn("{0}: {1} Backtrace: {2}".format(
                    ex_type.__name__, ex, traceback.extract_tb(tb)))
                del tb
                tries += 1

        raise RuntimeError(
            "Failed to execute query after 100 attempts: {0}".format(query))

    def execute_concurrent_args_helper(self, session, query, params):
        tries = 0
        while tries < 100:
            try:
                return execute_concurrent_with_args(session,
                                                    query,
                                                    params,
                                                    concurrency=50)
            except (ReadTimeout, WriteTimeout, OperationTimedOut, ReadFailure,
                    WriteFailure):
                ex_type, ex, tb = sys.exc_info()
                log.warn("{0}: {1} Backtrace: {2}".format(
                    ex_type.__name__, ex, traceback.extract_tb(tb)))
                del tb
                tries += 1

        raise RuntimeError(
            "Failed to execute query after 100 attempts: {0}".format(query))

    def setFailingNodes(self, failing_nodes, keyspace):
        """
        This method will take in a set of failing nodes, and toggle all of the nodes in the provided list to fail
        writes.
        @param failing_nodes A definitive list of nodes that should fail writes
        @param keyspace The keyspace to enable failures on

        """

        # Ensure all of the nodes on the list have failures enabled
        for node in failing_nodes:
            if node not in self.nodes_currently_failing:
                node.stop(wait_other_notice=True, gently=False)
                node.start(
                    jvm_args=[" -Dcassandra.test.fail_writes_ks=" + keyspace],
                    wait_for_binary_proto=True,
                    wait_other_notice=True)
                self.nodes_currently_failing.append(node)

        # Ensure all nodes not on the list, but that are currently set to failing are enabled
        for node in self.nodes_currently_failing:
            if node not in failing_nodes:
                node.stop(wait_other_notice=True, gently=False)
                node.start(wait_for_binary_proto=True, wait_other_notice=True)
                self.nodes_currently_failing.remove(node)

    def _perform_cql_statement(self,
                               text,
                               consistency_level,
                               expected_exception,
                               session=None):
        """
        Simple helper method to preform cql statements and check for expected exception
        @param text CQl statement to execute
        @param consistency_level Consistency level at which it is to be executed
        @param expected_exception Exception expected to be throw or none
        """
        if session is None:
            session = self.session
        statement = SimpleStatement(text)
        statement.consistency_level = consistency_level

        if expected_exception is None:
            self.execute_helper(session, statement)
        else:
            with self.assertRaises(expected_exception) as cm:
                self.execute_helper(session, statement)
            if self.support_v5 and (isinstance(cm.exception, WriteFailure)
                                    or isinstance(cm.exception, ReadFailure)):
                if isinstance(cm.exception, ReadFailure):
                    self.assertEqual(
                        list(cm.exception.error_code_map.values())[0], 1)
                else:
                    self.assertEqual(
                        list(cm.exception.error_code_map.values())[0], 0)

    def test_write_failures_from_coordinator(self):
        """
        Test to validate that write failures from the coordinator are surfaced appropriately.

        test_write_failures_from_coordinator Enable write failures on the various nodes using a custom jvm flag,
        cassandra.test.fail_writes_ks. This will cause writes to fail on that specific node. Depending on the replication
        factor of the keyspace, and the consistency level, we will expect the coordinator to send WriteFailure, or not.


        @since 2.6.0, 3.7.0
        @jira_ticket PYTHON-238, PYTHON-619
        @expected_result Appropriate write failures from the coordinator

        @test_category queries:basic
        """

        # Setup temporary keyspace.
        self._perform_cql_statement("""
            CREATE KEYSPACE testksfail
            WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'}
            """,
                                    consistency_level=ConsistencyLevel.ALL,
                                    expected_exception=None)

        # create table
        self._perform_cql_statement("""
            CREATE TABLE testksfail.test (
                k int PRIMARY KEY,
                v int )
            """,
                                    consistency_level=ConsistencyLevel.ALL,
                                    expected_exception=None)

        # Disable one node
        failing_nodes = [self.node1]
        self.setFailingNodes(failing_nodes, "testksfail")

        # With one node disabled we would expect a write failure with ConsistencyLevel of all
        self._perform_cql_statement("""
            INSERT INTO testksfail.test (k, v) VALUES  (1, 0 )
            """,
                                    consistency_level=ConsistencyLevel.ALL,
                                    expected_exception=WriteFailure)

        # We have two nodes left so a write with consistency level of QUORUM should complete as expected
        self._perform_cql_statement("""
            INSERT INTO testksfail.test (k, v) VALUES  (1, 0 )
            """,
                                    consistency_level=ConsistencyLevel.QUORUM,
                                    expected_exception=None)

        failing_nodes = []

        # Restart the nodes to fully functional again
        self.setFailingNodes(failing_nodes, "testksfail")

        # Drop temporary keyspace
        self._perform_cql_statement("""
            DROP KEYSPACE testksfail
            """,
                                    consistency_level=ConsistencyLevel.ANY,
                                    expected_exception=None)

    def test_tombstone_overflow_read_failure(self):
        """
        Test to validate that a ReadFailure is returned from the node when a specified threshold of tombstombs is
        reached.

        test_tombstomb_overflow_read_failure First sets the tombstone failure threshold down to a level that allows it
        to be more easily encountered. We then create some wide rows and ensure they are deleted appropriately. This
        produces the correct amount of tombstombs. Upon making a simple query we expect to get a read failure back
        from the coordinator.


        @since 2.6.0, 3.7.0
        @jira_ticket PYTHON-238, PYTHON-619
        @expected_result Appropriate write failures from the coordinator

        @test_category queries:basic
        """

        # Setup table for "wide row"
        self._perform_cql_statement("""
            CREATE TABLE test3rf.test2 (
                k int,
                v0 int,
                v1 int, PRIMARY KEY (k,v0))
            """,
                                    consistency_level=ConsistencyLevel.ALL,
                                    expected_exception=None)

        statement = self.session.prepare(
            "INSERT INTO test3rf.test2 (k, v0,v1) VALUES  (1,?,1)")
        parameters = [(x, ) for x in range(3000)]
        self.execute_concurrent_args_helper(self.session, statement,
                                            parameters)

        statement = self.session.prepare(
            "DELETE v1 FROM test3rf.test2 WHERE k = 1 AND v0 =?")
        parameters = [(x, ) for x in range(2001)]
        self.execute_concurrent_args_helper(self.session, statement,
                                            parameters)

        self._perform_cql_statement("""
            SELECT * FROM test3rf.test2 WHERE k = 1
            """,
                                    consistency_level=ConsistencyLevel.ALL,
                                    expected_exception=ReadFailure)

        self._perform_cql_statement("""
            DROP TABLE test3rf.test2;
            """,
                                    consistency_level=ConsistencyLevel.ALL,
                                    expected_exception=None)

    def test_user_function_failure(self):
        """
        Test to validate that exceptions in user defined function are correctly surfaced by the driver to us.

        test_user_function_failure First creates a table to use for testing. Then creates a function that will throw an
        exception when invoked. It then invokes the function and expects a FunctionException. Finally it preforms
        cleanup operations.

        @since 2.6.0
        @jira_ticket PYTHON-238
        @expected_result Function failures when UDF throws exception

        @test_category queries:basic
        """

        # create UDF that throws an exception
        self._perform_cql_statement("""
            CREATE FUNCTION test3rf.test_failure(d double)
            RETURNS NULL ON NULL INPUT
            RETURNS double
            LANGUAGE java AS 'throw new RuntimeException("failure");';
            """,
                                    consistency_level=ConsistencyLevel.ALL,
                                    expected_exception=None)

        # Create test table
        self._perform_cql_statement("""
            CREATE TABLE  test3rf.d (k int PRIMARY KEY , d double);
            """,
                                    consistency_level=ConsistencyLevel.ALL,
                                    expected_exception=None)

        # Insert some values
        self._perform_cql_statement("""
            INSERT INTO test3rf.d (k,d) VALUES (0, 5.12);
            """,
                                    consistency_level=ConsistencyLevel.ALL,
                                    expected_exception=None)

        # Run the function expect a function failure exception
        self._perform_cql_statement("""
            SELECT test_failure(d) FROM test3rf.d WHERE k = 0;
            """,
                                    consistency_level=ConsistencyLevel.ALL,
                                    expected_exception=FunctionFailure)

        self._perform_cql_statement("""
            DROP FUNCTION test3rf.test_failure;
            """,
                                    consistency_level=ConsistencyLevel.ALL,
                                    expected_exception=None)

        self._perform_cql_statement("""
            DROP TABLE test3rf.d;
            """,
                                    consistency_level=ConsistencyLevel.ALL,
                                    expected_exception=None)
예제 #44
0
class PreparedStatementTests(unittest.TestCase):
    def setUp(self):
        self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
        self.session = self.cluster.connect()

    def tearDown(self):
        self.cluster.shutdown()

    def test_routing_key(self):
        """
        Simple code coverage to ensure routing_keys can be accessed
        """
        prepared = self.session.prepare("""
            INSERT INTO test3rf.test (k, v) VALUES  (?, ?)
            """)

        self.assertIsInstance(prepared, PreparedStatement)
        bound = prepared.bind((1, None))
        self.assertEqual(bound.routing_key, b'\x00\x00\x00\x01')

    def test_empty_routing_key_indexes(self):
        """
        Ensure when routing_key_indexes are blank,
        the routing key should be None
        """
        prepared = self.session.prepare("""
            INSERT INTO test3rf.test (k, v) VALUES  (?, ?)
            """)
        prepared.routing_key_indexes = None

        self.assertIsInstance(prepared, PreparedStatement)
        bound = prepared.bind((1, None))
        self.assertEqual(bound.routing_key, None)

    def test_predefined_routing_key(self):
        """
        Basic test that ensures _set_routing_key()
        overrides the current routing key
        """
        prepared = self.session.prepare("""
            INSERT INTO test3rf.test (k, v) VALUES  (?, ?)
            """)

        self.assertIsInstance(prepared, PreparedStatement)
        bound = prepared.bind((1, None))
        bound._set_routing_key('fake_key')
        self.assertEqual(bound.routing_key, 'fake_key')

    def test_multiple_routing_key_indexes(self):
        """
        Basic test that uses a fake routing_key_index
        """
        prepared = self.session.prepare("""
            INSERT INTO test3rf.test (k, v) VALUES  (?, ?)
            """)
        self.assertIsInstance(prepared, PreparedStatement)

        prepared.routing_key_indexes = [0, 1]
        bound = prepared.bind((1, 2))
        self.assertEqual(
            bound.routing_key,
            b'\x00\x04\x00\x00\x00\x01\x00\x00\x04\x00\x00\x00\x02\x00')

        prepared.routing_key_indexes = [1, 0]
        bound = prepared.bind((1, 2))
        self.assertEqual(
            bound.routing_key,
            b'\x00\x04\x00\x00\x00\x02\x00\x00\x04\x00\x00\x00\x01\x00')

    def test_bound_keyspace(self):
        """
        Ensure that bound.keyspace works as expected
        """
        prepared = self.session.prepare("""
            INSERT INTO test3rf.test (k, v) VALUES  (?, ?)
            """)

        self.assertIsInstance(prepared, PreparedStatement)
        bound = prepared.bind((1, 2))
        self.assertEqual(bound.keyspace, 'test3rf')
예제 #45
0
    def test_refresh_schema_no_wait(self):

        contact_points = ['127.0.0.1']
        cluster = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=10,
                          contact_points=contact_points, load_balancing_policy=WhiteListRoundRobinPolicy(contact_points))
        session = cluster.connect()

        schema_ver = session.execute("SELECT schema_version FROM system.local WHERE key='local'")[0][0]

        # create a schema disagreement
        session.execute("UPDATE system.local SET schema_version=%s WHERE key='local'", (uuid4(),))

        try:
            agreement_timeout = 1

            # cluster agreement wait exceeded
            c = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=agreement_timeout)
            start_time = time.time()
            s = c.connect()
            end_time = time.time()
            self.assertGreaterEqual(end_time - start_time, agreement_timeout)
            self.assertTrue(c.metadata.keyspaces)

            # cluster agreement wait used for refresh
            original_meta = c.metadata.keyspaces
            start_time = time.time()
            self.assertRaisesRegexp(Exception, r"Schema was not refreshed.*", c.refresh_schema)
            end_time = time.time()
            self.assertGreaterEqual(end_time - start_time, agreement_timeout)
            self.assertIs(original_meta, c.metadata.keyspaces)
            
            # refresh wait overrides cluster value
            original_meta = c.metadata.keyspaces
            start_time = time.time()
            c.refresh_schema(max_schema_agreement_wait=0)
            end_time = time.time()
            self.assertLess(end_time - start_time, agreement_timeout)
            self.assertIsNot(original_meta, c.metadata.keyspaces)
            self.assertEqual(original_meta, c.metadata.keyspaces)

            c.shutdown()

            refresh_threshold = 0.5
            # cluster agreement bypass
            c = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=0)
            start_time = time.time()
            s = c.connect()
            end_time = time.time()
            self.assertLess(end_time - start_time, refresh_threshold)
            self.assertTrue(c.metadata.keyspaces)

            # cluster agreement wait used for refresh
            original_meta = c.metadata.keyspaces
            start_time = time.time()
            c.refresh_schema()
            end_time = time.time()
            self.assertLess(end_time - start_time, refresh_threshold)
            self.assertIsNot(original_meta, c.metadata.keyspaces)
            self.assertEqual(original_meta, c.metadata.keyspaces)
            
            # refresh wait overrides cluster value
            original_meta = c.metadata.keyspaces
            start_time = time.time()
            self.assertRaisesRegexp(Exception, r"Schema was not refreshed.*", c.refresh_schema, max_schema_agreement_wait=agreement_timeout)
            end_time = time.time()
            self.assertGreaterEqual(end_time - start_time, agreement_timeout)
            self.assertIs(original_meta, c.metadata.keyspaces)

            c.shutdown()
        finally:
            session.execute("UPDATE system.local SET schema_version=%s WHERE key='local'", (schema_ver,))

        cluster.shutdown()
예제 #46
0
class SchemaMetadataTests(unittest.TestCase):

    ksname = "schemametadatatest"

    @property
    def cfname(self):
        return self._testMethodName.lower()

    @classmethod
    def setup_class(cls):
        cluster = Cluster(protocol_version=PROTOCOL_VERSION)
        session = cluster.connect()
        try:
            results = session.execute(
                "SELECT keyspace_name FROM system.schema_keyspaces")
            existing_keyspaces = [row[0] for row in results]
            if cls.ksname in existing_keyspaces:
                session.execute("DROP KEYSPACE %s" % cls.ksname)

            session.execute("""
                CREATE KEYSPACE %s
                WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'};
                """ % cls.ksname)
        finally:
            cluster.shutdown()

    @classmethod
    def teardown_class(cls):
        cluster = Cluster(['127.0.0.1'], protocol_version=PROTOCOL_VERSION)
        session = cluster.connect()
        try:
            session.execute("DROP KEYSPACE %s" % cls.ksname)
        finally:
            cluster.shutdown()

    def setUp(self):
        self.cluster = Cluster(['127.0.0.1'],
                               protocol_version=PROTOCOL_VERSION)
        self.session = self.cluster.connect()

    def tearDown(self):
        try:
            self.session.execute("""
                DROP TABLE {ksname}.{cfname}
                """.format(ksname=self.ksname, cfname=self.cfname))
        finally:
            self.cluster.shutdown()

    def make_create_statement(self,
                              partition_cols,
                              clustering_cols=None,
                              other_cols=None,
                              compact=False):
        clustering_cols = clustering_cols or []
        other_cols = other_cols or []

        statement = "CREATE TABLE %s.%s (" % (self.ksname, self.cfname)
        if len(partition_cols) == 1 and not clustering_cols:
            statement += "%s text PRIMARY KEY, " % partition_cols[0]
        else:
            statement += ", ".join("%s text" % col for col in partition_cols)
            statement += ", "

        statement += ", ".join("%s text" % col
                               for col in clustering_cols + other_cols)

        if len(partition_cols) != 1 or clustering_cols:
            statement += ", PRIMARY KEY ("

            if len(partition_cols) > 1:
                statement += "(" + ", ".join(partition_cols) + ")"
            else:
                statement += partition_cols[0]

            if clustering_cols:
                statement += ", "
                statement += ", ".join(clustering_cols)

            statement += ")"

        statement += ")"
        if compact:
            statement += " WITH COMPACT STORAGE"

        return statement

    def check_create_statement(self, tablemeta, original):
        recreate = tablemeta.as_cql_query(formatted=False)
        self.assertEqual(original, recreate[:len(original)])
        self.session.execute("DROP TABLE %s.%s" % (self.ksname, self.cfname))
        self.session.execute(recreate)

        # create the table again, but with formatting enabled
        self.session.execute("DROP TABLE %s.%s" % (self.ksname, self.cfname))
        recreate = tablemeta.as_cql_query(formatted=True)
        self.session.execute(recreate)

    def get_table_metadata(self):
        self.cluster.control_connection.refresh_schema()
        return self.cluster.metadata.keyspaces[self.ksname].tables[self.cfname]

    def test_basic_table_meta_properties(self):
        create_statement = self.make_create_statement(["a"], [], ["b", "c"])
        self.session.execute(create_statement)

        self.cluster.control_connection.refresh_schema()

        meta = self.cluster.metadata
        self.assertNotEqual(meta.cluster_ref, None)
        self.assertNotEqual(meta.cluster_name, None)
        self.assertTrue(self.ksname in meta.keyspaces)
        ksmeta = meta.keyspaces[self.ksname]

        self.assertEqual(ksmeta.name, self.ksname)
        self.assertTrue(ksmeta.durable_writes)
        self.assertEqual(ksmeta.replication_strategy.name, 'SimpleStrategy')
        self.assertEqual(ksmeta.replication_strategy.replication_factor, 1)

        self.assertTrue(self.cfname in ksmeta.tables)
        tablemeta = ksmeta.tables[self.cfname]
        self.assertEqual(tablemeta.keyspace, ksmeta)
        self.assertEqual(tablemeta.name, self.cfname)

        self.assertEqual([u'a'], [c.name for c in tablemeta.partition_key])
        self.assertEqual([], tablemeta.clustering_key)
        self.assertEqual([u'a', u'b', u'c'], sorted(tablemeta.columns.keys()))

        for option in tablemeta.options:
            self.assertIn(option, TableMetadata.recognized_options)

        self.check_create_statement(tablemeta, create_statement)

    def test_compound_primary_keys(self):
        create_statement = self.make_create_statement(["a"], ["b"], ["c"])
        create_statement += " WITH CLUSTERING ORDER BY (b ASC)"
        self.session.execute(create_statement)
        tablemeta = self.get_table_metadata()

        self.assertEqual([u'a'], [c.name for c in tablemeta.partition_key])
        self.assertEqual([u'b'], [c.name for c in tablemeta.clustering_key])
        self.assertEqual([u'a', u'b', u'c'], sorted(tablemeta.columns.keys()))

        self.check_create_statement(tablemeta, create_statement)

    def test_compound_primary_keys_more_columns(self):
        create_statement = self.make_create_statement(["a"], ["b", "c"],
                                                      ["d", "e", "f"])
        create_statement += " WITH CLUSTERING ORDER BY (b ASC, c ASC)"
        self.session.execute(create_statement)
        tablemeta = self.get_table_metadata()

        self.assertEqual([u'a'], [c.name for c in tablemeta.partition_key])
        self.assertEqual([u'b', u'c'],
                         [c.name for c in tablemeta.clustering_key])
        self.assertEqual([u'a', u'b', u'c', u'd', u'e', u'f'],
                         sorted(tablemeta.columns.keys()))

        self.check_create_statement(tablemeta, create_statement)

    def test_composite_primary_key(self):
        create_statement = self.make_create_statement(["a", "b"], [], ["c"])
        self.session.execute(create_statement)
        tablemeta = self.get_table_metadata()

        self.assertEqual([u'a', u'b'],
                         [c.name for c in tablemeta.partition_key])
        self.assertEqual([], tablemeta.clustering_key)
        self.assertEqual([u'a', u'b', u'c'], sorted(tablemeta.columns.keys()))

        self.check_create_statement(tablemeta, create_statement)

    def test_composite_in_compound_primary_key(self):
        create_statement = self.make_create_statement(["a", "b"], ["c"],
                                                      ["d", "e"])
        create_statement += " WITH CLUSTERING ORDER BY (c ASC)"
        self.session.execute(create_statement)
        tablemeta = self.get_table_metadata()

        self.assertEqual([u'a', u'b'],
                         [c.name for c in tablemeta.partition_key])
        self.assertEqual([u'c'], [c.name for c in tablemeta.clustering_key])
        self.assertEqual([u'a', u'b', u'c', u'd', u'e'],
                         sorted(tablemeta.columns.keys()))

        self.check_create_statement(tablemeta, create_statement)

    def test_compound_primary_keys_compact(self):
        create_statement = self.make_create_statement(["a"], ["b"], ["c"],
                                                      compact=True)
        create_statement += " AND CLUSTERING ORDER BY (b ASC)"
        self.session.execute(create_statement)
        tablemeta = self.get_table_metadata()

        self.assertEqual([u'a'], [c.name for c in tablemeta.partition_key])
        self.assertEqual([u'b'], [c.name for c in tablemeta.clustering_key])
        self.assertEqual([u'a', u'b', u'c'], sorted(tablemeta.columns.keys()))

        self.check_create_statement(tablemeta, create_statement)

    def test_compound_primary_keys_more_columns_compact(self):
        create_statement = self.make_create_statement(["a"], ["b", "c"], ["d"],
                                                      compact=True)
        create_statement += " AND CLUSTERING ORDER BY (b ASC, c ASC)"
        self.session.execute(create_statement)
        tablemeta = self.get_table_metadata()

        self.assertEqual([u'a'], [c.name for c in tablemeta.partition_key])
        self.assertEqual([u'b', u'c'],
                         [c.name for c in tablemeta.clustering_key])
        self.assertEqual([u'a', u'b', u'c', u'd'],
                         sorted(tablemeta.columns.keys()))

        self.check_create_statement(tablemeta, create_statement)

    def test_composite_primary_key_compact(self):
        create_statement = self.make_create_statement(["a", "b"], [], ["c"],
                                                      compact=True)
        self.session.execute(create_statement)
        tablemeta = self.get_table_metadata()

        self.assertEqual([u'a', u'b'],
                         [c.name for c in tablemeta.partition_key])
        self.assertEqual([], tablemeta.clustering_key)
        self.assertEqual([u'a', u'b', u'c'], sorted(tablemeta.columns.keys()))

        self.check_create_statement(tablemeta, create_statement)

    def test_composite_in_compound_primary_key_compact(self):
        create_statement = self.make_create_statement(["a", "b"], ["c"], ["d"],
                                                      compact=True)
        create_statement += " AND CLUSTERING ORDER BY (c ASC)"
        self.session.execute(create_statement)
        tablemeta = self.get_table_metadata()

        self.assertEqual([u'a', u'b'],
                         [c.name for c in tablemeta.partition_key])
        self.assertEqual([u'c'], [c.name for c in tablemeta.clustering_key])
        self.assertEqual([u'a', u'b', u'c', u'd'],
                         sorted(tablemeta.columns.keys()))

        self.check_create_statement(tablemeta, create_statement)

    def test_compound_primary_keys_ordering(self):
        create_statement = self.make_create_statement(["a"], ["b"], ["c"])
        create_statement += " WITH CLUSTERING ORDER BY (b DESC)"
        self.session.execute(create_statement)
        tablemeta = self.get_table_metadata()
        self.check_create_statement(tablemeta, create_statement)

    def test_compound_primary_keys_more_columns_ordering(self):
        create_statement = self.make_create_statement(["a"], ["b", "c"],
                                                      ["d", "e", "f"])
        create_statement += " WITH CLUSTERING ORDER BY (b DESC, c ASC)"
        self.session.execute(create_statement)
        tablemeta = self.get_table_metadata()
        self.check_create_statement(tablemeta, create_statement)

    def test_composite_in_compound_primary_key_ordering(self):
        create_statement = self.make_create_statement(["a", "b"], ["c"],
                                                      ["d", "e"])
        create_statement += " WITH CLUSTERING ORDER BY (c DESC)"
        self.session.execute(create_statement)
        tablemeta = self.get_table_metadata()
        self.check_create_statement(tablemeta, create_statement)

    def test_indexes(self):
        create_statement = self.make_create_statement(["a"], ["b", "c"],
                                                      ["d", "e", "f"])
        create_statement += " WITH CLUSTERING ORDER BY (b ASC, c ASC)"
        self.session.execute(create_statement)

        d_index = "CREATE INDEX d_index ON %s.%s (d)" % (self.ksname,
                                                         self.cfname)
        e_index = "CREATE INDEX e_index ON %s.%s (e)" % (self.ksname,
                                                         self.cfname)
        self.session.execute(d_index)
        self.session.execute(e_index)

        tablemeta = self.get_table_metadata()
        statements = tablemeta.export_as_string().strip()
        statements = [s.strip() for s in statements.split(';')]
        statements = list(filter(bool, statements))
        self.assertEqual(3, len(statements))
        self.assertEqual(d_index, statements[1])
        self.assertEqual(e_index, statements[2])

        # make sure indexes are included in KeyspaceMetadata.export_as_string()
        ksmeta = self.cluster.metadata.keyspaces[self.ksname]
        statement = ksmeta.export_as_string()
        self.assertIn('CREATE INDEX d_index', statement)
        self.assertIn('CREATE INDEX e_index', statement)

    def test_collection_indexes(self):
        self.session.execute(
            "CREATE TABLE %s.%s (a int PRIMARY KEY, b map<text, text>)" %
            (self.ksname, self.cfname))
        self.session.execute("CREATE INDEX index1 ON %s.%s (keys(b))" %
                             (self.ksname, self.cfname))

        tablemeta = self.get_table_metadata()
        self.assertIn('(keys(b))', tablemeta.export_as_string())

        self.session.execute("DROP INDEX %s.index1" % (self.ksname, ))
        self.session.execute("CREATE INDEX index2 ON %s.%s (b)" %
                             (self.ksname, self.cfname))

        tablemeta = self.get_table_metadata()
        self.assertIn(' (b)', tablemeta.export_as_string())

        # test full indexes on frozen collections, if available
        if get_server_versions()[0] >= (2, 1, 3):
            self.session.execute("DROP TABLE %s.%s" %
                                 (self.ksname, self.cfname))
            self.session.execute(
                "CREATE TABLE %s.%s (a int PRIMARY KEY, b frozen<map<text, text>>)"
                % (self.ksname, self.cfname))
            self.session.execute("CREATE INDEX index3 ON %s.%s (full(b))" %
                                 (self.ksname, self.cfname))

            tablemeta = self.get_table_metadata()
            self.assertIn('(full(b))', tablemeta.export_as_string())

    def test_compression_disabled(self):
        create_statement = self.make_create_statement(["a"], ["b"], ["c"])
        create_statement += " WITH compression = {}"
        self.session.execute(create_statement)
        tablemeta = self.get_table_metadata()
        self.assertIn("compression = {}", tablemeta.export_as_string())
예제 #47
0
class BatchStatementTests(unittest.TestCase):
    def setUp(self):
        if PROTOCOL_VERSION < 2:
            raise unittest.SkipTest(
                "Protocol 2.0+ is required for BATCH operations, currently testing against %r"
                % (PROTOCOL_VERSION, ))

        self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
        if PROTOCOL_VERSION < 3:
            self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
        self.session = self.cluster.connect()

        self.session.execute("TRUNCATE test3rf.test")

    def tearDown(self):
        self.cluster.shutdown()

    def confirm_results(self):
        keys = set()
        values = set()
        results = self.session.execute("SELECT * FROM test3rf.test")
        for result in results:
            keys.add(result.k)
            values.add(result.v)

        self.assertEqual(set(range(10)), keys)
        self.assertEqual(set(range(10)), values)

    def test_string_statements(self):
        batch = BatchStatement(BatchType.LOGGED)
        for i in range(10):
            batch.add("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)",
                      (i, i))

        self.session.execute(batch)
        self.session.execute_async(batch).result()
        self.confirm_results()

    def test_simple_statements(self):
        batch = BatchStatement(BatchType.LOGGED)
        for i in range(10):
            batch.add(
                SimpleStatement(
                    "INSERT INTO test3rf.test (k, v) VALUES (%s, %s)"), (i, i))

        self.session.execute(batch)
        self.session.execute_async(batch).result()
        self.confirm_results()

    def test_prepared_statements(self):
        prepared = self.session.prepare(
            "INSERT INTO test3rf.test (k, v) VALUES (?, ?)")

        batch = BatchStatement(BatchType.LOGGED)
        for i in range(10):
            batch.add(prepared, (i, i))

        self.session.execute(batch)
        self.session.execute_async(batch).result()
        self.confirm_results()

    def test_bound_statements(self):
        prepared = self.session.prepare(
            "INSERT INTO test3rf.test (k, v) VALUES (?, ?)")

        batch = BatchStatement(BatchType.LOGGED)
        for i in range(10):
            batch.add(prepared.bind((i, i)))

        self.session.execute(batch)
        self.session.execute_async(batch).result()
        self.confirm_results()

    def test_no_parameters(self):
        batch = BatchStatement(BatchType.LOGGED)
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (0, 0)")
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (1, 1)", ())
        batch.add(
            SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (2, 2)"))
        batch.add(
            SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (3, 3)"),
            ())

        prepared = self.session.prepare(
            "INSERT INTO test3rf.test (k, v) VALUES (4, 4)")
        batch.add(prepared)
        batch.add(prepared, ())
        batch.add(prepared.bind([]))
        batch.add(prepared.bind([]), ())

        batch.add("INSERT INTO test3rf.test (k, v) VALUES (5, 5)", ())
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (6, 6)", ())
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (7, 7)", ())
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (8, 8)", ())
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (9, 9)", ())

        self.assertRaises(ValueError, batch.add, prepared.bind([]), (1))
        self.assertRaises(ValueError, batch.add, prepared.bind([]), (1, 2))
        self.assertRaises(ValueError, batch.add, prepared.bind([]), (1, 2, 3))

        self.session.execute(batch)
        self.confirm_results()
예제 #48
0
class BatchStatementDefaultRoutingKeyTests(unittest.TestCase):
    # Test for PYTHON-126: BatchStatement.add() should set the routing key of the first added prepared statement

    def setUp(self):
        if PROTOCOL_VERSION < 2:
            raise unittest.SkipTest(
                "Protocol 2.0+ is required for BATCH operations, currently testing against %r"
                % (PROTOCOL_VERSION,))
        self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
        self.session = self.cluster.connect()
        query = """
                INSERT INTO test3rf.test (k, v) VALUES  (?, ?)
                """
        self.simple_statement = SimpleStatement(query, routing_key='ss_rk', keyspace='keyspace_name')
        self.prepared = self.session.prepare(query)

    def tearDown(self):
        self.cluster.shutdown()

    def test_rk_from_bound(self):
        """
        batch routing key is inherited from BoundStatement
        """
        bound = self.prepared.bind((1, None))
        batch = BatchStatement()
        batch.add(bound)
        self.assertIsNotNone(batch.routing_key)
        self.assertEqual(batch.routing_key, bound.routing_key)

    def test_rk_from_simple(self):
        """
        batch routing key is inherited from SimpleStatement
        """
        batch = BatchStatement()
        batch.add(self.simple_statement)
        self.assertIsNotNone(batch.routing_key)
        self.assertEqual(batch.routing_key, self.simple_statement.routing_key)

    def test_inherit_first_rk_bound(self):
        """
        compound batch inherits the first routing key of the first added statement (bound statement is first)
        """
        bound = self.prepared.bind((100000000, None))
        batch = BatchStatement()
        batch.add("ss with no rk")
        batch.add(bound)
        batch.add(self.simple_statement)

        for i in range(3):
            batch.add(self.prepared, (i, i))

        self.assertIsNotNone(batch.routing_key)
        self.assertEqual(batch.routing_key, bound.routing_key)

    def test_inherit_first_rk_simple_statement(self):
        """
        compound batch inherits the first routing key of the first added statement (Simplestatement is first)
        """
        bound = self.prepared.bind((1, None))
        batch = BatchStatement()
        batch.add("ss with no rk")
        batch.add(self.simple_statement)
        batch.add(bound)

        for i in range(10):
            batch.add(self.prepared, (i, i))

        self.assertIsNotNone(batch.routing_key)
        self.assertEqual(batch.routing_key, self.simple_statement.routing_key)

    def test_inherit_first_rk_prepared_param(self):
        """
        compound batch inherits the first routing key of the first added statement (prepared statement is first)
        """
        bound = self.prepared.bind((2, None))
        batch = BatchStatement()
        batch.add("ss with no rk")
        batch.add(self.prepared, (1, 0))
        batch.add(bound)
        batch.add(self.simple_statement)

        self.assertIsNotNone(batch.routing_key)
        self.assertEqual(batch.routing_key, self.prepared.bind((1, 0)).routing_key)
예제 #49
0
class SerialConsistencyTests(unittest.TestCase):
    def setUp(self):
        if PROTOCOL_VERSION < 2:
            raise unittest.SkipTest(
                "Protocol 2.0+ is required for Serial Consistency, currently testing against %r"
                % (PROTOCOL_VERSION,))

        self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
        if PROTOCOL_VERSION < 3:
            self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
        self.session = self.cluster.connect()

    def tearDown(self):
        self.cluster.shutdown()

    def test_conditional_update(self):
        self.session.execute("INSERT INTO test3rf.test (k, v) VALUES (0, 0)")
        statement = SimpleStatement(
            "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=1",
            serial_consistency_level=ConsistencyLevel.SERIAL)
        # crazy test, but PYTHON-299
        # TODO: expand to check more parameters get passed to statement, and on to messages
        self.assertEqual(statement.serial_consistency_level, ConsistencyLevel.SERIAL)
        future = self.session.execute_async(statement)
        result = future.result()
        self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.SERIAL)
        self.assertTrue(result)
        self.assertFalse(result[0].applied)

        statement = SimpleStatement(
            "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=0",
            serial_consistency_level=ConsistencyLevel.LOCAL_SERIAL)
        self.assertEqual(statement.serial_consistency_level, ConsistencyLevel.LOCAL_SERIAL)
        future = self.session.execute_async(statement)
        result = future.result()
        self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.LOCAL_SERIAL)
        self.assertTrue(result)
        self.assertTrue(result[0].applied)

    def test_conditional_update_with_prepared_statements(self):
        self.session.execute("INSERT INTO test3rf.test (k, v) VALUES (0, 0)")
        statement = self.session.prepare(
            "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=2")

        statement.serial_consistency_level = ConsistencyLevel.SERIAL
        future = self.session.execute_async(statement)
        result = future.result()
        self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.SERIAL)
        self.assertTrue(result)
        self.assertFalse(result[0].applied)

        statement = self.session.prepare(
            "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=0")
        bound = statement.bind(())
        bound.serial_consistency_level = ConsistencyLevel.LOCAL_SERIAL
        future = self.session.execute_async(bound)
        result = future.result()
        self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.LOCAL_SERIAL)
        self.assertTrue(result)
        self.assertTrue(result[0].applied)

    def test_conditional_update_with_batch_statements(self):
        self.session.execute("INSERT INTO test3rf.test (k, v) VALUES (0, 0)")
        statement = BatchStatement(serial_consistency_level=ConsistencyLevel.SERIAL)
        statement.add("UPDATE test3rf.test SET v=1 WHERE k=0 IF v=1")
        self.assertEqual(statement.serial_consistency_level, ConsistencyLevel.SERIAL)
        future = self.session.execute_async(statement)
        result = future.result()
        self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.SERIAL)
        self.assertTrue(result)
        self.assertFalse(result[0].applied)

        statement = BatchStatement(serial_consistency_level=ConsistencyLevel.LOCAL_SERIAL)
        statement.add("UPDATE test3rf.test SET v=1 WHERE k=0 IF v=0")
        self.assertEqual(statement.serial_consistency_level, ConsistencyLevel.LOCAL_SERIAL)
        future = self.session.execute_async(statement)
        result = future.result()
        self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.LOCAL_SERIAL)
        self.assertTrue(result)
        self.assertTrue(result[0].applied)

    def test_bad_consistency_level(self):
        statement = SimpleStatement("foo")
        self.assertRaises(ValueError, setattr, statement, 'serial_consistency_level', ConsistencyLevel.ONE)
        self.assertRaises(ValueError, SimpleStatement, 'foo', serial_consistency_level=ConsistencyLevel.ONE)
def main():
    cassandra_cluster = Cluster(
        ['10.0.0.5', '10.0.0.7', '10.0.0.12', '10.0.0.19'])
    cassandra_session = cassandra_cluster.connect('insight')
    drop_views(cassandra_session)
    cassandra_cluster.shutdown()
예제 #51
0
class BatchStatementTests(BasicSharedKeyspaceUnitTestCase):

    def setUp(self):
        if PROTOCOL_VERSION < 2:
            raise unittest.SkipTest(
                "Protocol 2.0+ is required for BATCH operations, currently testing against %r"
                % (PROTOCOL_VERSION,))

        self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
        if PROTOCOL_VERSION < 3:
            self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
        self.session = self.cluster.connect()

    def tearDown(self):
        self.cluster.shutdown()

    def confirm_results(self):
        keys = set()
        values = set()
        # Assuming the test data is inserted at default CL.ONE, we need ALL here to guarantee we see
        # everything inserted
        results = self.session.execute(SimpleStatement("SELECT * FROM test3rf.test",
                                                       consistency_level=ConsistencyLevel.ALL))
        for result in results:
            keys.add(result.k)
            values.add(result.v)

        self.assertEqual(set(range(10)), keys, msg=results)
        self.assertEqual(set(range(10)), values, msg=results)

    def test_string_statements(self):
        batch = BatchStatement(BatchType.LOGGED)
        for i in range(10):
            batch.add("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", (i, i))

        self.session.execute(batch)
        self.session.execute_async(batch).result()
        self.confirm_results()

    def test_simple_statements(self):
        batch = BatchStatement(BatchType.LOGGED)
        for i in range(10):
            batch.add(SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)"), (i, i))

        self.session.execute(batch)
        self.session.execute_async(batch).result()
        self.confirm_results()

    def test_prepared_statements(self):
        prepared = self.session.prepare("INSERT INTO test3rf.test (k, v) VALUES (?, ?)")

        batch = BatchStatement(BatchType.LOGGED)
        for i in range(10):
            batch.add(prepared, (i, i))

        self.session.execute(batch)
        self.session.execute_async(batch).result()
        self.confirm_results()

    def test_bound_statements(self):
        prepared = self.session.prepare("INSERT INTO test3rf.test (k, v) VALUES (?, ?)")

        batch = BatchStatement(BatchType.LOGGED)
        for i in range(10):
            batch.add(prepared.bind((i, i)))

        self.session.execute(batch)
        self.session.execute_async(batch).result()
        self.confirm_results()

    def test_no_parameters(self):
        batch = BatchStatement(BatchType.LOGGED)
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (0, 0)")
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (1, 1)", ())
        batch.add(SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (2, 2)"))
        batch.add(SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (3, 3)"), ())

        prepared = self.session.prepare("INSERT INTO test3rf.test (k, v) VALUES (4, 4)")
        batch.add(prepared)
        batch.add(prepared, ())
        batch.add(prepared.bind([]))
        batch.add(prepared.bind([]), ())

        batch.add("INSERT INTO test3rf.test (k, v) VALUES (5, 5)", ())
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (6, 6)", ())
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (7, 7)", ())
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (8, 8)", ())
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (9, 9)", ())

        self.assertRaises(ValueError, batch.add, prepared.bind([]), (1))
        self.assertRaises(ValueError, batch.add, prepared.bind([]), (1, 2))
        self.assertRaises(ValueError, batch.add, prepared.bind([]), (1, 2, 3))

        self.session.execute(batch)
        self.confirm_results()

    def test_unicode(self):
        ddl = '''
            CREATE TABLE test3rf.testtext (
                k int PRIMARY KEY,
                v text )'''
        self.session.execute(ddl)
        unicode_text = u'Fran\u00E7ois'
        query = u'INSERT INTO test3rf.testtext (k, v) VALUES (%s, %s)'
        try:
            batch = BatchStatement(BatchType.LOGGED)
            batch.add(u"INSERT INTO test3rf.testtext (k, v) VALUES (%s, %s)", (0, unicode_text))
            self.session.execute(batch)
        finally:
            self.session.execute("DROP TABLE test3rf.testtext")

    def test_too_many_statements(self):
        max_statements = 0xFFFF
        ss = SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (0, 0)")
        b = BatchStatement(batch_type=BatchType.UNLOGGED, consistency_level=ConsistencyLevel.ONE)

        # max works
        b.add_all([ss] * max_statements, [None] * max_statements)
        self.session.execute(b)

        # max + 1 raises
        self.assertRaises(ValueError, b.add, ss)

        # also would have bombed trying to encode
        b._statements_and_parameters.append((False, ss.query_string, ()))
        self.assertRaises(NoHostAvailable, self.session.execute, b)
예제 #52
0
class LightweightTransactionTests(unittest.TestCase):
    def setUp(self):
        """
        Test is skipped if run with cql version < 2

        """
        if PROTOCOL_VERSION < 2:
            raise unittest.SkipTest(
                "Protocol 2.0+ is required for Lightweight transactions, currently testing against %r"
                % (PROTOCOL_VERSION,))

        self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
        self.session = self.cluster.connect()

        ddl = '''
            CREATE TABLE test3rf.lwt (
                k int PRIMARY KEY,
                v int )'''
        self.session.execute(ddl)

    def tearDown(self):
        """
        Shutdown cluster
        """
        self.session.execute("DROP TABLE test3rf.lwt")
        self.cluster.shutdown()

    def test_no_connection_refused_on_timeout(self):
        """
        Test for PYTHON-91 "Connection closed after LWT timeout"
        Verifies that connection to the cluster is not shut down when timeout occurs.
        Number of iterations can be specified with LWT_ITERATIONS environment variable.
        Default value is 1000
        """
        insert_statement = self.session.prepare("INSERT INTO test3rf.lwt (k, v) VALUES (0, 0) IF NOT EXISTS")
        delete_statement = self.session.prepare("DELETE FROM test3rf.lwt WHERE k = 0 IF EXISTS")

        iterations = int(os.getenv("LWT_ITERATIONS", 1000))

        # Prepare series of parallel statements
        statements_and_params = []
        for i in range(iterations):
            statements_and_params.append((insert_statement, ()))
            statements_and_params.append((delete_statement, ()))

        received_timeout = False
        results = execute_concurrent(self.session, statements_and_params, raise_on_first_error=False)
        for (success, result) in results:
            if success:
                continue
            else:
                # In this case result is an exception
                if type(result).__name__ == "NoHostAvailable":
                    self.fail("PYTHON-91: Disconnected from Cassandra: %s" % result.message)
                if type(result).__name__ == "WriteTimeout":
                    received_timeout = True
                    continue
                if type(result).__name__ == "WriteFailure":
                    received_timeout = True
                    continue
                if type(result).__name__ == "ReadTimeout":
                    continue
                if type(result).__name__ == "ReadFailure":
                    continue

                self.fail("Unexpected exception %s: %s" % (type(result).__name__, result.message))

        # Make sure test passed
        self.assertTrue(received_timeout)
예제 #53
0
class CassandraIO():
    """
    This module reads the packet info from the Cassandra database and returns
    a histogram of the data. As of this writing, Cassandra is not suitable
    for our needs and the plan is to move to a relational database so not a lot
    of care is being put into this module
    """
    def __init__(self, keyspace, table, hostname="127.0.0.1"):
        self._cluster = Cluster([hostname])
        self._session = self._cluster.connect(keyspace)
        self._table = table

    def get_histogram(self, sample_window_sec, slide_window_sec, filter_name,
                      filter_value, features_keep):
        """
        filter_name and filter_value is to define one thing that we're making
        histograms for. For example, source_addr = 10.0.0.1 would build
        histograms for all packets originating from 10.0.0.1 and with whatever
        desired features
        
        features_keep is a tuple of strings representing field names in the 
        database. If the feature is
        inside the text_values map, just pass in the key for the map and this
        code will automatically try the various maps looking for it. We 
        assume the same key name doesn't appear in multiple maps

        As a result the ret dict contains a flat keyspace
        """
        ret = Histograms(sample_window_sec, slide_window_sec)

        self._result = self._session.execute("SELECT * FROM " + self._table)

        temp_data = dict()

        count = 0
        for res in self._result:
            count += 1

            features = ()
            res_dict = res._asdict()

            if res_dict[filter_name] != filter_value:
                continue

            if res.source_addr in temp_data:
                temp_data[res_dict["dest_addr"]] += 1
            else:
                temp_data[res_dict["dest_addr"]] = 0

            for f in features_keep.split(","):
                f = f.strip()
                if f in res_dict:
                    features = features + (res_dict[f], )
                elif f in res_dict["text_values"]:
                    features = features + (res_dict["text_values"][f], )
                else:
                    raise Exception("Could not find field " + f)

            sec = time.mktime(res.time_stamp.timetuple())

            ret.insert_one(','.join(map(str, features)), sec)
        return ret

    def close(self):
        self._session = None
        self._cluster.shutdown()
        self._cluster = None
예제 #54
0
class Client:

    def __init__(self, hosts, metadata):
        self.port = 9042
        self.hosts = []
        for host in hosts:
            if ":" in host:
                self.port = host.split(":")[-1]
                self.hosts.append(host.split(":")[0])
            else:
                self.hosts.append(host)
        logger.debug(f"Connecting to hosts: {hosts}")
        self.ks = metadata["keyspace"]
        self.repl = metadata["replication"]
        self.cluster = Cluster(self.hosts, port=self.port)
        try:
            self.session = self.cluster.connect(self.ks)
        except NoHostAvailable:
            self.cluster.shutdown()
            self.cluster = Cluster(self.hosts, port=self.port)
            query = f"CREATE KEYSPACE IF NOT EXISTS {self.ks} WITH REPLICATION = {self.repl}"
            logger.debug(f"Keyspace not found: {self.ks}")
            logger.debug(f"Executing query: {query}")
            self.session = self.cluster.connect()
            self.session.execute(query)
            self.session.set_keyspace(self.ks)

    def heartbeat(self):
        return bool(self.cluster.metadata.keyspaces)

    def _format_table(self, table):
        columns = []
        for column_name, column in table.columns.items():
            columns.append({
                "name": column_name,
                "datatype": column.cql_type
            })
        return {
            "name": table.name,
            "columns": columns,
            "primary_key": [x.name for x in table.primary_key]
        }

    def list_tables(self, **kwargs):
        tables = []
        for _, table in self.cluster.metadata.keyspaces[self.ks].tables.items():
            tables.append(self._format_table(table))
        return tables

    def create_table(self, data):
        logger.debug(f"Creating table {data}")
        fields = []
        for i in data["columns"]:
            fields.append(f'{i["name"]} {i["datatype"]}')
        query = f'CREATE TABLE {data["name"]}({",".join(fields)}, PRIMARY KEY({",".join(data["primary_key"])}))'
        logger.debug(f"Executing query: {query}")
        self.session.execute(query)
        return True

    def describe_table(self, table_name):
        table = self.cluster.metadata.keyspaces[self.ks].tables[table_name]
        return self._format_table(table)

    def insert_into(self, table_name, data):
        logger.debug(f"Insert into {table_name}: {data}")
        fields = []
        values = []
        for k, v in data["field_values"].items():
            fields.append(k)
            if isinstance(v, str):
                v = f"'{v}'"
            values.append(str(v))
        query = f'INSERT INTO {table_name}({",".join(fields)}) VALUES({",".join(values)})'
        logger.debug(f"Executing query: {query}")
        self.session.execute(query)
        return True

    def select_from(self, table_name, data):
        logger.debug(f"Select from {table_name}: {data}")
        query = f'SELECT * FROM {table_name}'
        logger.debug(f"Executing query: {query}")
        rows = self.session.execute(query)
        return rows.all()

    def update_from(self, table_name, data):
        logger.debug(f"Update from {table_name}: {data}")
        updates = []
        filters = []
        for k, v in data["field_values"].items():
            if isinstance(v, str):
                v = f"'{v}'"
            updates.append(f"{k} = {v}")
        for k, v in data["where"].items():
            if isinstance(v, str):
                v = f"'{v}'"
            filters.append(f"{k} = {v}")
        query = f'UPDATE {table_name} SET {",".join(updates)} WHERE {",".join(filters)}'
        logger.debug(f"Executing query: {query}")
        self.session.execute(query)
        return True

    def delete_from(self, table_name, data):
        logger.debug(f"Delete from {table_name}: {data}")
        filters = []
        for k, v in data["where"].items():
            if isinstance(v, str):
                v = f"'{v}'"
            filters.append(f"{k} = {v}")
        query = f'DELETE FROM {table_name} WHERE {",".join(filters)}'
        logger.debug(f"Executing query: {query}")
        self.session.execute(query)
        return True
    def token_aware(self, keyspace, use_prepared=False):
        use_singledc()
        cluster = Cluster(load_balancing_policy=TokenAwarePolicy(
            RoundRobinPolicy()),
                          protocol_version=PROTOCOL_VERSION)
        session = cluster.connect()
        wait_for_up(cluster, 1, wait=False)
        wait_for_up(cluster, 2, wait=False)
        wait_for_up(cluster, 3)

        create_schema(session, keyspace, replication_factor=1)
        self._insert(session, keyspace)
        self._query(session, keyspace, use_prepared=use_prepared)

        self.coordinator_stats.assert_query_count_equals(self, 1, 0)
        self.coordinator_stats.assert_query_count_equals(self, 2, 12)
        self.coordinator_stats.assert_query_count_equals(self, 3, 0)

        self.coordinator_stats.reset_counts()
        self._query(session, keyspace, use_prepared=use_prepared)

        self.coordinator_stats.assert_query_count_equals(self, 1, 0)
        self.coordinator_stats.assert_query_count_equals(self, 2, 12)
        self.coordinator_stats.assert_query_count_equals(self, 3, 0)

        self.coordinator_stats.reset_counts()
        force_stop(2)
        wait_for_down(cluster, 2, wait=True)

        try:
            self._query(session, keyspace, use_prepared=use_prepared)
            self.fail()
        except Unavailable as e:
            self.assertEqual(e.consistency, 1)
            self.assertEqual(e.required_replicas, 1)
            self.assertEqual(e.alive_replicas, 0)

        self.coordinator_stats.reset_counts()
        start(2)
        wait_for_up(cluster, 2, wait=True)

        self._query(session, keyspace, use_prepared=use_prepared)

        self.coordinator_stats.assert_query_count_equals(self, 1, 0)
        self.coordinator_stats.assert_query_count_equals(self, 2, 12)
        self.coordinator_stats.assert_query_count_equals(self, 3, 0)

        self.coordinator_stats.reset_counts()
        stop(2)
        wait_for_down(cluster, 2, wait=True)

        try:
            self._query(session, keyspace, use_prepared=use_prepared)
            self.fail()
        except Unavailable:
            pass

        self.coordinator_stats.reset_counts()
        start(2)
        wait_for_up(cluster, 2, wait=True)
        decommission(2)
        wait_for_down(cluster, 2, wait=True)

        self._query(session, keyspace, use_prepared=use_prepared)

        results = set([
            self.coordinator_stats.get_query_count(1),
            self.coordinator_stats.get_query_count(3)
        ])
        self.assertEqual(results, set([0, 12]))
        self.coordinator_stats.assert_query_count_equals(self, 2, 0)

        cluster.shutdown()
예제 #56
0
class cassandra_utils():
    ''' Object that holds all cassandra related information '''
    def __init__(self, hosts_list):
        self.hosts_list = hosts_list
        # For now, this is hardcoded. Make it configurable later
        self.replication = {'class': 'SimpleStrategy', 'replication_factor': 1}
        self._connect_to_cluster()

    # Private methods
    def _connect_to_cluster(self):
        logging.info('connecting to Cassandra at: {}'.format(self.hosts_list))
        self.cluster = Cluster(self.hosts_list)
        self.session = self.cluster.connect()
        self.cluster_name = self.cluster.metadata.cluster_name
        logging.info('Connected to cluster named: {}'.format(
            self.cluster_name))

    # Public methods
    def get_replication(self):
        ''' Simply return the replication setting for now '''
        return self.replication

    def get_keyspaces(self):
        ''' Return the list of keyspaces in this cluster. Also init
        a dict of keyspaces and keyspace objects '''
        self.system_ks_list = []
        self.db_ks_list = []
        self.system_ks_dict = {}
        self.db_ks_dict = {}
        for ks, ks_obj in self.cluster.metadata.keyspaces.items():
            logging.debug('Setting up keyspace: {}'.format(ks))
            setattr(self, ks, ks_obj)
            if ks in SYSTEM_KS:
                self.system_ks_list.append(ks)
                self.system_ks_dict[ks] = ks_obj
            else:
                self.db_ks_list.append(ks)
                self.db_ks_dict[ks] = ks_obj

        return self.db_ks_list, self.system_ks_list

    def get_tables_in_keyspace(self, keyspace):
        ''' Given a key space, return a list of tables.
            Also set the table obj as an attr in this (self) obj '''
        table_list = []
        # get the keyspace attribute
        ks_obj = getattr(self, keyspace)
        # roll through the ks_obj tables dict and set things up
        tables_dict = ks_obj.tables.items()
        for table_name, table_obj in tables_dict:
            logging.debug('Setting up for table: {} in keyspace: {}'.format(
                table_name, keyspace))
            table_list.append(table_name)
            setattr(ks_obj, table_name, table_obj)
        return table_list

    def set_session_keyspace(self, keyspace):
        ''' Set the default keyspace '''
        logging.info('Setting cluster keyspace to: {}'.format(keyspace))
        self.session.set_keyspace(keyspace)

    def create_table(cass,
                     ks_name,
                     table_name,
                     table_columns=TABLE_COLUMNS_SV2):
        ''' Create specified table in keyspace ks_name if it does not exist '''
        cmd = "CREATE TABLE IF NOT EXISTS {ks_name}.{table_name} {cols}".format(
            ks_name=ks_name, table_name=table_name, cols=table_columns)
        logging.info('Creating table with command: {}'.format(cmd))
        retval = cass.session.execute(cmd)
        # TBD: Not too sure how to check for creating failures
        logging.info('Create command returned: {}'.format(retval))

    def delete_table(cass, ks_name, table_name):
        ''' Delete specified table from keyspace ks_name  '''
        cmd = "DROP TABLE IF EXISTS {ks_name}.{table_name};".format(
            ks_name=ks_name, table_name=table_name)
        logging.info('Deleting table with command: {}'.format(cmd))
        retval = cass.session.execute(cmd)
        # TBD: Not too sure how to check for creating failures
        logging.info('Delete command returned: {}'.format(retval))

    def cleanup(self):
        ''' Close all connections to the Cassandra cluster '''
        logging.info('Closing connection to cluster: {}'.format(
            self.cluster_name))
        self.cluster.shutdown()
예제 #57
0
class HeartbeatTest(unittest.TestCase):
    """
    Test to validate failing a heartbeat check doesn't mark a host as down

    @since 3.3
    @jira_ticket PYTHON-286
    @expected_result host should not be marked down when heartbeat fails

    @test_category connection heartbeat
    """
    def setUp(self):
        self.cluster = Cluster(protocol_version=PROTOCOL_VERSION,
                               idle_heartbeat_interval=1)
        self.session = self.cluster.connect()

    def tearDown(self):
        self.cluster.shutdown()

    def test_heart_beat_timeout(self):
        # Setup a host listener to ensure the nodes don't go down
        test_listener = TestHostListener()
        host = "127.0.0.1"
        node = get_node(1)
        initial_connections = self.fetch_connections(host, self.cluster)
        self.assertNotEqual(len(initial_connections), 0)
        self.cluster.register_listener(test_listener)
        # Pause the node
        node.pause()
        # Wait for connections associated with this host go away
        self.wait_for_no_connections(host, self.cluster)
        # Resume paused node
        node.resume()
        # Run a query to ensure connections are re-established
        current_host = ""
        count = 0
        while current_host != host and count < 100:
            rs = self.session.execute_async("SELECT * FROM system.local",
                                            trace=False)
            rs.result()
            current_host = str(rs._current_host)
            count += 1
            time.sleep(.1)
        self.assertLess(count, 100, "Never connected to the first node")
        new_connections = self.wait_for_connections(host, self.cluster)
        self.assertIsNone(test_listener.host_down)
        # Make sure underlying new connections don't match previous ones
        for connection in initial_connections:
            self.assertFalse(connection in new_connections)

    def fetch_connections(self, host, cluster):
        # Given a cluster object and host grab all connection associated with that host
        connections = []
        holders = cluster.get_connection_holders()
        for conn in holders:
            if host == str(getattr(conn, 'host', '')):
                if isinstance(conn, HostConnectionPool):
                    if conn._connections is not None:
                        connections.append(conn._connections)
                else:
                    if conn._connection is not None:
                        connections.append(conn._connection)
        return connections

    def wait_for_connections(self, host, cluster):
        retry = 0
        while (retry < 300):
            retry += 1
            connections = self.fetch_connections(host, cluster)
            if len(connections) is not 0:
                return connections
            time.sleep(.1)
        self.fail("No new connections found")

    def wait_for_no_connections(self, host, cluster):
        retry = 0
        while (retry < 100):
            retry += 1
            connections = self.fetch_connections(host, cluster)
            if len(connections) is 0:
                return
            time.sleep(.5)
        self.fail("Connections never cleared")
    def test_dc_aware_roundrobin_one_remote_host(self):
        use_multidc([2, 2])
        keyspace = 'test_dc_aware_roundrobin_one_remote_host'
        cluster = Cluster(load_balancing_policy=DCAwareRoundRobinPolicy(
            'dc2', used_hosts_per_remote_dc=1),
                          protocol_version=PROTOCOL_VERSION)
        session = cluster.connect()
        wait_for_up(cluster, 1, wait=False)
        wait_for_up(cluster, 2, wait=False)
        wait_for_up(cluster, 3, wait=False)
        wait_for_up(cluster, 4)

        create_schema(session, keyspace, replication_strategy=[2, 2])
        self._insert(session, keyspace)
        self._query(session, keyspace)

        self.coordinator_stats.assert_query_count_equals(self, 1, 0)
        self.coordinator_stats.assert_query_count_equals(self, 2, 0)
        self.coordinator_stats.assert_query_count_equals(self, 3, 6)
        self.coordinator_stats.assert_query_count_equals(self, 4, 6)

        self.coordinator_stats.reset_counts()
        bootstrap(5, 'dc1')
        wait_for_up(cluster, 5)

        self._query(session, keyspace)

        self.coordinator_stats.assert_query_count_equals(self, 1, 0)
        self.coordinator_stats.assert_query_count_equals(self, 2, 0)
        self.coordinator_stats.assert_query_count_equals(self, 3, 6)
        self.coordinator_stats.assert_query_count_equals(self, 4, 6)
        self.coordinator_stats.assert_query_count_equals(self, 5, 0)

        self.coordinator_stats.reset_counts()
        decommission(3)
        decommission(4)
        wait_for_down(cluster, 3, wait=True)
        wait_for_down(cluster, 4, wait=True)

        self._query(session, keyspace)

        self.coordinator_stats.assert_query_count_equals(self, 3, 0)
        self.coordinator_stats.assert_query_count_equals(self, 4, 0)
        responses = set()
        for node in [1, 2, 5]:
            responses.add(self.coordinator_stats.get_query_count(node))
        self.assertEqual(set([0, 0, 12]), responses)

        self.coordinator_stats.reset_counts()
        decommission(5)
        wait_for_down(cluster, 5, wait=True)

        self._query(session, keyspace)

        self.coordinator_stats.assert_query_count_equals(self, 3, 0)
        self.coordinator_stats.assert_query_count_equals(self, 4, 0)
        self.coordinator_stats.assert_query_count_equals(self, 5, 0)
        responses = set()
        for node in [1, 2]:
            responses.add(self.coordinator_stats.get_query_count(node))
        self.assertEqual(set([0, 12]), responses)

        self.coordinator_stats.reset_counts()
        decommission(1)
        wait_for_down(cluster, 1, wait=True)

        self._query(session, keyspace)

        self.coordinator_stats.assert_query_count_equals(self, 1, 0)
        self.coordinator_stats.assert_query_count_equals(self, 2, 12)
        self.coordinator_stats.assert_query_count_equals(self, 3, 0)
        self.coordinator_stats.assert_query_count_equals(self, 4, 0)
        self.coordinator_stats.assert_query_count_equals(self, 5, 0)

        self.coordinator_stats.reset_counts()
        force_stop(2)

        try:
            self._query(session, keyspace)
            self.fail()
        except NoHostAvailable:
            pass

        cluster.shutdown()
예제 #59
0
def teardown(hosts):
    cluster = Cluster(hosts)
    cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
    session = cluster.connect()
    session.execute("DROP KEYSPACE " + KEYSPACE)
    cluster.shutdown()
예제 #60
0
class DatastoreProxy(AppDBInterface):
  """ 
    Cassandra implementation of the AppDBInterface
  """
  def __init__(self, log_level=logging.INFO):
    """
    Constructor.
    """
    class_name = self.__class__.__name__
    self.logger = logging.getLogger(class_name)
    self.logger.setLevel(log_level)
    self.logger.info('Starting {}'.format(class_name))

    self.hosts = appscale_info.get_db_ips()

    remaining_retries = INITIAL_CONNECT_RETRIES
    while True:
      try:
        self.cluster = Cluster(self.hosts, default_retry_policy=BASIC_RETRIES)
        self.session = self.cluster.connect(KEYSPACE)
        break
      except cassandra.cluster.NoHostAvailable as connection_error:
        remaining_retries -= 1
        if remaining_retries < 0:
          raise connection_error
        time.sleep(3)

    self.session.default_consistency_level = ConsistencyLevel.QUORUM
    self.prepared_statements = {}

  def close(self):
    """ Close all sessions and connections to Cassandra. """
    self.cluster.shutdown()

  def batch_get_entity(self, table_name, row_keys, column_names):
    """
    Takes in batches of keys and retrieves their corresponding rows.
    
    Args:
      table_name: The table to access
      row_keys: A list of keys to access
      column_names: A list of columns to access
    Returns:
      A dictionary of rows and columns/values of those rows. The format 
      looks like such: {key:{column_name:value,...}}
    Raises:
      TypeError: If an argument passed in was not of the expected type.
      AppScaleDBConnectionError: If the batch_get could not be performed due to
        an error with Cassandra.
    """
    if not isinstance(table_name, str): raise TypeError("Expected a str")
    if not isinstance(column_names, list): raise TypeError("Expected a list")
    if not isinstance(row_keys, list): raise TypeError("Expected a list")

    row_keys_bytes = [bytearray(row_key) for row_key in row_keys]

    statement = 'SELECT * FROM "{table}" '\
                'WHERE {key} IN %s and {column} IN %s'.format(
                  table=table_name,
                  key=ThriftColumn.KEY,
                  column=ThriftColumn.COLUMN_NAME,
                )
    query = SimpleStatement(statement, retry_policy=BASIC_RETRIES)
    parameters = (ValueSequence(row_keys_bytes), ValueSequence(column_names))

    try:
      results = self.session.execute(query, parameters=parameters)

      results_dict = {row_key: {} for row_key in row_keys}
      for (key, column, value) in results:
        if key not in results_dict:
          results_dict[key] = {}
        results_dict[key][column] = value

      return results_dict
    except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
      message = 'Exception during batch_get_entity'
      logging.exception(message)
      raise AppScaleDBConnectionError(message)

  def batch_put_entity(self, table_name, row_keys, column_names, cell_values,
                       ttl=None):
    """
    Allows callers to store multiple rows with a single call. A row can 
    have multiple columns and values with them. We refer to each row as 
    an entity.
   
    Args: 
      table_name: The table to mutate
      row_keys: A list of keys to store on
      column_names: A list of columns to mutate
      cell_values: A dict of key/value pairs
      ttl: The number of seconds to keep the row.
    Raises:
      TypeError: If an argument passed in was not of the expected type.
      AppScaleDBConnectionError: If the batch_put could not be performed due to
        an error with Cassandra.
    """
    if not isinstance(table_name, str):
      raise TypeError("Expected a str")
    if not isinstance(column_names, list):
      raise TypeError("Expected a list")
    if not isinstance(row_keys, list):
      raise TypeError("Expected a list")
    if not isinstance(cell_values, dict):
      raise TypeError("Expected a dict")

    insert_str = """
      INSERT INTO "{table}" ({key}, {column}, {value})
      VALUES (?, ?, ?)
    """.format(table=table_name,
               key=ThriftColumn.KEY,
               column=ThriftColumn.COLUMN_NAME,
               value=ThriftColumn.VALUE)

    if ttl is not None:
      insert_str += 'USING TTL {}'.format(ttl)

    statement = self.session.prepare(insert_str)

    statements_and_params = []
    for row_key in row_keys:
      for column in column_names:
        params = (bytearray(row_key), column,
                  bytearray(cell_values[row_key][column]))
        statements_and_params.append((statement, params))

    try:
      execute_concurrent(self.session, statements_and_params,
                         raise_on_first_error=True)
    except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
      message = 'Exception during batch_put_entity'
      logging.exception(message)
      raise AppScaleDBConnectionError(message)

  def prepare_insert(self, table):
    """ Prepare an insert statement.

    Args:
      table: A string containing the table name.
    Returns:
      A PreparedStatement object.
    """
    statement = """
      INSERT INTO "{table}" ({key}, {column}, {value})
      VALUES (?, ?, ?)
      USING TIMESTAMP ?
    """.format(table=table,
               key=ThriftColumn.KEY,
               column=ThriftColumn.COLUMN_NAME,
               value=ThriftColumn.VALUE)

    if statement not in self.prepared_statements:
      self.prepared_statements[statement] = self.session.prepare(statement)

    return self.prepared_statements[statement]

  def prepare_delete(self, table):
    """ Prepare a delete statement.

    Args:
      table: A string containing the table name.
    Returns:
      A PreparedStatement object.
    """
    statement = """
      DELETE FROM "{table}"
      USING TIMESTAMP ?
      WHERE {key} = ?
    """.format(table=table, key=ThriftColumn.KEY)

    if statement not in self.prepared_statements:
      self.prepared_statements[statement] = self.session.prepare(statement)

    return self.prepared_statements[statement]

  def _normal_batch(self, mutations, txid):
    """ Use Cassandra's native batch statement to apply mutations atomically.

    Args:
      mutations: A list of dictionaries representing mutations.
      txid: An integer specifying a transaction ID.
    """
    self.logger.debug('Normal batch: {} mutations'.format(len(mutations)))
    batch = BatchStatement(consistency_level=ConsistencyLevel.QUORUM,
                           retry_policy=BASIC_RETRIES)
    prepared_statements = {'insert': {}, 'delete': {}}
    for mutation in mutations:
      table = mutation['table']

      if table == 'group_updates':
        key = mutation['key']
        insert = """
          INSERT INTO group_updates (group, last_update)
          VALUES (%(group)s, %(last_update)s)
          USING TIMESTAMP %(timestamp)s
        """
        parameters = {'group': key, 'last_update': mutation['last_update'],
                      'timestamp': get_write_time(txid)}
        batch.add(insert, parameters)
        continue

      if mutation['operation'] == Operations.PUT:
        if table not in prepared_statements['insert']:
          prepared_statements['insert'][table] = self.prepare_insert(table)
        values = mutation['values']
        for column in values:
          batch.add(
            prepared_statements['insert'][table],
            (bytearray(mutation['key']), column, bytearray(values[column]),
             get_write_time(txid))
          )
      elif mutation['operation'] == Operations.DELETE:
        if table not in prepared_statements['delete']:
          prepared_statements['delete'][table] = self.prepare_delete(table)
        batch.add(
          prepared_statements['delete'][table],
          (get_write_time(txid), bytearray(mutation['key']))
        )

    try:
      self.session.execute(batch)
    except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
      message = 'Exception during batch_mutate'
      logging.exception(message)
      raise AppScaleDBConnectionError(message)

  def apply_mutations(self, mutations, txid):
    """ Apply mutations across tables.

    Args:
      mutations: A list of dictionaries representing mutations.
      txid: An integer specifying a transaction ID.
    """
    prepared_statements = {'insert': {}, 'delete': {}}
    statements_and_params = []
    for mutation in mutations:
      table = mutation['table']

      if table == 'group_updates':
        key = mutation['key']
        insert = """
          INSERT INTO group_updates (group, last_update)
          VALUES (%(group)s, %(last_update)s)
          USING TIMESTAMP %(timestamp)s
        """
        parameters = {'group': key, 'last_update': mutation['last_update'],
                      'timestamp': get_write_time(txid)}
        statements_and_params.append((SimpleStatement(insert), parameters))
        continue

      if mutation['operation'] == Operations.PUT:
        if table not in prepared_statements['insert']:
          prepared_statements['insert'][table] = self.prepare_insert(table)
        values = mutation['values']
        for column in values:
          params = (bytearray(mutation['key']), column,
                    bytearray(values[column]), get_write_time(txid))
          statements_and_params.append(
            (prepared_statements['insert'][table], params))
      elif mutation['operation'] == Operations.DELETE:
        if table not in prepared_statements['delete']:
          prepared_statements['delete'][table] = self.prepare_delete(table)
        params = (get_write_time(txid), bytearray(mutation['key']))
        statements_and_params.append(
          (prepared_statements['delete'][table], params))

    execute_concurrent(self.session, statements_and_params,
                       raise_on_first_error=True)

  def _large_batch(self, app, mutations, entity_changes, txn):
    """ Insert or delete multiple rows across tables in an atomic statement.

    Args:
      app: A string containing the application ID.
      mutations: A list of dictionaries representing mutations.
      entity_changes: A list of changes at the entity level.
      txn: A transaction ID handler.
    Raises:
      FailedBatch if a concurrent process modifies the batch status.
      AppScaleDBConnectionError if a database connection error was encountered.
    """
    self.logger.debug('Large batch: transaction {}, {} mutations'.
                      format(txn, len(mutations)))
    large_batch = LargeBatch(self.session, app, txn)
    try:
      large_batch.start()
    except FailedBatch as batch_error:
      raise AppScaleDBConnectionError(str(batch_error))

    insert_item = """
      INSERT INTO batches (app, transaction, namespace, path,
                           old_value, new_value)
      VALUES (?, ?, ?, ?, ?, ?)
    """
    insert_statement = self.session.prepare(insert_item)

    statements_and_params = []
    for entity_change in entity_changes:
      old_value = None
      if entity_change['old'] is not None:
        old_value = bytearray(entity_change['old'].Encode())
      new_value = None
      if entity_change['new'] is not None:
        new_value = bytearray(entity_change['new'].Encode())

      parameters = (app, txn, entity_change['key'].name_space(),
                    bytearray(entity_change['key'].path().Encode()), old_value,
                    new_value)
      statements_and_params.append((insert_statement, parameters))

    try:
      execute_concurrent(self.session, statements_and_params,
                         raise_on_first_error=True)
    except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
      message = 'Unable to write large batch log'
      logging.exception(message)
      raise AppScaleDBConnectionError(message)

    try:
      large_batch.set_applied()
    except FailedBatch as batch_error:
      raise AppScaleDBConnectionError(str(batch_error))

    try:
      self.apply_mutations(mutations, txn)
    except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
      message = 'Exception during large batch'
      logging.exception(message)
      raise AppScaleDBConnectionError(message)

    try:
      large_batch.cleanup()
    except FailedBatch:
      # This should not raise an exception since the batch is already applied.
      logging.exception('Unable to clear batch status')

    clear_batch = """
      DELETE FROM batches
      WHERE app = %(app)s AND transaction = %(transaction)s
    """
    parameters = {'app': app, 'transaction': txn}
    try:
      self.session.execute(clear_batch, parameters)
    except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
      logging.exception('Unable to clear batch log')

  def batch_mutate(self, app, mutations, entity_changes, txn):
    """ Insert or delete multiple rows across tables in an atomic statement.

    Args:
      app: A string containing the application ID.
      mutations: A list of dictionaries representing mutations.
      entity_changes: A list of changes at the entity level.
      txn: A transaction ID handler.
    """
    size = batch_size(mutations)
    if size > LARGE_BATCH_THRESHOLD:
      self._large_batch(app, mutations, entity_changes, txn)
    else:
      self._normal_batch(mutations, txn)

  def batch_delete(self, table_name, row_keys, column_names=()):
    """
    Remove a set of rows corresponding to a set of keys.
     
    Args:
      table_name: Table to delete rows from
      row_keys: A list of keys to remove
      column_names: Not used
    Raises:
      TypeError: If an argument passed in was not of the expected type.
      AppScaleDBConnectionError: If the batch_delete could not be performed due
        to an error with Cassandra.
    """ 
    if not isinstance(table_name, str): raise TypeError("Expected a str")
    if not isinstance(row_keys, list): raise TypeError("Expected a list")

    row_keys_bytes = [bytearray(row_key) for row_key in row_keys]

    statement = 'DELETE FROM "{table}" WHERE {key} IN %s'.\
      format(
        table=table_name,
        key=ThriftColumn.KEY
      )
    query = SimpleStatement(statement, retry_policy=BASIC_RETRIES)
    parameters = (ValueSequence(row_keys_bytes),)

    try:
      self.session.execute(query, parameters=parameters)
    except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
      message = 'Exception during batch_delete'
      logging.exception(message)
      raise AppScaleDBConnectionError(message)

  def delete_table(self, table_name):
    """ 
    Drops a given table (aka column family in Cassandra)
  
    Args:
      table_name: A string name of the table to drop
    Raises:
      TypeError: If an argument passed in was not of the expected type.
      AppScaleDBConnectionError: If the delete_table could not be performed due
        to an error with Cassandra.
    """
    if not isinstance(table_name, str): raise TypeError("Expected a str")

    statement = 'DROP TABLE IF EXISTS "{table}"'.format(table=table_name)
    query = SimpleStatement(statement, retry_policy=BASIC_RETRIES)

    try:
      self.session.execute(query)
    except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
      message = 'Exception during delete_table'
      logging.exception(message)
      raise AppScaleDBConnectionError(message)

  def create_table(self, table_name, column_names):
    """ 
    Creates a table if it doesn't already exist.
    
    Args:
      table_name: The column family name
      column_names: Not used but here to match the interface
    Raises:
      TypeError: If an argument passed in was not of the expected type.
      AppScaleDBConnectionError: If the create_table could not be performed due
        to an error with Cassandra.
    """
    if not isinstance(table_name, str): raise TypeError("Expected a str")
    if not isinstance(column_names, list): raise TypeError("Expected a list")

    statement = 'CREATE TABLE IF NOT EXISTS "{table}" ('\
        '{key} blob,'\
        '{column} text,'\
        '{value} blob,'\
        'PRIMARY KEY ({key}, {column})'\
      ') WITH COMPACT STORAGE'.format(
        table=table_name,
        key=ThriftColumn.KEY,
        column=ThriftColumn.COLUMN_NAME,
        value=ThriftColumn.VALUE
      )
    query = SimpleStatement(statement, retry_policy=NO_RETRIES)

    try:
      self.session.execute(query, timeout=SCHEMA_CHANGE_TIMEOUT)
    except cassandra.OperationTimedOut:
      logging.warning(
        'Encountered an operation timeout while creating a table. Waiting {} '
        'seconds for schema to settle.'.format(SCHEMA_CHANGE_TIMEOUT))
      time.sleep(SCHEMA_CHANGE_TIMEOUT)
      raise AppScaleDBConnectionError('Exception during create_table')
    except (error for error in dbconstants.TRANSIENT_CASSANDRA_ERRORS
            if error != cassandra.OperationTimedOut):
      message = 'Exception during create_table'
      logging.exception(message)
      raise AppScaleDBConnectionError(message)

  def range_query(self,
                  table_name,
                  column_names, 
                  start_key, 
                  end_key, 
                  limit, 
                  offset=0, 
                  start_inclusive=True, 
                  end_inclusive=True,
                  keys_only=False):
    """ 
    Gets a dense range ordered by keys. Returns an ordered list of 
    a dictionary of [key:{column1:value1, column2:value2},...]
    or a list of keys if keys only.
     
    Args:
      table_name: Name of table to access
      column_names: Columns which get returned within the key range
      start_key: String for which the query starts at
      end_key: String for which the query ends at
      limit: Maximum number of results to return
      offset: Cuts off these many from the results [offset:]
      start_inclusive: Boolean if results should include the start_key
      end_inclusive: Boolean if results should include the end_key
      keys_only: Boolean if to only keys and not values
    Raises:
      TypeError: If an argument passed in was not of the expected type.
      AppScaleDBConnectionError: If the range_query could not be performed due
        to an error with Cassandra.
    Returns:
      An ordered list of dictionaries of key=>columns/values
    """
    if not isinstance(table_name, str):
      raise TypeError('table_name must be a string')
    if not isinstance(column_names, list):
      raise TypeError('column_names must be a list')
    if not isinstance(start_key, str):
      raise TypeError('start_key must be a string')
    if not isinstance(end_key, str):
      raise TypeError('end_key must be a string')
    if not isinstance(limit, (int, long)) and limit is not None:
      raise TypeError('limit must be int, long, or NoneType')
    if not isinstance(offset, (int, long)):
      raise TypeError('offset must be int or long')

    if start_inclusive:
      gt_compare = '>='
    else:
      gt_compare = '>'

    if end_inclusive:
      lt_compare = '<='
    else:
      lt_compare = '<'

    query_limit = ''
    if limit is not None:
      query_limit = 'LIMIT {}'.format(len(column_names) * limit)

    statement = """
      SELECT * FROM "{table}" WHERE
      token({key}) {gt_compare} %s AND
      token({key}) {lt_compare} %s AND
      {column} IN %s
      {limit}
      ALLOW FILTERING
    """.format(table=table_name,
               key=ThriftColumn.KEY,
               gt_compare=gt_compare,
               lt_compare=lt_compare,
               column=ThriftColumn.COLUMN_NAME,
               limit=query_limit)

    query = SimpleStatement(statement, retry_policy=BASIC_RETRIES)
    parameters = (bytearray(start_key), bytearray(end_key),
                  ValueSequence(column_names))

    try:
      results = self.session.execute(query, parameters=parameters)

      results_list = []
      current_item = {}
      current_key = None
      for (key, column, value) in results:
        if keys_only:
          results_list.append(key)
          continue

        if key != current_key:
          if current_item:
            results_list.append({current_key: current_item})
          current_item = {}
          current_key = key

        current_item[column] = value
      if current_item:
        results_list.append({current_key: current_item})
      return results_list[offset:]
    except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
      message = 'Exception during range_query'
      logging.exception(message)
      raise AppScaleDBConnectionError(message)

  def get_metadata(self, key):
    """ Retrieve a value from the datastore metadata table.

    Args:
      key: A string containing the key to fetch.
    Returns:
      A string containing the value or None if the key is not present.
    """
    statement = """
      SELECT {value} FROM "{table}"
      WHERE {key} = %s
      AND {column} = %s
    """.format(
      value=ThriftColumn.VALUE,
      table=dbconstants.DATASTORE_METADATA_TABLE,
      key=ThriftColumn.KEY,
      column=ThriftColumn.COLUMN_NAME
    )
    try:
      results = self.session.execute(statement, (bytearray(key), key))
    except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
      message = 'Unable to fetch {} from datastore metadata'.format(key)
      logging.exception(message)
      raise AppScaleDBConnectionError(message)

    try:
      return results[0].value
    except IndexError:
      return None

  def set_metadata(self, key, value):
    """ Set a datastore metadata value.

    Args:
      key: A string containing the key to set.
      value: A string containing the value to set.
    """
    if not isinstance(key, str):
      raise TypeError('key should be a string')

    if not isinstance(value, str):
      raise TypeError('value should be a string')

    statement = """
      INSERT INTO "{table}" ({key}, {column}, {value})
      VALUES (%(key)s, %(column)s, %(value)s)
    """.format(
      table=dbconstants.DATASTORE_METADATA_TABLE,
      key=ThriftColumn.KEY,
      column=ThriftColumn.COLUMN_NAME,
      value=ThriftColumn.VALUE
    )
    parameters = {'key': bytearray(key),
                  'column': key,
                  'value': bytearray(value)}
    try:
      self.session.execute(statement, parameters)
    except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
      message = 'Unable to set datastore metadata for {}'.format(key)
      logging.exception(message)
      raise AppScaleDBConnectionError(message)
    except cassandra.InvalidRequest:
      self.create_table(dbconstants.DATASTORE_METADATA_TABLE,
                        dbconstants.DATASTORE_METADATA_SCHEMA)
      self.session.execute(statement, parameters)

  def get_indices(self, app_id):
    """ Gets the indices of the given application.

    Args:
      app_id: Name of the application.
    Returns:
      Returns a list of encoded entity_pb.CompositeIndex objects.
    """
    start_key = dbconstants.KEY_DELIMITER.join([app_id, 'index', ''])
    end_key = dbconstants.KEY_DELIMITER.join(
      [app_id, 'index', dbconstants.TERMINATING_STRING])
    result = self.range_query(
      dbconstants.METADATA_TABLE,
      dbconstants.METADATA_SCHEMA,
      start_key,
      end_key,
      dbconstants.MAX_NUMBER_OF_COMPOSITE_INDEXES,
      offset=0,
      start_inclusive=True,
      end_inclusive=True)
    list_result = []
    for list_item in result:
      for key, value in list_item.iteritems():
        list_result.append(value['data'])
    return list_result

  def valid_data_version(self):
    """ Checks whether or not the data layout can be used.

    Returns:
      A boolean.
    """
    try:
      version = self.get_metadata(VERSION_INFO_KEY)
    except cassandra.InvalidRequest:
      return False

    return version is not None and float(version) == EXPECTED_DATA_VERSION

  def group_updates(self, groups):
    """ Fetch the latest transaction IDs for each group.

    Args:
      groups: An interable containing encoded Reference objects.
    Returns:
      A set of integers specifying transaction IDs.
    """
    futures = []
    for group in groups:
      query = 'SELECT * FROM group_updates WHERE group=%s'
      futures.append(self.session.execute_async(query, [bytearray(group)]))

    updates = set()
    for future in futures:
      rows = future.result()
      try:
        result = rows[0]
      except IndexError:
        continue

      updates.add(result.last_update)

    return updates

  def start_transaction(self, app, txid, is_xg, in_progress):
    """ Persist transaction metadata.

    Args:
      app: A string containing an application ID.
      txid: An integer specifying the transaction ID.
      is_xg: A boolean specifying that the transaction is cross-group.
      in_progress: An iterable containing transaction IDs.
    """
    if in_progress:
      in_progress_bin = bytearray(
        struct.pack('q' * len(in_progress), *in_progress))
    else:
      in_progress_bin = None

    insert = """
      INSERT INTO transactions (txid_hash, operation, namespace, path,
                                start_time, is_xg, in_progress)
      VALUES (%(txid_hash)s, %(operation)s, %(namespace)s, %(path)s,
              %(start_time)s, %(is_xg)s, %(in_progress)s)
      USING TTL {ttl}
    """.format(ttl=dbconstants.MAX_TX_DURATION * 2)
    parameters = {'txid_hash': tx_partition(app, txid),
                  'operation': TxnActions.START,
                  'namespace': '',
                  'path': bytearray(''),
                  'start_time': datetime.datetime.utcnow(),
                  'is_xg': is_xg,
                  'in_progress': in_progress_bin}

    try:
      self.session.execute(insert, parameters)
    except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
      message = 'Exception while starting a transaction'
      logging.exception(message)
      raise AppScaleDBConnectionError(message)

  def put_entities_tx(self, app, txid, entities):
    """ Update transaction metadata with new put operations.

    Args:
      app: A string containing an application ID.
      txid: An integer specifying the transaction ID.
      entities: A list of entities that will be put upon commit.
    """
    batch = BatchStatement(consistency_level=ConsistencyLevel.QUORUM,
                           retry_policy=BASIC_RETRIES)
    insert = self.session.prepare("""
      INSERT INTO transactions (txid_hash, operation, namespace, path, entity)
      VALUES (?, ?, ?, ?, ?)
      USING TTL {ttl}
    """.format(ttl=dbconstants.MAX_TX_DURATION * 2))

    for entity in entities:
      args = (tx_partition(app, txid),
              TxnActions.MUTATE,
              entity.key().name_space(),
              bytearray(entity.key().path().Encode()),
              bytearray(entity.Encode()))
      batch.add(insert, args)

    try:
      self.session.execute(batch)
    except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
      message = 'Exception while putting entities in a transaction'
      logging.exception(message)
      raise AppScaleDBConnectionError(message)

  def delete_entities_tx(self, app, txid, entity_keys):
    """ Update transaction metadata with new delete operations.

    Args:
      app: A string containing an application ID.
      txid: An integer specifying the transaction ID.
      entity_keys: A list of entity keys that will be deleted upon commit.
    """
    batch = BatchStatement(consistency_level=ConsistencyLevel.QUORUM,
                           retry_policy=BASIC_RETRIES)
    insert = self.session.prepare("""
      INSERT INTO transactions (txid_hash, operation, namespace, path, entity)
      VALUES (?, ?, ?, ?, ?)
      USING TTL {ttl}
    """.format(ttl=dbconstants.MAX_TX_DURATION * 2))

    for key in entity_keys:
      # The None value overwrites previous puts.
      args = (tx_partition(app, txid),
              TxnActions.MUTATE,
              key.name_space(),
              bytearray(key.path().Encode()),
              None)
      batch.add(insert, args)

    try:
      self.session.execute(batch)
    except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
      message = 'Exception while deleting entities in a transaction'
      logging.exception(message)
      raise AppScaleDBConnectionError(message)

  def transactional_tasks_count(self, app, txid):
    """ Count the number of existing tasks associated with the transaction.

    Args:
      app: A string specifying an application ID.
      txid: An integer specifying a transaction ID.
    Returns:
      An integer specifying the number of existing tasks.
    """
    select = """
      SELECT count(*) FROM transactions
      WHERE txid_hash = %(txid_hash)s
      AND operation = %(operation)s
    """
    parameters = {'txid_hash': tx_partition(app, txid),
                  'operation': TxnActions.ENQUEUE_TASK}
    try:
      return self.session.execute(select, parameters)[0].count
    except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
      message = 'Exception while fetching task count'
      logging.exception(message)
      raise AppScaleDBConnectionError(message)

  def add_transactional_tasks(self, app, txid, tasks):
    """ Add tasks to be enqueued upon the completion of a transaction.

    Args:
      app: A string specifying an application ID.
      txid: An integer specifying a transaction ID.
      tasks: A list of TaskQueueAddRequest objects.
    """
    batch = BatchStatement(consistency_level=ConsistencyLevel.QUORUM,
                           retry_policy=BASIC_RETRIES)
    insert = self.session.prepare("""
      INSERT INTO transactions (txid_hash, operation, namespace, path, task)
      VALUES (?, ?, ?, ?, ?)
      USING TTL {ttl}
    """.format(ttl=dbconstants.MAX_TX_DURATION * 2))

    for task in tasks:
      task.clear_transaction()

      # The path for the task entry doesn't matter as long as it's unique.
      path = bytearray(str(uuid.uuid4()))

      args = (tx_partition(app, txid),
              TxnActions.ENQUEUE_TASK,
              '',
              path,
              task.Encode())
      batch.add(insert, args)

    try:
      self.session.execute(batch)
    except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
      message = 'Exception while adding tasks in a transaction'
      logging.exception(message)
      raise AppScaleDBConnectionError(message)

  def record_reads(self, app, txid, group_keys):
    """ Keep track of which entity groups were read in a transaction.

    Args:
      app: A string specifying an application ID.
      txid: An integer specifying a transaction ID.
      group_keys: An iterable containing Reference objects.
    """
    batch = BatchStatement(consistency_level=ConsistencyLevel.QUORUM,
                           retry_policy=BASIC_RETRIES)
    insert = self.session.prepare("""
      INSERT INTO transactions (txid_hash, operation, namespace, path)
      VALUES (?, ?, ?, ?)
      USING TTL {ttl}
    """.format(ttl=dbconstants.MAX_TX_DURATION * 2))

    for group_key in group_keys:
      if not isinstance(group_key, entity_pb.Reference):
        group_key = entity_pb.Reference(group_key)

      args = (tx_partition(app, txid),
              TxnActions.GET,
              group_key.name_space(),
              bytearray(group_key.path().Encode()))
      batch.add(insert, args)

    try:
      self.session.execute(batch)
    except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
      message = 'Exception while recording reads in a transaction'
      logging.exception(message)
      raise AppScaleDBConnectionError(message)

  def get_transaction_metadata(self, app, txid):
    """ Fetch transaction state.

    Args:
      app: A string specifying an application ID.
      txid: An integer specifying a transaction ID.
    Returns:
      A dictionary containing transaction state.
    """
    select = """
      SELECT namespace, operation, path, start_time, is_xg, in_progress,
             entity, task
      FROM transactions
      WHERE txid_hash = %(txid_hash)s
    """
    parameters = {'txid_hash': tx_partition(app, txid)}
    try:
      results = self.session.execute(select, parameters)
    except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
      message = 'Exception while inserting entities in a transaction'
      logging.exception(message)
      raise AppScaleDBConnectionError(message)

    metadata = {'puts': {}, 'deletes': [], 'tasks': [], 'reads': set()}
    for result in results:
      if result.operation == TxnActions.START:
        metadata['start'] = result.start_time
        metadata['is_xg'] = result.is_xg
        metadata['in_progress'] = set()
        if metadata['in_progress'] is not None:
          metadata['in_progress'] = set(
            struct.unpack('q' * int(len(result.in_progress) / 8),
                          result.in_progress))
      if result.operation == TxnActions.MUTATE:
        key = create_key(app, result.namespace, result.path)
        if result.entity is None:
          metadata['deletes'].append(key)
        else:
          metadata['puts'][key.Encode()] = result.entity
      if result.operation == TxnActions.GET:
        group_key = create_key(app, result.namespace, result.path)
        metadata['reads'].add(group_key.Encode())
      if result.operation == TxnActions.ENQUEUE_TASK:
        metadata['tasks'].append(
          taskqueue_service_pb.TaskQueueAddRequest(result.task))
    return metadata