def test_pool_management(self): # Ensure that in_flight and request_ids quiesce after cluster operations cluster = Cluster( protocol_version=PROTOCOL_VERSION, idle_heartbeat_interval=0 ) # no idle heartbeat here, pool management is tested in test_idle_heartbeat session = cluster.connect() session2 = cluster.connect() # prepare p = session.prepare("SELECT * FROM system.local WHERE key=?") self.assertTrue(session.execute(p, ('local', ))) # simple self.assertTrue( session.execute("SELECT * FROM system.local WHERE key='local'")) # set keyspace session.set_keyspace('system') session.set_keyspace('system_traces') # use keyspace session.execute('USE system') session.execute('USE system_traces') # refresh schema cluster.refresh_schema_metadata() cluster.refresh_schema_metadata(max_schema_agreement_wait=0) assert_quiescent_pool_state(self, cluster) cluster.shutdown()
def test_pool_management(self): # Ensure that in_flight and request_ids quiesce after cluster operations cluster = Cluster(protocol_version=PROTOCOL_VERSION, idle_heartbeat_interval=0) # no idle heartbeat here, pool management is tested in test_idle_heartbeat session = cluster.connect() session2 = cluster.connect() # prepare p = session.prepare("SELECT * FROM system.local WHERE key=?") self.assertTrue(session.execute(p, ('local',))) # simple self.assertTrue(session.execute("SELECT * FROM system.local WHERE key='local'")) # set keyspace session.set_keyspace('system') session.set_keyspace('system_traces') # use keyspace session.execute('USE system') session.execute('USE system_traces') # refresh schema cluster.refresh_schema_metadata() cluster.refresh_schema_metadata(max_schema_agreement_wait=0) # submit schema refresh future = cluster.submit_schema_refresh() future.result() assert_quiescent_pool_state(self, cluster) cluster.shutdown()
def test_refresh_schema(self): cluster = Cluster(protocol_version=PROTOCOL_VERSION) session = cluster.connect() original_meta = cluster.metadata.keyspaces # full schema refresh, with wait cluster.refresh_schema_metadata() self.assertIsNot(original_meta, cluster.metadata.keyspaces) self.assertEqual(original_meta, cluster.metadata.keyspaces) cluster.shutdown()
def test_refresh_schema(self): cluster = Cluster(protocol_version=PROTOCOL_VERSION) session = cluster.connect() original_meta = cluster.metadata.keyspaces # full schema refresh, with wait cluster.refresh_schema_metadata() self.assertIsNot(original_meta, cluster.metadata.keyspaces) self.assertEqual(original_meta, cluster.metadata.keyspaces) cluster.shutdown()
class Migrator(object): """Execute migration operations in a C* database based on configuration. `opts` must contain at least the following attributes: - config_file: path to a YAML file containing the configuration - profiles: map of profile names to keyspace settings - user, password: authentication options. May be None to not use it. - hosts: comma-separated list of contact points - port: connection port """ logger = logging.getLogger("Migrator") def __init__(self, config, profile='dev', hosts=['127.0.0.1'], port=9042, user=None, password=None, host_cert_path=None, client_key_path=None, client_cert_path=None): self.config = config try: self.current_profile = self.config.profiles[profile] except KeyError: raise ValueError("Invalid profile name '{}'".format(profile)) if user: auth_provider = PlainTextAuthProvider(user, password) else: auth_provider = None if host_cert_path: ssl_options = self._build_ssl_options(host_cert_path, client_key_path, client_cert_path) else: ssl_options = None self.cluster = Cluster(contact_points=hosts, port=port, auth_provider=auth_provider, max_schema_agreement_wait=300, control_connection_timeout=10, connect_timeout=30, ssl_options=ssl_options) self._session = None def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): if self._session is not None: self._session.shutdown() self._session = None if self.cluster is not None: self.cluster.shutdown() self.cluster = None def _build_ssl_options(self, host_cert_path, client_key_path, client_cert_path): return { 'ca_certs': host_cert_path, 'certfile': client_cert_path, 'keyfile': client_key_path } def _check_cluster(self): """Check if the cluster is still alive, raise otherwise""" if not self.cluster: raise RuntimeError("Cluster has shut down") def _init_session(self): if not self._session: s = self._session = self.cluster.connect() s.default_consistency_level = ConsistencyLevel.ALL s.default_serial_consistency_level = ConsistencyLevel.SERIAL s.default_timeout = 120 @property def session(self): """Initialize and configure a C* driver session if needed""" self._check_cluster() self._init_session() return self._session def _get_target_version(self, v): """ Parses a version specifier to an actual numeric migration version `v` might be: - None: the latest version is chosen - an int, or a numeric string: that exact version is chosen - a string: the version with that name is chosen if it exists If an invalid version is found, `ValueError` is raised. """ if v is None: return len(self.config.migrations) if isinstance(v, int): num = v elif v.isdigit(): num = int(v) else: try: num = self.config.migrations.index(v) except IndexError: num = -1 if num <= 0: raise ValueError('Invalid database version, must be a number > 0 ' 'or the name of an existing migration') return num def _q(self, query, **kwargs): """ Format a query with the configured keyspace and migration table `keyspace` and `table` are interpolated as named arguments """ return query.format(keyspace=self.config.keyspace, table=self.config.migrations_table, **kwargs) def _execute(self, query, *args, **kwargs): """Execute a query with the current session""" self.logger.debug('Executing query: {}'.format(query)) return self.session.execute(query, *args, **kwargs) def _keyspace_exists(self): self._init_session() return self.config.keyspace in self.cluster.metadata.keyspaces def _ensure_keyspace(self): """Create the keyspace if it does not exist""" if self._keyspace_exists(): return self.logger.info("Creating keyspace '{}'".format(self.config.keyspace)) profile = self.current_profile self._execute( self._q(CREATE_KEYSPACE, replication=cassandra_ddl_repr(profile['replication']), durable_writes=cassandra_ddl_repr( profile['durable_writes']))) self.cluster.refresh_keyspace_metadata(self.config.keyspace) def _table_exists(self): self._init_session() ks_metadata = self.cluster.metadata.keyspaces.get( self.config.keyspace, None) # Fail if the keyspace is missing. If it should be created # automatically _ensure_keyspace() must be called first. if not ks_metadata: raise ValueError("Keyspace '{}' does not exist, " "stopping".format(self.config.keyspace)) return self.config.migrations_table in ks_metadata.tables def _ensure_table(self): """Create the migration table if it does not exist""" if self._table_exists(): return self.logger.info( "Creating table '{table}' in keyspace '{keyspace}'".format( keyspace=self.config.keyspace, table=self.config.migrations_table)) self._execute(self._q(CREATE_MIGRATIONS_TABLE)) self.cluster.refresh_table_metadata(self.config.keyspace, self.config.migrations_table) def _verify_migrations(self, migrations, ignore_failed=False, ignore_concurrent=False): """Verify if the version history persisted in C* matches the migrations Migrations with corresponding DB versions must have the same content and name. Every DB version must have a corresponding migration. Migrations without corresponding DB versions are considered pending, and returned as a result. Returns a list of tuples of (version, migration), with version starting from 1 and incrementing linearly. """ # Load all the currently existing versions and sort them by version # number, as Cassandra can only sort it for us by partition. cur_versions = self._execute( self._q('SELECT * FROM "{keyspace}"."{table}"')) cur_versions = sorted(cur_versions, key=lambda v: v.version) last_version = None version_pairs = zip_longest(cur_versions, migrations) # Work through ordered pairs of (existing version, migration), so that # stored versions and expected migrations can be compared for any # differences. for i, (version, migration) in enumerate(version_pairs, 1): # If version is empty, the migration has not yet been applied. # Keep track of the first such version, and append it to the # pending migrations list. if not version: break # If migration is empty, we have a version in the database with # no corresponding file. That might mean we're running the wrong # migration or have an out-of-date state, so we must fail. if not migration: raise UnknownMigration(version.version, version.name) # A migration was previously run and failed. if version.state == Migration.State.FAILED: if ignore_failed: break raise FailedMigration(version.version, version.name) last_version = version.version # A stored version's migrations differs from the one in the FS. if version.content != migration.content or \ version.name != migration.name or \ bytearray(version.checksum) != bytearray(migration.checksum): raise InconsistentState(migration, version) if not last_version: pending_migrations = list(migrations) else: pending_migrations = list(migrations)[last_version:] if not pending_migrations: self.logger.info('Database is already up-to-date') else: self.logger.info('Pending migrations found. Current version: {}, ' 'Latest version: {}'.format( last_version, len(migrations))) pending_migrations = enumerate(pending_migrations, (last_version or 0) + 1) return last_version, cur_versions, list(pending_migrations) def _create_version(self, version, migration): """ Write an in-progress version entry to C* The migration is inserted with the given `version` number if and only if it does not exist already (using a CompareAndSet operation). If the insert suceeds (with the migration marked as in-progress), we can continue and actually execute it. Otherwise, there was a concurrent write and we must fail to allow the other write to continue. """ self.logger.info('Writing in-progress migration version {}: {}'.format( version, migration)) version_id = uuid.uuid4() result = self._execute( self._q(CREATE_DB_VERSION), (version_id, version, migration.name, migration.content, bytearray(migration.checksum), Migration.State.IN_PROGRESS)) return version_id def _apply_cql_migration(self, version, migration): """ Persist and apply a cql migration First create an in-progress version entry, apply the script, then finalize the entry as succeeded, failed or skipped. """ self.logger.info('Applying cql migration') statements = CqlSplitter.split(migration.content) try: if statements: self.logger.info('Executing migration with ' '{} CQL statements'.format(len(statements))) for statement in statements: self.session.execute(statement) except Exception: raise FailedMigration(version, migration.name) def _apply_python_migration(self, version, migration): """ Persist and apply a python migration First create an in-progress version entry, apply the script, then finalize the entry as succeeded, failed or skipped. """ self.logger.info('Applying python script') try: mod, _ = os.path.splitext(os.path.basename(migration.path)) migration_script = importlib.import_module(mod) migration_script.execute(self._session) except Exception: self.logger.exception('Failed to execute script') raise FailedMigration(version, migration.name) def _apply_migration(self, version, migration, skip=False): """ Persist and apply a migration When `skip` is True, do everything but actually run the script, for example, when baselining instead of migrating. """ self.logger.info('Advancing to version {}'.format(version)) version_uuid = self._create_version(version, migration) new_state = Migration.State.FAILED sys.path.append(self.config.migrations_path) result = None try: if skip: self.logger.info('Migration is marked for skipping, ' 'not actually running script') else: if migration.is_python: self._apply_python_migration(version, migration) else: self._apply_cql_migration(version, migration) except Exception: self.logger.exception('Failed to execute migration') raise FailedMigration(version, migration.name) else: new_state = (Migration.State.SUCCEEDED if not skip else Migration.State.SKIPPED) finally: self.logger.info('Finalizing migration version with ' 'state {}'.format(new_state)) result = self._execute(self._q(FINALIZE_DB_VERSION), (new_state, version_uuid)) def _cleanup_previous_versions(self, cur_versions): if not cur_versions: return last_version = cur_versions[-1] if last_version.state != Migration.State.FAILED: return self.logger.warn('Cleaning up previous failed migration ' '(version {}): {}'.format(last_version.version, last_version.name)) result = self._execute(self._q(DELETE_DB_VERSION), (last_version.id, Migration.State.FAILED)) def _advance(self, migrations, target, cur_versions, skip=False, force=False): """Apply all necessary migrations to reach a target version""" if force: self._cleanup_previous_versions(cur_versions) target_version = self._get_target_version(target) if migrations: # Set default keyspace so migrations don't need to refer to it # manually # Fixes https://github.com/Cobliteam/cassandra-migrate/issues/5 self.session.execute('USE {};'.format(self.config.keyspace)) for version, migration in migrations: if version > target_version: break self._apply_migration(version, migration, skip=skip) self.cluster.refresh_schema_metadata() def baseline(self, opts): """Baseline a database, by advancing migration state without changes""" self._check_cluster() self._ensure_table() last_version, cur_versions, pending_migrations = \ self._verify_migrations(self.config.migrations, ignore_failed=False) self._advance(pending_migrations, opts.db_version, cur_versions, skip=True) @confirmation_required def migrate(self, opts): """ Migrate a database to a given version, applying any needed migrations """ self._check_cluster() self._ensure_keyspace() self._ensure_table() last_version, cur_versions, pending_migrations = \ self._verify_migrations(self.config.migrations, ignore_failed=opts.force) self._advance(pending_migrations, opts.db_version, cur_versions, force=opts.force) @confirmation_required def reset(self, opts): """Reset a database, by dropping the keyspace then migrating""" self._check_cluster() self.logger.info("Dropping existing keyspace '{}'".format( self.config.keyspace)) self._execute(self._q(DROP_KEYSPACE)) self.cluster.refresh_schema_metadata() opts.force = False self.migrate(opts) @staticmethod def _bytes_to_hex(bs): return codecs.getencoder('hex')(bs)[0] def status(self, opts): self._check_cluster() if not self._keyspace_exists(): print("Keyspace '{}' does not exist".format(self.config.keyspace)) return if not self._table_exists(): print("Migration table '{table}' does not exist in " "keyspace '{keyspace}'".format( keyspace=self.config.keyspace, table=self.config.migrations_table)) return last_version, cur_versions, pending_migrations = \ self._verify_migrations(self.config.migrations, ignore_failed=True, ignore_concurrent=True) latest_version = len(self.config.migrations) print( tabulate((('Keyspace:', self.config.keyspace), ('Migrations table:', self.config.migrations_table), ('Current DB version:', last_version), ('Latest DB version:', latest_version)), tablefmt='plain')) if cur_versions: print('\n## Applied migrations\n') data = [] for version in cur_versions: checksum = self._bytes_to_hex(version.checksum) date = arrow.get(version.applied_at).format() data.append((str(version.version), version.name, version.state, date, checksum)) print( tabulate( data, headers=['#', 'Name', 'State', 'Date applied', 'Checksum'])) if pending_migrations: print('\n## Pending migrations\n') data = [] for version, migration in pending_migrations: checksum = self._bytes_to_hex(migration.checksum) data.append((str(version), migration.name, checksum)) print(tabulate(data, headers=['#', 'Name', 'Checksum']))
def test_refresh_schema_no_wait(self): contact_points = ['127.0.0.1'] cluster = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=10, contact_points=contact_points, load_balancing_policy=WhiteListRoundRobinPolicy(contact_points)) session = cluster.connect() schema_ver = session.execute("SELECT schema_version FROM system.local WHERE key='local'")[0][0] # create a schema disagreement session.execute("UPDATE system.local SET schema_version=%s WHERE key='local'", (uuid4(),)) try: agreement_timeout = 1 # cluster agreement wait exceeded c = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=agreement_timeout) start_time = time.time() s = c.connect() end_time = time.time() self.assertGreaterEqual(end_time - start_time, agreement_timeout) self.assertTrue(c.metadata.keyspaces) # cluster agreement wait used for refresh original_meta = c.metadata.keyspaces start_time = time.time() self.assertRaisesRegexp(Exception, r"Schema metadata was not refreshed.*", c.refresh_schema_metadata) end_time = time.time() self.assertGreaterEqual(end_time - start_time, agreement_timeout) self.assertIs(original_meta, c.metadata.keyspaces) # refresh wait overrides cluster value original_meta = c.metadata.keyspaces start_time = time.time() c.refresh_schema_metadata(max_schema_agreement_wait=0) end_time = time.time() self.assertLess(end_time - start_time, agreement_timeout) self.assertIsNot(original_meta, c.metadata.keyspaces) self.assertEqual(original_meta, c.metadata.keyspaces) c.shutdown() refresh_threshold = 0.5 # cluster agreement bypass c = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=0) start_time = time.time() s = c.connect() end_time = time.time() self.assertLess(end_time - start_time, refresh_threshold) self.assertTrue(c.metadata.keyspaces) # cluster agreement wait used for refresh original_meta = c.metadata.keyspaces start_time = time.time() c.refresh_schema_metadata() end_time = time.time() self.assertLess(end_time - start_time, refresh_threshold) self.assertIsNot(original_meta, c.metadata.keyspaces) self.assertEqual(original_meta, c.metadata.keyspaces) # refresh wait overrides cluster value original_meta = c.metadata.keyspaces start_time = time.time() self.assertRaisesRegexp(Exception, r"Schema metadata was not refreshed.*", c.refresh_schema_metadata, max_schema_agreement_wait=agreement_timeout) end_time = time.time() self.assertGreaterEqual(end_time - start_time, agreement_timeout) self.assertIs(original_meta, c.metadata.keyspaces) c.shutdown() finally: session.execute("UPDATE system.local SET schema_version=%s WHERE key='local'", (schema_ver,)) cluster.shutdown()
def test_refresh_schema_no_wait(self): contact_points = ['127.0.0.1'] cluster = Cluster( protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=10, contact_points=contact_points, load_balancing_policy=WhiteListRoundRobinPolicy(contact_points)) session = cluster.connect() schema_ver = session.execute( "SELECT schema_version FROM system.local WHERE key='local'")[0][0] new_schema_ver = uuid4() session.execute( "UPDATE system.local SET schema_version=%s WHERE key='local'", (new_schema_ver, )) try: agreement_timeout = 1 # cluster agreement wait exceeded c = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=agreement_timeout) c.connect() self.assertTrue(c.metadata.keyspaces) # cluster agreement wait used for refresh original_meta = c.metadata.keyspaces start_time = time.time() self.assertRaisesRegexp(Exception, r"Schema metadata was not refreshed.*", c.refresh_schema_metadata) end_time = time.time() self.assertGreaterEqual(end_time - start_time, agreement_timeout) self.assertIs(original_meta, c.metadata.keyspaces) # refresh wait overrides cluster value original_meta = c.metadata.keyspaces start_time = time.time() c.refresh_schema_metadata(max_schema_agreement_wait=0) end_time = time.time() self.assertLess(end_time - start_time, agreement_timeout) self.assertIsNot(original_meta, c.metadata.keyspaces) self.assertEqual(original_meta, c.metadata.keyspaces) c.shutdown() refresh_threshold = 0.5 # cluster agreement bypass c = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=0) start_time = time.time() s = c.connect() end_time = time.time() self.assertLess(end_time - start_time, refresh_threshold) self.assertTrue(c.metadata.keyspaces) # cluster agreement wait used for refresh original_meta = c.metadata.keyspaces start_time = time.time() c.refresh_schema_metadata() end_time = time.time() self.assertLess(end_time - start_time, refresh_threshold) self.assertIsNot(original_meta, c.metadata.keyspaces) self.assertEqual(original_meta, c.metadata.keyspaces) # refresh wait overrides cluster value original_meta = c.metadata.keyspaces start_time = time.time() self.assertRaisesRegexp( Exception, r"Schema metadata was not refreshed.*", c.refresh_schema_metadata, max_schema_agreement_wait=agreement_timeout) end_time = time.time() self.assertGreaterEqual(end_time - start_time, agreement_timeout) self.assertIs(original_meta, c.metadata.keyspaces) c.shutdown() finally: # TODO once fixed this connect call session = cluster.connect() session.execute( "UPDATE system.local SET schema_version=%s WHERE key='local'", (schema_ver, )) cluster.shutdown()
"-mode cql3", "native", "-rate threads=1", "-log level=verbose interval=5", "-errors retries=999 ignore", "-node {}".format( master_nodes[0].ip()) ], stdout=stressor_log, stderr=subprocess.STDOUT)) log('Letting stressor run for a while...') # Sleep long enough for stressor to create the keyspace time.sleep(10) log('Fetching schema definitions from master cluster.') cm.control_connection_timeout = 20 with cm.connect() as _: cm.refresh_schema_metadata(max_schema_agreement_wait=10) ks = cm.metadata.keyspaces[KS_NAME] ut_ddls = [t[1].as_cql_query() for t in ks.user_types.items()] table_ddls = [] for name, table in ks.tables.items(): if name.endswith('_scylla_cdc_log'): continue if 'cdc' in table.extensions: del table.extensions['cdc'] table_ddls.append(table.as_cql_query()) log('User types:\n{}'.format('\n'.join(ut_ddls))) log('Table definitions:\n{}'.format('\n'.join(table_ddls))) log('Letting stressor run for a while...') time.sleep(5)
def prime_cassandra(replication): """ Create Cassandra keyspace and initial tables. Args: replication: An integer specifying the replication factor for the keyspace. Raises: AppScaleBadArg if replication factor is not greater than 0. TypeError if replication is not an integer. """ if not isinstance(replication, int): raise TypeError('Replication must be an integer') if int(replication) <= 0: raise dbconstants.AppScaleBadArg('Replication must be greater than zero') hosts = appscale_info.get_db_ips() cluster = None session = None remaining_retries = INITIAL_CONNECT_RETRIES while True: try: cluster = Cluster(hosts) session = cluster.connect() break except cassandra.cluster.NoHostAvailable as connection_error: remaining_retries -= 1 if remaining_retries < 0: raise connection_error time.sleep(3) session.default_consistency_level = ConsistencyLevel.QUORUM create_keyspace = """ CREATE KEYSPACE IF NOT EXISTS "{keyspace}" WITH REPLICATION = %(replication)s """.format(keyspace=KEYSPACE) keyspace_replication = {'class': 'SimpleStrategy', 'replication_factor': replication} session.execute(create_keyspace, {'replication': keyspace_replication}) session.set_keyspace(KEYSPACE) for table in dbconstants.INITIAL_TABLES: create_table = """ CREATE TABLE IF NOT EXISTS "{table}" ( {key} blob, {column} text, {value} blob, PRIMARY KEY ({key}, {column}) ) WITH COMPACT STORAGE """.format(table=table, key=ThriftColumn.KEY, column=ThriftColumn.COLUMN_NAME, value=ThriftColumn.VALUE) logging.info('Trying to create {}'.format(table)) cluster.refresh_schema_metadata() session.execute(create_table) create_batch_tables(cluster, session) create_pull_queue_tables(cluster, session) first_entity = session.execute( 'SELECT * FROM "{}" LIMIT 1'.format(dbconstants.APP_ENTITY_TABLE)) existing_entities = len(list(first_entity)) == 1 define_ua_schema(session) metadata_insert = """ INSERT INTO "{table}" ({key}, {column}, {value}) VALUES (%(key)s, %(column)s, %(value)s) """.format( table=dbconstants.DATASTORE_METADATA_TABLE, key=ThriftColumn.KEY, column=ThriftColumn.COLUMN_NAME, value=ThriftColumn.VALUE ) if not existing_entities: parameters = {'key': bytearray(cassandra_interface.VERSION_INFO_KEY), 'column': cassandra_interface.VERSION_INFO_KEY, 'value': bytearray(str(POST_JOURNAL_VERSION))} session.execute(metadata_insert, parameters) # Indicate that the database has been successfully primed. parameters = {'key': bytearray(cassandra_interface.PRIMED_KEY), 'column': cassandra_interface.PRIMED_KEY, 'value': bytearray('true')} session.execute(metadata_insert, parameters) logging.info('Cassandra is primed.')
class DatastoreProxy(AppDBInterface): """ Cassandra implementation of the AppDBInterface """ def __init__(self, log_level=logging.INFO): """ Constructor. """ class_name = self.__class__.__name__ self.logger = logging.getLogger(class_name) self.logger.setLevel(log_level) self.logger.info("Starting {}".format(class_name)) self.hosts = appscale_info.get_db_ips() remaining_retries = INITIAL_CONNECT_RETRIES while True: try: self.cluster = Cluster(self.hosts) self.session = self.cluster.connect(KEYSPACE) break except cassandra.cluster.NoHostAvailable as connection_error: remaining_retries -= 1 if remaining_retries < 0: raise connection_error time.sleep(3) self.session.default_consistency_level = ConsistencyLevel.QUORUM self.retry_policy = IdempotentRetryPolicy() def close(self): """ Close all sessions and connections to Cassandra. """ self.cluster.shutdown() def batch_get_entity(self, table_name, row_keys, column_names): """ Takes in batches of keys and retrieves their corresponding rows. Args: table_name: The table to access row_keys: A list of keys to access column_names: A list of columns to access Returns: A dictionary of rows and columns/values of those rows. The format looks like such: {key:{column_name:value,...}} Raises: TypeError: If an argument passed in was not of the expected type. AppScaleDBConnectionError: If the batch_get could not be performed due to an error with Cassandra. """ if not isinstance(table_name, str): raise TypeError("Expected a str") if not isinstance(column_names, list): raise TypeError("Expected a list") if not isinstance(row_keys, list): raise TypeError("Expected a list") row_keys_bytes = [bytearray(row_key) for row_key in row_keys] statement = 'SELECT * FROM "{table}" ' "WHERE {key} IN %s and {column} IN %s".format( table=table_name, key=ThriftColumn.KEY, column=ThriftColumn.COLUMN_NAME ) query = SimpleStatement(statement, retry_policy=self.retry_policy) parameters = (ValueSequence(row_keys_bytes), ValueSequence(column_names)) try: results = self.session.execute(query, parameters=parameters) results_dict = {row_key: {} for row_key in row_keys} for (key, column, value) in results: if key not in results_dict: results_dict[key] = {} results_dict[key][column] = value return results_dict except (cassandra.Unavailable, cassandra.Timeout, cassandra.CoordinationFailure, cassandra.OperationTimedOut): message = "Exception during batch_get_entity" logging.exception(message) raise AppScaleDBConnectionError(message) def batch_put_entity(self, table_name, row_keys, column_names, cell_values, ttl=None): """ Allows callers to store multiple rows with a single call. A row can have multiple columns and values with them. We refer to each row as an entity. Args: table_name: The table to mutate row_keys: A list of keys to store on column_names: A list of columns to mutate cell_values: A dict of key/value pairs ttl: The number of seconds to keep the row. Raises: TypeError: If an argument passed in was not of the expected type. AppScaleDBConnectionError: If the batch_put could not be performed due to an error with Cassandra. """ if not isinstance(table_name, str): raise TypeError("Expected a str") if not isinstance(column_names, list): raise TypeError("Expected a list") if not isinstance(row_keys, list): raise TypeError("Expected a list") if not isinstance(cell_values, dict): raise TypeError("Expected a dict") insert_str = """ INSERT INTO "{table}" ({key}, {column}, {value}) VALUES (?, ?, ?) """.format( table=table_name, key=ThriftColumn.KEY, column=ThriftColumn.COLUMN_NAME, value=ThriftColumn.VALUE ) if ttl is not None: insert_str += "USING TTL {}".format(ttl) statement = self.session.prepare(insert_str) batch_insert = BatchStatement(retry_policy=self.retry_policy) for row_key in row_keys: for column in column_names: batch_insert.add(statement, (bytearray(row_key), column, bytearray(cell_values[row_key][column]))) try: self.session.execute(batch_insert) except (cassandra.Unavailable, cassandra.Timeout, cassandra.CoordinationFailure, cassandra.OperationTimedOut): message = "Exception during batch_put_entity" logging.exception(message) raise AppScaleDBConnectionError(message) def prepare_insert(self, table): """ Prepare an insert statement. Args: table: A string containing the table name. Returns: A PreparedStatement object. """ statement = """ INSERT INTO "{table}" ({key}, {column}, {value}) VALUES (?, ?, ?) """.format( table=table, key=ThriftColumn.KEY, column=ThriftColumn.COLUMN_NAME, value=ThriftColumn.VALUE ) return self.session.prepare(statement) def prepare_delete(self, table): """ Prepare a delete statement. Args: table: A string containing the table name. Returns: A PreparedStatement object. """ statement = """ DELETE FROM "{table}" WHERE {key} = ? """.format( table=table, key=ThriftColumn.KEY ) return self.session.prepare(statement) def _normal_batch(self, mutations): """ Use Cassandra's native batch statement to apply mutations atomically. Args: mutations: A list of dictionaries representing mutations. """ self.logger.debug("Normal batch: {} mutations".format(len(mutations))) batch = BatchStatement(consistency_level=ConsistencyLevel.QUORUM, retry_policy=self.retry_policy) prepared_statements = {"insert": {}, "delete": {}} for mutation in mutations: table = mutation["table"] if mutation["operation"] == TxnActions.PUT: if table not in prepared_statements["insert"]: prepared_statements["insert"][table] = self.prepare_insert(table) values = mutation["values"] for column in values: batch.add( prepared_statements["insert"][table], (bytearray(mutation["key"]), column, bytearray(values[column])), ) elif mutation["operation"] == TxnActions.DELETE: if table not in prepared_statements["delete"]: prepared_statements["delete"][table] = self.prepare_delete(table) batch.add(prepared_statements["delete"][table], (bytearray(mutation["key"]),)) try: self.session.execute(batch) except (cassandra.Unavailable, cassandra.Timeout, cassandra.CoordinationFailure, cassandra.OperationTimedOut): message = "Exception during batch_mutate" logging.exception(message) raise AppScaleDBConnectionError(message) def apply_mutations(self, mutations): """ Apply mutations across tables. Args: mutations: A list of dictionaries representing mutations. """ for mutation in mutations: table = mutation["table"] if mutation["operation"] == TxnActions.PUT: values = mutation["values"] for column in values: insert_row = """ INSERT INTO "{table}" ({key}, {column}, {value}) VALUES (%(key)s, %(column)s, %(value)s) """.format( table=table, key=ThriftColumn.KEY, column=ThriftColumn.COLUMN_NAME, value=ThriftColumn.VALUE ) parameters = { "key": bytearray(mutation["key"]), "column": column, "value": bytearray(values[column]), } self.session.execute(insert_row, parameters) elif mutation["operation"] == TxnActions.DELETE: delete_row = """ DELETE FROM "{table}" WHERE {key} = %(key)s """.format( table=table, key=ThriftColumn.KEY ) parameters = {"key": bytearray(mutation["key"])} self.session.execute(delete_row, parameters) def _large_batch(self, app, mutations, entity_changes, txn): """ Insert or delete multiple rows across tables in an atomic statement. Args: app: A string containing the application ID. mutations: A list of dictionaries representing mutations. entity_changes: A list of changes at the entity level. txn: A transaction ID handler. Raises: FailedBatch if a concurrent process modifies the batch status. """ self.logger.debug("Large batch: transaction {}, {} mutations".format(txn, len(mutations))) set_status = """ INSERT INTO batch_status (app, transaction, applied) VALUES (%(app)s, %(transaction)s, False) IF NOT EXISTS """ parameters = {"app": app, "transaction": txn} result = self.session.execute(set_status, parameters) if not result.was_applied: raise FailedBatch("A batch for transaction {} already exists".format(txn)) for entity_change in entity_changes: insert_item = """ INSERT INTO batches (app, transaction, namespace, path, old_value, new_value) VALUES (%(app)s, %(transaction)s, %(namespace)s, %(path)s, %(old_value)s, %(new_value)s) """ old_value = None if entity_change["old"] is not None: old_value = bytearray(entity_change["old"].Encode()) new_value = None if entity_change["new"] is not None: new_value = bytearray(entity_change["new"].Encode()) parameters = { "app": app, "transaction": txn, "namespace": entity_change["key"].name_space(), "path": bytearray(entity_change["key"].path().Encode()), "old_value": old_value, "new_value": new_value, } self.session.execute(insert_item, parameters) update_status = """ UPDATE batch_status SET applied = True WHERE app = %(app)s AND transaction = %(transaction)s IF applied = False """ parameters = {"app": app, "transaction": txn} result = self.session.execute(update_status, parameters) if not result.was_applied: raise FailedBatch("Another process modified batch for transaction {}".format(txn)) self.apply_mutations(mutations) clear_batch = """ DELETE FROM batches WHERE app = %(app)s AND transaction = %(transaction)s """ parameters = {"app": app, "transaction": txn} self.session.execute(clear_batch, parameters) clear_status = """ DELETE FROM batch_status WHERE app = %(app)s and transaction = %(transaction)s """ parameters = {"app": app, "transaction": txn} self.session.execute(clear_status, parameters) def batch_mutate(self, app, mutations, entity_changes, txn): """ Insert or delete multiple rows across tables in an atomic statement. Args: app: A string containing the application ID. mutations: A list of dictionaries representing mutations. entity_changes: A list of changes at the entity level. txn: A transaction ID handler. """ size = batch_size(mutations) self.logger.debug("batch_size: {}".format(size)) if size > LARGE_BATCH_THRESHOLD: self._large_batch(app, mutations, entity_changes, txn) else: self._normal_batch(mutations) def batch_delete(self, table_name, row_keys, column_names=()): """ Remove a set of rows corresponding to a set of keys. Args: table_name: Table to delete rows from row_keys: A list of keys to remove column_names: Not used Raises: TypeError: If an argument passed in was not of the expected type. AppScaleDBConnectionError: If the batch_delete could not be performed due to an error with Cassandra. """ if not isinstance(table_name, str): raise TypeError("Expected a str") if not isinstance(row_keys, list): raise TypeError("Expected a list") row_keys_bytes = [bytearray(row_key) for row_key in row_keys] statement = 'DELETE FROM "{table}" WHERE {key} IN %s'.format(table=table_name, key=ThriftColumn.KEY) query = SimpleStatement(statement, retry_policy=self.retry_policy) parameters = (ValueSequence(row_keys_bytes),) try: self.session.execute(query, parameters=parameters) except (cassandra.Unavailable, cassandra.Timeout, cassandra.CoordinationFailure, cassandra.OperationTimedOut): message = "Exception during batch_delete" logging.exception(message) raise AppScaleDBConnectionError(message) def delete_table(self, table_name): """ Drops a given table (aka column family in Cassandra) Args: table_name: A string name of the table to drop Raises: TypeError: If an argument passed in was not of the expected type. AppScaleDBConnectionError: If the delete_table could not be performed due to an error with Cassandra. """ if not isinstance(table_name, str): raise TypeError("Expected a str") statement = 'DROP TABLE IF EXISTS "{table}"'.format(table=table_name) query = SimpleStatement(statement, retry_policy=self.retry_policy) try: self.session.execute(query) except (cassandra.Unavailable, cassandra.Timeout, cassandra.CoordinationFailure, cassandra.OperationTimedOut): message = "Exception during delete_table" logging.exception(message) raise AppScaleDBConnectionError(message) def create_table(self, table_name, column_names): """ Creates a table if it doesn't already exist. Args: table_name: The column family name column_names: Not used but here to match the interface Raises: TypeError: If an argument passed in was not of the expected type. AppScaleDBConnectionError: If the create_table could not be performed due to an error with Cassandra. """ if not isinstance(table_name, str): raise TypeError("Expected a str") if not isinstance(column_names, list): raise TypeError("Expected a list") self.cluster.refresh_schema_metadata() statement = ( 'CREATE TABLE IF NOT EXISTS "{table}" (' "{key} blob," "{column} text," "{value} blob," "PRIMARY KEY ({key}, {column})" ") WITH COMPACT STORAGE".format( table=table_name, key=ThriftColumn.KEY, column=ThriftColumn.COLUMN_NAME, value=ThriftColumn.VALUE ) ) query = SimpleStatement(statement) try: self.session.execute(query) except (cassandra.Unavailable, cassandra.Timeout, cassandra.CoordinationFailure, cassandra.OperationTimedOut): message = "Exception during create_table" logging.exception(message) raise AppScaleDBConnectionError(message) def range_query( self, table_name, column_names, start_key, end_key, limit, offset=0, start_inclusive=True, end_inclusive=True, keys_only=False, ): """ Gets a dense range ordered by keys. Returns an ordered list of a dictionary of [key:{column1:value1, column2:value2},...] or a list of keys if keys only. Args: table_name: Name of table to access column_names: Columns which get returned within the key range start_key: String for which the query starts at end_key: String for which the query ends at limit: Maximum number of results to return offset: Cuts off these many from the results [offset:] start_inclusive: Boolean if results should include the start_key end_inclusive: Boolean if results should include the end_key keys_only: Boolean if to only keys and not values Raises: TypeError: If an argument passed in was not of the expected type. AppScaleDBConnectionError: If the range_query could not be performed due to an error with Cassandra. Returns: An ordered list of dictionaries of key=>columns/values """ if not isinstance(table_name, str): raise TypeError("table_name must be a string") if not isinstance(column_names, list): raise TypeError("column_names must be a list") if not isinstance(start_key, str): raise TypeError("start_key must be a string") if not isinstance(end_key, str): raise TypeError("end_key must be a string") if not isinstance(limit, (int, long)) and limit is not None: raise TypeError("limit must be int, long, or NoneType") if not isinstance(offset, (int, long)): raise TypeError("offset must be int or long") if start_inclusive: gt_compare = ">=" else: gt_compare = ">" if end_inclusive: lt_compare = "<=" else: lt_compare = "<" query_limit = "" if limit is not None: query_limit = "LIMIT {}".format(len(column_names) * limit) statement = """ SELECT * FROM "{table}" WHERE token({key}) {gt_compare} %s AND token({key}) {lt_compare} %s AND {column} IN %s {limit} ALLOW FILTERING """.format( table=table_name, key=ThriftColumn.KEY, gt_compare=gt_compare, lt_compare=lt_compare, column=ThriftColumn.COLUMN_NAME, limit=query_limit, ) query = SimpleStatement(statement, retry_policy=self.retry_policy) parameters = (bytearray(start_key), bytearray(end_key), ValueSequence(column_names)) try: results = self.session.execute(query, parameters=parameters) results_list = [] current_item = {} current_key = None for (key, column, value) in results: if keys_only: results_list.append(key) continue if key != current_key: if current_item: results_list.append({current_key: current_item}) current_item = {} current_key = key current_item[column] = value if current_item: results_list.append({current_key: current_item}) return results_list[offset:] except (cassandra.Unavailable, cassandra.Timeout, cassandra.CoordinationFailure, cassandra.OperationTimedOut): message = "Exception during range_query" logging.exception(message) raise AppScaleDBConnectionError(message) def get_metadata(self, key): """ Retrieve a value from the datastore metadata table. Args: key: A string containing the key to fetch. Returns: A string containing the value or None if the key is not present. """ statement = """ SELECT {value} FROM "{table}" WHERE {key} = %s AND {column} = %s """.format( value=ThriftColumn.VALUE, table=dbconstants.DATASTORE_METADATA_TABLE, key=ThriftColumn.KEY, column=ThriftColumn.COLUMN_NAME, ) try: results = self.session.execute(statement, (bytearray(key), key)) except (cassandra.Unavailable, cassandra.Timeout, cassandra.CoordinationFailure, cassandra.OperationTimedOut): message = "Unable to fetch {} from datastore metadata".format(key) logging.exception(message) raise AppScaleDBConnectionError(message) try: return results[0].value except IndexError: return None def set_metadata(self, key, value): """ Set a datastore metadata value. Args: key: A string containing the key to set. value: A string containing the value to set. """ if not isinstance(key, str): raise TypeError("key should be a string") if not isinstance(value, str): raise TypeError("value should be a string") statement = """ INSERT INTO "{table}" ({key}, {column}, {value}) VALUES (%(key)s, %(column)s, %(value)s) """.format( table=dbconstants.DATASTORE_METADATA_TABLE, key=ThriftColumn.KEY, column=ThriftColumn.COLUMN_NAME, value=ThriftColumn.VALUE, ) parameters = {"key": bytearray(key), "column": key, "value": bytearray(value)} try: self.session.execute(statement, parameters) except (cassandra.Unavailable, cassandra.Timeout, cassandra.CoordinationFailure, cassandra.OperationTimedOut): message = "Unable to set datastore metadata for {}".format(key) logging.exception(message) raise AppScaleDBConnectionError(message) except cassandra.InvalidRequest: self.create_table(dbconstants.DATASTORE_METADATA_TABLE, dbconstants.DATASTORE_METADATA_SCHEMA) self.session.execute(statement, parameters) def get_indices(self, app_id): """ Gets the indices of the given application. Args: app_id: Name of the application. Returns: Returns a list of encoded entity_pb.CompositeIndex objects. """ start_key = dbconstants.KEY_DELIMITER.join([app_id, "index", ""]) end_key = dbconstants.KEY_DELIMITER.join([app_id, "index", dbconstants.TERMINATING_STRING]) result = self.range_query( dbconstants.METADATA_TABLE, dbconstants.METADATA_SCHEMA, start_key, end_key, dbconstants.MAX_NUMBER_OF_COMPOSITE_INDEXES, offset=0, start_inclusive=True, end_inclusive=True, ) list_result = [] for list_item in result: for key, value in list_item.iteritems(): list_result.append(value["data"]) return list_result def valid_data_version(self): """ Checks whether or not the data layout can be used. Returns: A boolean. """ try: version = self.get_metadata(VERSION_INFO_KEY) except cassandra.InvalidRequest: return False return version is not None and float(version) == EXPECTED_DATA_VERSION
def prime_cassandra(replication): """ Create Cassandra keyspace and initial tables. Args: replication: An integer specifying the replication factor for the keyspace. Raises: AppScaleBadArg if replication factor is not greater than 0. TypeError if replication is not an integer. """ if not isinstance(replication, int): raise TypeError('Replication must be an integer') if int(replication) <= 0: raise dbconstants.AppScaleBadArg( 'Replication must be greater than zero') hosts = appscale_info.get_db_ips() cluster = None session = None remaining_retries = INITIAL_CONNECT_RETRIES while True: try: cluster = Cluster(hosts) session = cluster.connect() break except cassandra.cluster.NoHostAvailable as connection_error: remaining_retries -= 1 if remaining_retries < 0: raise connection_error time.sleep(3) session.default_consistency_level = ConsistencyLevel.QUORUM create_keyspace = """ CREATE KEYSPACE IF NOT EXISTS "{keyspace}" WITH REPLICATION = %(replication)s """.format(keyspace=KEYSPACE) keyspace_replication = { 'class': 'SimpleStrategy', 'replication_factor': replication } session.execute(create_keyspace, {'replication': keyspace_replication}) session.set_keyspace(KEYSPACE) for table in dbconstants.INITIAL_TABLES: create_table = """ CREATE TABLE IF NOT EXISTS "{table}" ( {key} blob, {column} text, {value} blob, PRIMARY KEY ({key}, {column}) ) WITH COMPACT STORAGE """.format(table=table, key=ThriftColumn.KEY, column=ThriftColumn.COLUMN_NAME, value=ThriftColumn.VALUE) logging.info('Trying to create {}'.format(table)) cluster.refresh_schema_metadata() session.execute(create_table) create_batch_tables(cluster, session) create_pull_queue_tables(cluster, session) first_entity = session.execute('SELECT * FROM "{}" LIMIT 1'.format( dbconstants.APP_ENTITY_TABLE)) existing_entities = len(list(first_entity)) == 1 define_ua_schema(session) metadata_insert = """ INSERT INTO "{table}" ({key}, {column}, {value}) VALUES (%(key)s, %(column)s, %(value)s) """.format(table=dbconstants.DATASTORE_METADATA_TABLE, key=ThriftColumn.KEY, column=ThriftColumn.COLUMN_NAME, value=ThriftColumn.VALUE) if not existing_entities: parameters = { 'key': bytearray(cassandra_interface.VERSION_INFO_KEY), 'column': cassandra_interface.VERSION_INFO_KEY, 'value': bytearray(str(POST_JOURNAL_VERSION)) } session.execute(metadata_insert, parameters) # Indicate that the database has been successfully primed. parameters = { 'key': bytearray(cassandra_interface.PRIMED_KEY), 'column': cassandra_interface.PRIMED_KEY, 'value': bytearray('true') } session.execute(metadata_insert, parameters) logging.info('Cassandra is primed.')
class DatastoreProxy(AppDBInterface): """ Cassandra implementation of the AppDBInterface """ def __init__(self, log_level=logging.INFO): """ Constructor. """ class_name = self.__class__.__name__ self.logger = logging.getLogger(class_name) self.logger.setLevel(log_level) self.logger.info('Starting {}'.format(class_name)) self.hosts = appscale_info.get_db_ips() remaining_retries = INITIAL_CONNECT_RETRIES while True: try: self.cluster = Cluster(self.hosts) self.session = self.cluster.connect(KEYSPACE) break except cassandra.cluster.NoHostAvailable as connection_error: remaining_retries -= 1 if remaining_retries < 0: raise connection_error time.sleep(3) self.session.default_consistency_level = ConsistencyLevel.QUORUM self.retry_policy = IdempotentRetryPolicy() def close(self): """ Close all sessions and connections to Cassandra. """ self.cluster.shutdown() def batch_get_entity(self, table_name, row_keys, column_names): """ Takes in batches of keys and retrieves their corresponding rows. Args: table_name: The table to access row_keys: A list of keys to access column_names: A list of columns to access Returns: A dictionary of rows and columns/values of those rows. The format looks like such: {key:{column_name:value,...}} Raises: TypeError: If an argument passed in was not of the expected type. AppScaleDBConnectionError: If the batch_get could not be performed due to an error with Cassandra. """ if not isinstance(table_name, str): raise TypeError("Expected a str") if not isinstance(column_names, list): raise TypeError("Expected a list") if not isinstance(row_keys, list): raise TypeError("Expected a list") row_keys_bytes = [bytearray(row_key) for row_key in row_keys] statement = 'SELECT * FROM "{table}" '\ 'WHERE {key} IN %s and {column} IN %s'.format( table=table_name, key=ThriftColumn.KEY, column=ThriftColumn.COLUMN_NAME, ) query = SimpleStatement(statement, retry_policy=self.retry_policy) parameters = (ValueSequence(row_keys_bytes), ValueSequence(column_names)) try: results = self.session.execute(query, parameters=parameters) results_dict = {row_key: {} for row_key in row_keys} for (key, column, value) in results: if key not in results_dict: results_dict[key] = {} results_dict[key][column] = value return results_dict except dbconstants.TRANSIENT_CASSANDRA_ERRORS: message = 'Exception during batch_get_entity' logging.exception(message) raise AppScaleDBConnectionError(message) def batch_put_entity(self, table_name, row_keys, column_names, cell_values, ttl=None): """ Allows callers to store multiple rows with a single call. A row can have multiple columns and values with them. We refer to each row as an entity. Args: table_name: The table to mutate row_keys: A list of keys to store on column_names: A list of columns to mutate cell_values: A dict of key/value pairs ttl: The number of seconds to keep the row. Raises: TypeError: If an argument passed in was not of the expected type. AppScaleDBConnectionError: If the batch_put could not be performed due to an error with Cassandra. """ if not isinstance(table_name, str): raise TypeError("Expected a str") if not isinstance(column_names, list): raise TypeError("Expected a list") if not isinstance(row_keys, list): raise TypeError("Expected a list") if not isinstance(cell_values, dict): raise TypeError("Expected a dict") insert_str = """ INSERT INTO "{table}" ({key}, {column}, {value}) VALUES (?, ?, ?) """.format(table=table_name, key=ThriftColumn.KEY, column=ThriftColumn.COLUMN_NAME, value=ThriftColumn.VALUE) if ttl is not None: insert_str += 'USING TTL {}'.format(ttl) statement = self.session.prepare(insert_str) batch_insert = BatchStatement(retry_policy=self.retry_policy) for row_key in row_keys: for column in column_names: batch_insert.add(statement, (bytearray(row_key), column, bytearray(cell_values[row_key][column]))) try: self.session.execute(batch_insert) except dbconstants.TRANSIENT_CASSANDRA_ERRORS: message = 'Exception during batch_put_entity' logging.exception(message) raise AppScaleDBConnectionError(message) def prepare_insert(self, table): """ Prepare an insert statement. Args: table: A string containing the table name. Returns: A PreparedStatement object. """ statement = """ INSERT INTO "{table}" ({key}, {column}, {value}) VALUES (?, ?, ?) """.format(table=table, key=ThriftColumn.KEY, column=ThriftColumn.COLUMN_NAME, value=ThriftColumn.VALUE) return self.session.prepare(statement) def prepare_delete(self, table): """ Prepare a delete statement. Args: table: A string containing the table name. Returns: A PreparedStatement object. """ statement = """ DELETE FROM "{table}" WHERE {key} = ? """.format(table=table, key=ThriftColumn.KEY) return self.session.prepare(statement) def _normal_batch(self, mutations): """ Use Cassandra's native batch statement to apply mutations atomically. Args: mutations: A list of dictionaries representing mutations. """ self.logger.debug('Normal batch: {} mutations'.format(len(mutations))) batch = BatchStatement(consistency_level=ConsistencyLevel.QUORUM, retry_policy=self.retry_policy) prepared_statements = {'insert': {}, 'delete': {}} for mutation in mutations: table = mutation['table'] if mutation['operation'] == TxnActions.PUT: if table not in prepared_statements['insert']: prepared_statements['insert'][table] = self.prepare_insert( table) values = mutation['values'] for column in values: batch.add(prepared_statements['insert'][table], (bytearray( mutation['key']), column, bytearray(values[column]))) elif mutation['operation'] == TxnActions.DELETE: if table not in prepared_statements['delete']: prepared_statements['delete'][table] = self.prepare_delete( table) batch.add(prepared_statements['delete'][table], (bytearray(mutation['key']), )) try: self.session.execute(batch) except dbconstants.TRANSIENT_CASSANDRA_ERRORS: message = 'Exception during batch_mutate' logging.exception(message) raise AppScaleDBConnectionError(message) def apply_mutations(self, mutations): """ Apply mutations across tables. Args: mutations: A list of dictionaries representing mutations. """ for mutation in mutations: table = mutation['table'] if mutation['operation'] == TxnActions.PUT: values = mutation['values'] for column in values: insert_row = """ INSERT INTO "{table}" ({key}, {column}, {value}) VALUES (%(key)s, %(column)s, %(value)s) """.format(table=table, key=ThriftColumn.KEY, column=ThriftColumn.COLUMN_NAME, value=ThriftColumn.VALUE) parameters = { 'key': bytearray(mutation['key']), 'column': column, 'value': bytearray(values[column]) } self.session.execute(insert_row, parameters) elif mutation['operation'] == TxnActions.DELETE: delete_row = """ DELETE FROM "{table}" WHERE {key} = %(key)s """.format(table=table, key=ThriftColumn.KEY) parameters = {'key': bytearray(mutation['key'])} self.session.execute(delete_row, parameters) def _large_batch(self, app, mutations, entity_changes, txn): """ Insert or delete multiple rows across tables in an atomic statement. Args: app: A string containing the application ID. mutations: A list of dictionaries representing mutations. entity_changes: A list of changes at the entity level. txn: A transaction ID handler. Raises: FailedBatch if a concurrent process modifies the batch status. """ self.logger.debug('Large batch: transaction {}, {} mutations'.format( txn, len(mutations))) set_status = """ INSERT INTO batch_status (app, transaction, applied) VALUES (%(app)s, %(transaction)s, False) IF NOT EXISTS """ parameters = {'app': app, 'transaction': txn} result = self.session.execute(set_status, parameters) if not result.was_applied: raise FailedBatch( 'A batch for transaction {} already exists'.format(txn)) for entity_change in entity_changes: insert_item = """ INSERT INTO batches (app, transaction, namespace, path, old_value, new_value) VALUES (%(app)s, %(transaction)s, %(namespace)s, %(path)s, %(old_value)s, %(new_value)s) """ old_value = None if entity_change['old'] is not None: old_value = bytearray(entity_change['old'].Encode()) new_value = None if entity_change['new'] is not None: new_value = bytearray(entity_change['new'].Encode()) parameters = { 'app': app, 'transaction': txn, 'namespace': entity_change['key'].name_space(), 'path': bytearray(entity_change['key'].path().Encode()), 'old_value': old_value, 'new_value': new_value } self.session.execute(insert_item, parameters) update_status = """ UPDATE batch_status SET applied = True WHERE app = %(app)s AND transaction = %(transaction)s IF applied = False """ parameters = {'app': app, 'transaction': txn} result = self.session.execute(update_status, parameters) if not result.was_applied: raise FailedBatch( 'Another process modified batch for transaction {}'.format( txn)) self.apply_mutations(mutations) clear_batch = """ DELETE FROM batches WHERE app = %(app)s AND transaction = %(transaction)s """ parameters = {'app': app, 'transaction': txn} self.session.execute(clear_batch, parameters) clear_status = """ DELETE FROM batch_status WHERE app = %(app)s and transaction = %(transaction)s """ parameters = {'app': app, 'transaction': txn} self.session.execute(clear_status, parameters) def batch_mutate(self, app, mutations, entity_changes, txn): """ Insert or delete multiple rows across tables in an atomic statement. Args: app: A string containing the application ID. mutations: A list of dictionaries representing mutations. entity_changes: A list of changes at the entity level. txn: A transaction ID handler. """ size = batch_size(mutations) self.logger.debug('batch_size: {}'.format(size)) if size > LARGE_BATCH_THRESHOLD: self._large_batch(app, mutations, entity_changes, txn) else: self._normal_batch(mutations) def batch_delete(self, table_name, row_keys, column_names=()): """ Remove a set of rows corresponding to a set of keys. Args: table_name: Table to delete rows from row_keys: A list of keys to remove column_names: Not used Raises: TypeError: If an argument passed in was not of the expected type. AppScaleDBConnectionError: If the batch_delete could not be performed due to an error with Cassandra. """ if not isinstance(table_name, str): raise TypeError("Expected a str") if not isinstance(row_keys, list): raise TypeError("Expected a list") row_keys_bytes = [bytearray(row_key) for row_key in row_keys] statement = 'DELETE FROM "{table}" WHERE {key} IN %s'.\ format( table=table_name, key=ThriftColumn.KEY ) query = SimpleStatement(statement, retry_policy=self.retry_policy) parameters = (ValueSequence(row_keys_bytes), ) try: self.session.execute(query, parameters=parameters) except dbconstants.TRANSIENT_CASSANDRA_ERRORS: message = 'Exception during batch_delete' logging.exception(message) raise AppScaleDBConnectionError(message) def delete_table(self, table_name): """ Drops a given table (aka column family in Cassandra) Args: table_name: A string name of the table to drop Raises: TypeError: If an argument passed in was not of the expected type. AppScaleDBConnectionError: If the delete_table could not be performed due to an error with Cassandra. """ if not isinstance(table_name, str): raise TypeError("Expected a str") statement = 'DROP TABLE IF EXISTS "{table}"'.format(table=table_name) query = SimpleStatement(statement, retry_policy=self.retry_policy) try: self.session.execute(query) except dbconstants.TRANSIENT_CASSANDRA_ERRORS: message = 'Exception during delete_table' logging.exception(message) raise AppScaleDBConnectionError(message) def create_table(self, table_name, column_names): """ Creates a table if it doesn't already exist. Args: table_name: The column family name column_names: Not used but here to match the interface Raises: TypeError: If an argument passed in was not of the expected type. AppScaleDBConnectionError: If the create_table could not be performed due to an error with Cassandra. """ if not isinstance(table_name, str): raise TypeError("Expected a str") if not isinstance(column_names, list): raise TypeError("Expected a list") self.cluster.refresh_schema_metadata() statement = 'CREATE TABLE IF NOT EXISTS "{table}" ('\ '{key} blob,'\ '{column} text,'\ '{value} blob,'\ 'PRIMARY KEY ({key}, {column})'\ ') WITH COMPACT STORAGE'.format( table=table_name, key=ThriftColumn.KEY, column=ThriftColumn.COLUMN_NAME, value=ThriftColumn.VALUE ) query = SimpleStatement(statement) try: self.session.execute(query) except dbconstants.TRANSIENT_CASSANDRA_ERRORS: message = 'Exception during create_table' logging.exception(message) raise AppScaleDBConnectionError(message) def range_query(self, table_name, column_names, start_key, end_key, limit, offset=0, start_inclusive=True, end_inclusive=True, keys_only=False): """ Gets a dense range ordered by keys. Returns an ordered list of a dictionary of [key:{column1:value1, column2:value2},...] or a list of keys if keys only. Args: table_name: Name of table to access column_names: Columns which get returned within the key range start_key: String for which the query starts at end_key: String for which the query ends at limit: Maximum number of results to return offset: Cuts off these many from the results [offset:] start_inclusive: Boolean if results should include the start_key end_inclusive: Boolean if results should include the end_key keys_only: Boolean if to only keys and not values Raises: TypeError: If an argument passed in was not of the expected type. AppScaleDBConnectionError: If the range_query could not be performed due to an error with Cassandra. Returns: An ordered list of dictionaries of key=>columns/values """ if not isinstance(table_name, str): raise TypeError('table_name must be a string') if not isinstance(column_names, list): raise TypeError('column_names must be a list') if not isinstance(start_key, str): raise TypeError('start_key must be a string') if not isinstance(end_key, str): raise TypeError('end_key must be a string') if not isinstance(limit, (int, long)) and limit is not None: raise TypeError('limit must be int, long, or NoneType') if not isinstance(offset, (int, long)): raise TypeError('offset must be int or long') if start_inclusive: gt_compare = '>=' else: gt_compare = '>' if end_inclusive: lt_compare = '<=' else: lt_compare = '<' query_limit = '' if limit is not None: query_limit = 'LIMIT {}'.format(len(column_names) * limit) statement = """ SELECT * FROM "{table}" WHERE token({key}) {gt_compare} %s AND token({key}) {lt_compare} %s AND {column} IN %s {limit} ALLOW FILTERING """.format(table=table_name, key=ThriftColumn.KEY, gt_compare=gt_compare, lt_compare=lt_compare, column=ThriftColumn.COLUMN_NAME, limit=query_limit) query = SimpleStatement(statement, retry_policy=self.retry_policy) parameters = (bytearray(start_key), bytearray(end_key), ValueSequence(column_names)) try: results = self.session.execute(query, parameters=parameters) results_list = [] current_item = {} current_key = None for (key, column, value) in results: if keys_only: results_list.append(key) continue if key != current_key: if current_item: results_list.append({current_key: current_item}) current_item = {} current_key = key current_item[column] = value if current_item: results_list.append({current_key: current_item}) return results_list[offset:] except dbconstants.TRANSIENT_CASSANDRA_ERRORS: message = 'Exception during range_query' logging.exception(message) raise AppScaleDBConnectionError(message) def get_metadata(self, key): """ Retrieve a value from the datastore metadata table. Args: key: A string containing the key to fetch. Returns: A string containing the value or None if the key is not present. """ statement = """ SELECT {value} FROM "{table}" WHERE {key} = %s AND {column} = %s """.format(value=ThriftColumn.VALUE, table=dbconstants.DATASTORE_METADATA_TABLE, key=ThriftColumn.KEY, column=ThriftColumn.COLUMN_NAME) try: results = self.session.execute(statement, (bytearray(key), key)) except dbconstants.TRANSIENT_CASSANDRA_ERRORS: message = 'Unable to fetch {} from datastore metadata'.format(key) logging.exception(message) raise AppScaleDBConnectionError(message) try: return results[0].value except IndexError: return None def set_metadata(self, key, value): """ Set a datastore metadata value. Args: key: A string containing the key to set. value: A string containing the value to set. """ if not isinstance(key, str): raise TypeError('key should be a string') if not isinstance(value, str): raise TypeError('value should be a string') statement = """ INSERT INTO "{table}" ({key}, {column}, {value}) VALUES (%(key)s, %(column)s, %(value)s) """.format(table=dbconstants.DATASTORE_METADATA_TABLE, key=ThriftColumn.KEY, column=ThriftColumn.COLUMN_NAME, value=ThriftColumn.VALUE) parameters = { 'key': bytearray(key), 'column': key, 'value': bytearray(value) } try: self.session.execute(statement, parameters) except dbconstants.TRANSIENT_CASSANDRA_ERRORS: message = 'Unable to set datastore metadata for {}'.format(key) logging.exception(message) raise AppScaleDBConnectionError(message) except cassandra.InvalidRequest: self.create_table(dbconstants.DATASTORE_METADATA_TABLE, dbconstants.DATASTORE_METADATA_SCHEMA) self.session.execute(statement, parameters) def get_indices(self, app_id): """ Gets the indices of the given application. Args: app_id: Name of the application. Returns: Returns a list of encoded entity_pb.CompositeIndex objects. """ start_key = dbconstants.KEY_DELIMITER.join([app_id, 'index', '']) end_key = dbconstants.KEY_DELIMITER.join( [app_id, 'index', dbconstants.TERMINATING_STRING]) result = self.range_query(dbconstants.METADATA_TABLE, dbconstants.METADATA_SCHEMA, start_key, end_key, dbconstants.MAX_NUMBER_OF_COMPOSITE_INDEXES, offset=0, start_inclusive=True, end_inclusive=True) list_result = [] for list_item in result: for key, value in list_item.iteritems(): list_result.append(value['data']) return list_result def valid_data_version(self): """ Checks whether or not the data layout can be used. Returns: A boolean. """ try: version = self.get_metadata(VERSION_INFO_KEY) except cassandra.InvalidRequest: return False return version is not None and float(version) == EXPECTED_DATA_VERSION