def test_pool_management(self):
        # Ensure that in_flight and request_ids quiesce after cluster operations
        cluster = Cluster(
            protocol_version=PROTOCOL_VERSION, idle_heartbeat_interval=0
        )  # no idle heartbeat here, pool management is tested in test_idle_heartbeat
        session = cluster.connect()
        session2 = cluster.connect()

        # prepare
        p = session.prepare("SELECT * FROM system.local WHERE key=?")
        self.assertTrue(session.execute(p, ('local', )))

        # simple
        self.assertTrue(
            session.execute("SELECT * FROM system.local WHERE key='local'"))

        # set keyspace
        session.set_keyspace('system')
        session.set_keyspace('system_traces')

        # use keyspace
        session.execute('USE system')
        session.execute('USE system_traces')

        # refresh schema
        cluster.refresh_schema_metadata()
        cluster.refresh_schema_metadata(max_schema_agreement_wait=0)

        assert_quiescent_pool_state(self, cluster)

        cluster.shutdown()
    def test_pool_management(self):
        # Ensure that in_flight and request_ids quiesce after cluster operations
        cluster = Cluster(protocol_version=PROTOCOL_VERSION, idle_heartbeat_interval=0)  # no idle heartbeat here, pool management is tested in test_idle_heartbeat
        session = cluster.connect()
        session2 = cluster.connect()

        # prepare
        p = session.prepare("SELECT * FROM system.local WHERE key=?")
        self.assertTrue(session.execute(p, ('local',)))

        # simple
        self.assertTrue(session.execute("SELECT * FROM system.local WHERE key='local'"))

        # set keyspace
        session.set_keyspace('system')
        session.set_keyspace('system_traces')

        # use keyspace
        session.execute('USE system')
        session.execute('USE system_traces')

        # refresh schema
        cluster.refresh_schema_metadata()
        cluster.refresh_schema_metadata(max_schema_agreement_wait=0)

        # submit schema refresh
        future = cluster.submit_schema_refresh()
        future.result()

        assert_quiescent_pool_state(self, cluster)

        cluster.shutdown()
    def test_refresh_schema(self):
        cluster = Cluster(protocol_version=PROTOCOL_VERSION)
        session = cluster.connect()

        original_meta = cluster.metadata.keyspaces
        # full schema refresh, with wait
        cluster.refresh_schema_metadata()
        self.assertIsNot(original_meta, cluster.metadata.keyspaces)
        self.assertEqual(original_meta, cluster.metadata.keyspaces)

        cluster.shutdown()
    def test_refresh_schema(self):
        cluster = Cluster(protocol_version=PROTOCOL_VERSION)
        session = cluster.connect()

        original_meta = cluster.metadata.keyspaces
        # full schema refresh, with wait
        cluster.refresh_schema_metadata()
        self.assertIsNot(original_meta, cluster.metadata.keyspaces)
        self.assertEqual(original_meta, cluster.metadata.keyspaces)

        cluster.shutdown()
class Migrator(object):
    """Execute migration operations in a C* database based on configuration.

    `opts` must contain at least the following attributes:
    - config_file: path to a YAML file containing the configuration
    - profiles: map of profile names to keyspace settings
    - user, password: authentication options. May be None to not use it.
    - hosts: comma-separated list of contact points
    - port: connection port
    """

    logger = logging.getLogger("Migrator")

    def __init__(self,
                 config,
                 profile='dev',
                 hosts=['127.0.0.1'],
                 port=9042,
                 user=None,
                 password=None,
                 host_cert_path=None,
                 client_key_path=None,
                 client_cert_path=None):
        self.config = config

        try:
            self.current_profile = self.config.profiles[profile]
        except KeyError:
            raise ValueError("Invalid profile name '{}'".format(profile))

        if user:
            auth_provider = PlainTextAuthProvider(user, password)
        else:
            auth_provider = None

        if host_cert_path:
            ssl_options = self._build_ssl_options(host_cert_path,
                                                  client_key_path,
                                                  client_cert_path)
        else:
            ssl_options = None

        self.cluster = Cluster(contact_points=hosts,
                               port=port,
                               auth_provider=auth_provider,
                               max_schema_agreement_wait=300,
                               control_connection_timeout=10,
                               connect_timeout=30,
                               ssl_options=ssl_options)

        self._session = None

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self._session is not None:
            self._session.shutdown()
            self._session = None

        if self.cluster is not None:
            self.cluster.shutdown()
            self.cluster = None

    def _build_ssl_options(self, host_cert_path, client_key_path,
                           client_cert_path):
        return {
            'ca_certs': host_cert_path,
            'certfile': client_cert_path,
            'keyfile': client_key_path
        }

    def _check_cluster(self):
        """Check if the cluster is still alive, raise otherwise"""
        if not self.cluster:
            raise RuntimeError("Cluster has shut down")

    def _init_session(self):
        if not self._session:
            s = self._session = self.cluster.connect()
            s.default_consistency_level = ConsistencyLevel.ALL
            s.default_serial_consistency_level = ConsistencyLevel.SERIAL
            s.default_timeout = 120

    @property
    def session(self):
        """Initialize and configure a  C* driver session if needed"""

        self._check_cluster()
        self._init_session()

        return self._session

    def _get_target_version(self, v):
        """
        Parses a version specifier to an actual numeric migration version

        `v` might be:
        - None: the latest version is chosen
        - an int, or a numeric string: that exact version is chosen
        - a string: the version with that name is chosen if it exists

        If an invalid version is found, `ValueError` is raised.
        """

        if v is None:
            return len(self.config.migrations)

        if isinstance(v, int):
            num = v
        elif v.isdigit():
            num = int(v)
        else:
            try:
                num = self.config.migrations.index(v)
            except IndexError:
                num = -1

        if num <= 0:
            raise ValueError('Invalid database version, must be a number > 0 '
                             'or the name of an existing migration')
        return num

    def _q(self, query, **kwargs):
        """
        Format a query with the configured keyspace and migration table

        `keyspace` and `table` are interpolated as named arguments
        """
        return query.format(keyspace=self.config.keyspace,
                            table=self.config.migrations_table,
                            **kwargs)

    def _execute(self, query, *args, **kwargs):
        """Execute a query with the current session"""
        self.logger.debug('Executing query: {}'.format(query))
        return self.session.execute(query, *args, **kwargs)

    def _keyspace_exists(self):
        self._init_session()

        return self.config.keyspace in self.cluster.metadata.keyspaces

    def _ensure_keyspace(self):
        """Create the keyspace if it does not exist"""

        if self._keyspace_exists():
            return

        self.logger.info("Creating keyspace '{}'".format(self.config.keyspace))

        profile = self.current_profile
        self._execute(
            self._q(CREATE_KEYSPACE,
                    replication=cassandra_ddl_repr(profile['replication']),
                    durable_writes=cassandra_ddl_repr(
                        profile['durable_writes'])))

        self.cluster.refresh_keyspace_metadata(self.config.keyspace)

    def _table_exists(self):
        self._init_session()

        ks_metadata = self.cluster.metadata.keyspaces.get(
            self.config.keyspace, None)
        # Fail if the keyspace is missing. If it should be created
        # automatically _ensure_keyspace() must be called first.
        if not ks_metadata:
            raise ValueError("Keyspace '{}' does not exist, "
                             "stopping".format(self.config.keyspace))

        return self.config.migrations_table in ks_metadata.tables

    def _ensure_table(self):
        """Create the migration table if it does not exist"""

        if self._table_exists():
            return

        self.logger.info(
            "Creating table '{table}' in keyspace '{keyspace}'".format(
                keyspace=self.config.keyspace,
                table=self.config.migrations_table))

        self._execute(self._q(CREATE_MIGRATIONS_TABLE))
        self.cluster.refresh_table_metadata(self.config.keyspace,
                                            self.config.migrations_table)

    def _verify_migrations(self,
                           migrations,
                           ignore_failed=False,
                           ignore_concurrent=False):
        """Verify if the version history persisted in C* matches the migrations

        Migrations with corresponding DB versions must have the same content
        and name.
        Every DB version must have a corresponding migration.
        Migrations without corresponding DB versions are considered pending,
        and returned as a result.

        Returns a list of tuples of (version, migration), with version starting
        from 1 and incrementing linearly.
        """

        # Load all the currently existing versions and sort them by version
        # number, as Cassandra can only sort it for us by partition.
        cur_versions = self._execute(
            self._q('SELECT * FROM "{keyspace}"."{table}"'))
        cur_versions = sorted(cur_versions, key=lambda v: v.version)

        last_version = None
        version_pairs = zip_longest(cur_versions, migrations)

        # Work through ordered pairs of (existing version, migration), so that
        # stored versions and expected migrations can be compared for any
        # differences.
        for i, (version, migration) in enumerate(version_pairs, 1):
            # If version is empty, the migration has not yet been applied.
            # Keep track of the first such version, and append it to the
            # pending migrations list.
            if not version:
                break

            # If migration is empty, we have a version in the database with
            # no corresponding file. That might mean we're running the wrong
            # migration or have an out-of-date state, so we must fail.
            if not migration:
                raise UnknownMigration(version.version, version.name)

            # A migration was previously run and failed.
            if version.state == Migration.State.FAILED:
                if ignore_failed:
                    break

                raise FailedMigration(version.version, version.name)

            last_version = version.version

            # A stored version's migrations differs from the one in the FS.
            if version.content != migration.content or \
               version.name != migration.name or \
               bytearray(version.checksum) != bytearray(migration.checksum):
                raise InconsistentState(migration, version)

        if not last_version:
            pending_migrations = list(migrations)
        else:
            pending_migrations = list(migrations)[last_version:]

        if not pending_migrations:
            self.logger.info('Database is already up-to-date')
        else:
            self.logger.info('Pending migrations found. Current version: {}, '
                             'Latest version: {}'.format(
                                 last_version, len(migrations)))

        pending_migrations = enumerate(pending_migrations,
                                       (last_version or 0) + 1)
        return last_version, cur_versions, list(pending_migrations)

    def _create_version(self, version, migration):
        """
        Write an in-progress version entry to C*

        The migration is inserted with the given `version` number if and only
        if it does not exist already (using a CompareAndSet operation).

        If the insert suceeds (with the migration marked as in-progress), we
        can continue and actually execute it. Otherwise, there was a concurrent
        write and we must fail to allow the other write to continue.

        """

        self.logger.info('Writing in-progress migration version {}: {}'.format(
            version, migration))

        version_id = uuid.uuid4()
        result = self._execute(
            self._q(CREATE_DB_VERSION),
            (version_id, version, migration.name, migration.content,
             bytearray(migration.checksum), Migration.State.IN_PROGRESS))

        return version_id

    def _apply_cql_migration(self, version, migration):
        """
        Persist and apply a cql migration

        First create an in-progress version entry, apply the script, then
        finalize the entry as succeeded, failed or skipped.
        """

        self.logger.info('Applying cql migration')

        statements = CqlSplitter.split(migration.content)

        try:
            if statements:
                self.logger.info('Executing migration with '
                                 '{} CQL statements'.format(len(statements)))

            for statement in statements:
                self.session.execute(statement)
        except Exception:
            raise FailedMigration(version, migration.name)

    def _apply_python_migration(self, version, migration):
        """
        Persist and apply a python migration

        First create an in-progress version entry, apply the script, then
        finalize the entry as succeeded, failed or skipped.
        """
        self.logger.info('Applying python script')

        try:
            mod, _ = os.path.splitext(os.path.basename(migration.path))
            migration_script = importlib.import_module(mod)
            migration_script.execute(self._session)
        except Exception:
            self.logger.exception('Failed to execute script')
            raise FailedMigration(version, migration.name)

    def _apply_migration(self, version, migration, skip=False):
        """
        Persist and apply a migration

        When `skip` is True, do everything but actually run the script, for
        example, when baselining instead of migrating.
        """

        self.logger.info('Advancing to version {}'.format(version))

        version_uuid = self._create_version(version, migration)
        new_state = Migration.State.FAILED
        sys.path.append(self.config.migrations_path)

        result = None

        try:
            if skip:
                self.logger.info('Migration is marked for skipping, '
                                 'not actually running script')
            else:
                if migration.is_python:
                    self._apply_python_migration(version, migration)
                else:
                    self._apply_cql_migration(version, migration)
        except Exception:
            self.logger.exception('Failed to execute migration')
            raise FailedMigration(version, migration.name)
        else:
            new_state = (Migration.State.SUCCEEDED
                         if not skip else Migration.State.SKIPPED)
        finally:
            self.logger.info('Finalizing migration version with '
                             'state {}'.format(new_state))
            result = self._execute(self._q(FINALIZE_DB_VERSION),
                                   (new_state, version_uuid))

    def _cleanup_previous_versions(self, cur_versions):
        if not cur_versions:
            return

        last_version = cur_versions[-1]
        if last_version.state != Migration.State.FAILED:
            return

        self.logger.warn('Cleaning up previous failed migration '
                         '(version {}): {}'.format(last_version.version,
                                                   last_version.name))

        result = self._execute(self._q(DELETE_DB_VERSION),
                               (last_version.id, Migration.State.FAILED))

    def _advance(self,
                 migrations,
                 target,
                 cur_versions,
                 skip=False,
                 force=False):
        """Apply all necessary migrations to reach a target version"""
        if force:
            self._cleanup_previous_versions(cur_versions)

        target_version = self._get_target_version(target)

        if migrations:
            # Set default keyspace so migrations don't need to refer to it
            # manually
            # Fixes https://github.com/Cobliteam/cassandra-migrate/issues/5
            self.session.execute('USE {};'.format(self.config.keyspace))

        for version, migration in migrations:
            if version > target_version:
                break

            self._apply_migration(version, migration, skip=skip)

        self.cluster.refresh_schema_metadata()

    def baseline(self, opts):
        """Baseline a database, by advancing migration state without changes"""

        self._check_cluster()
        self._ensure_table()

        last_version, cur_versions, pending_migrations = \
            self._verify_migrations(self.config.migrations,
                                    ignore_failed=False)

        self._advance(pending_migrations,
                      opts.db_version,
                      cur_versions,
                      skip=True)

    @confirmation_required
    def migrate(self, opts):
        """
        Migrate a database to a given version, applying any needed migrations
        """

        self._check_cluster()

        self._ensure_keyspace()
        self._ensure_table()

        last_version, cur_versions, pending_migrations = \
            self._verify_migrations(self.config.migrations,
                                    ignore_failed=opts.force)

        self._advance(pending_migrations,
                      opts.db_version,
                      cur_versions,
                      force=opts.force)

    @confirmation_required
    def reset(self, opts):
        """Reset a database, by dropping the keyspace then migrating"""
        self._check_cluster()

        self.logger.info("Dropping existing keyspace '{}'".format(
            self.config.keyspace))

        self._execute(self._q(DROP_KEYSPACE))
        self.cluster.refresh_schema_metadata()

        opts.force = False
        self.migrate(opts)

    @staticmethod
    def _bytes_to_hex(bs):
        return codecs.getencoder('hex')(bs)[0]

    def status(self, opts):
        self._check_cluster()

        if not self._keyspace_exists():
            print("Keyspace '{}' does not exist".format(self.config.keyspace))
            return

        if not self._table_exists():
            print("Migration table '{table}' does not exist in "
                  "keyspace '{keyspace}'".format(
                      keyspace=self.config.keyspace,
                      table=self.config.migrations_table))
            return

        last_version, cur_versions, pending_migrations = \
            self._verify_migrations(self.config.migrations,
                                    ignore_failed=True,
                                    ignore_concurrent=True)
        latest_version = len(self.config.migrations)

        print(
            tabulate((('Keyspace:', self.config.keyspace),
                      ('Migrations table:', self.config.migrations_table),
                      ('Current DB version:', last_version),
                      ('Latest DB version:', latest_version)),
                     tablefmt='plain'))

        if cur_versions:
            print('\n## Applied migrations\n')

            data = []
            for version in cur_versions:
                checksum = self._bytes_to_hex(version.checksum)
                date = arrow.get(version.applied_at).format()
                data.append((str(version.version), version.name, version.state,
                             date, checksum))
            print(
                tabulate(
                    data,
                    headers=['#', 'Name', 'State', 'Date applied',
                             'Checksum']))

        if pending_migrations:
            print('\n## Pending migrations\n')

            data = []
            for version, migration in pending_migrations:
                checksum = self._bytes_to_hex(migration.checksum)
                data.append((str(version), migration.name, checksum))
            print(tabulate(data, headers=['#', 'Name', 'Checksum']))
    def test_refresh_schema_no_wait(self):

        contact_points = ['127.0.0.1']
        cluster = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=10,
                          contact_points=contact_points, load_balancing_policy=WhiteListRoundRobinPolicy(contact_points))
        session = cluster.connect()

        schema_ver = session.execute("SELECT schema_version FROM system.local WHERE key='local'")[0][0]

        # create a schema disagreement
        session.execute("UPDATE system.local SET schema_version=%s WHERE key='local'", (uuid4(),))

        try:
            agreement_timeout = 1

            # cluster agreement wait exceeded
            c = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=agreement_timeout)
            start_time = time.time()
            s = c.connect()
            end_time = time.time()
            self.assertGreaterEqual(end_time - start_time, agreement_timeout)
            self.assertTrue(c.metadata.keyspaces)

            # cluster agreement wait used for refresh
            original_meta = c.metadata.keyspaces
            start_time = time.time()
            self.assertRaisesRegexp(Exception, r"Schema metadata was not refreshed.*", c.refresh_schema_metadata)
            end_time = time.time()
            self.assertGreaterEqual(end_time - start_time, agreement_timeout)
            self.assertIs(original_meta, c.metadata.keyspaces)
            
            # refresh wait overrides cluster value
            original_meta = c.metadata.keyspaces
            start_time = time.time()
            c.refresh_schema_metadata(max_schema_agreement_wait=0)
            end_time = time.time()
            self.assertLess(end_time - start_time, agreement_timeout)
            self.assertIsNot(original_meta, c.metadata.keyspaces)
            self.assertEqual(original_meta, c.metadata.keyspaces)

            c.shutdown()

            refresh_threshold = 0.5
            # cluster agreement bypass
            c = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=0)
            start_time = time.time()
            s = c.connect()
            end_time = time.time()
            self.assertLess(end_time - start_time, refresh_threshold)
            self.assertTrue(c.metadata.keyspaces)

            # cluster agreement wait used for refresh
            original_meta = c.metadata.keyspaces
            start_time = time.time()
            c.refresh_schema_metadata()
            end_time = time.time()
            self.assertLess(end_time - start_time, refresh_threshold)
            self.assertIsNot(original_meta, c.metadata.keyspaces)
            self.assertEqual(original_meta, c.metadata.keyspaces)
            
            # refresh wait overrides cluster value
            original_meta = c.metadata.keyspaces
            start_time = time.time()
            self.assertRaisesRegexp(Exception, r"Schema metadata was not refreshed.*", c.refresh_schema_metadata,
                                    max_schema_agreement_wait=agreement_timeout)
            end_time = time.time()
            self.assertGreaterEqual(end_time - start_time, agreement_timeout)
            self.assertIs(original_meta, c.metadata.keyspaces)

            c.shutdown()
        finally:
            session.execute("UPDATE system.local SET schema_version=%s WHERE key='local'", (schema_ver,))

        cluster.shutdown()
    def test_refresh_schema_no_wait(self):

        contact_points = ['127.0.0.1']
        cluster = Cluster(
            protocol_version=PROTOCOL_VERSION,
            max_schema_agreement_wait=10,
            contact_points=contact_points,
            load_balancing_policy=WhiteListRoundRobinPolicy(contact_points))
        session = cluster.connect()

        schema_ver = session.execute(
            "SELECT schema_version FROM system.local WHERE key='local'")[0][0]
        new_schema_ver = uuid4()
        session.execute(
            "UPDATE system.local SET schema_version=%s WHERE key='local'",
            (new_schema_ver, ))

        try:
            agreement_timeout = 1

            # cluster agreement wait exceeded
            c = Cluster(protocol_version=PROTOCOL_VERSION,
                        max_schema_agreement_wait=agreement_timeout)
            c.connect()
            self.assertTrue(c.metadata.keyspaces)

            # cluster agreement wait used for refresh
            original_meta = c.metadata.keyspaces
            start_time = time.time()
            self.assertRaisesRegexp(Exception,
                                    r"Schema metadata was not refreshed.*",
                                    c.refresh_schema_metadata)
            end_time = time.time()
            self.assertGreaterEqual(end_time - start_time, agreement_timeout)
            self.assertIs(original_meta, c.metadata.keyspaces)

            # refresh wait overrides cluster value
            original_meta = c.metadata.keyspaces
            start_time = time.time()
            c.refresh_schema_metadata(max_schema_agreement_wait=0)
            end_time = time.time()
            self.assertLess(end_time - start_time, agreement_timeout)
            self.assertIsNot(original_meta, c.metadata.keyspaces)
            self.assertEqual(original_meta, c.metadata.keyspaces)

            c.shutdown()

            refresh_threshold = 0.5
            # cluster agreement bypass
            c = Cluster(protocol_version=PROTOCOL_VERSION,
                        max_schema_agreement_wait=0)
            start_time = time.time()
            s = c.connect()
            end_time = time.time()
            self.assertLess(end_time - start_time, refresh_threshold)
            self.assertTrue(c.metadata.keyspaces)

            # cluster agreement wait used for refresh
            original_meta = c.metadata.keyspaces
            start_time = time.time()
            c.refresh_schema_metadata()
            end_time = time.time()
            self.assertLess(end_time - start_time, refresh_threshold)
            self.assertIsNot(original_meta, c.metadata.keyspaces)
            self.assertEqual(original_meta, c.metadata.keyspaces)

            # refresh wait overrides cluster value
            original_meta = c.metadata.keyspaces
            start_time = time.time()
            self.assertRaisesRegexp(
                Exception,
                r"Schema metadata was not refreshed.*",
                c.refresh_schema_metadata,
                max_schema_agreement_wait=agreement_timeout)
            end_time = time.time()
            self.assertGreaterEqual(end_time - start_time, agreement_timeout)
            self.assertIs(original_meta, c.metadata.keyspaces)
            c.shutdown()
        finally:
            # TODO once fixed this connect call
            session = cluster.connect()
            session.execute(
                "UPDATE system.local SET schema_version=%s WHERE key='local'",
                (schema_ver, ))

        cluster.shutdown()
Exemple #8
0
                    "-mode cql3", "native", "-rate threads=1",
                    "-log level=verbose interval=5",
                    "-errors retries=999 ignore", "-node {}".format(
                        master_nodes[0].ip())
                ],
                                 stdout=stressor_log,
                                 stderr=subprocess.STDOUT))

        log('Letting stressor run for a while...')
        # Sleep long enough for stressor to create the keyspace
        time.sleep(10)

        log('Fetching schema definitions from master cluster.')
        cm.control_connection_timeout = 20
        with cm.connect() as _:
            cm.refresh_schema_metadata(max_schema_agreement_wait=10)
            ks = cm.metadata.keyspaces[KS_NAME]
            ut_ddls = [t[1].as_cql_query() for t in ks.user_types.items()]
            table_ddls = []
            for name, table in ks.tables.items():
                if name.endswith('_scylla_cdc_log'):
                    continue
                if 'cdc' in table.extensions:
                    del table.extensions['cdc']
                table_ddls.append(table.as_cql_query())

        log('User types:\n{}'.format('\n'.join(ut_ddls)))
        log('Table definitions:\n{}'.format('\n'.join(table_ddls)))

        log('Letting stressor run for a while...')
        time.sleep(5)
Exemple #9
0
def prime_cassandra(replication):
  """ Create Cassandra keyspace and initial tables.

  Args:
    replication: An integer specifying the replication factor for the keyspace.
  Raises:
    AppScaleBadArg if replication factor is not greater than 0.
    TypeError if replication is not an integer.
  """
  if not isinstance(replication, int):
    raise TypeError('Replication must be an integer')

  if int(replication) <= 0:
    raise dbconstants.AppScaleBadArg('Replication must be greater than zero')

  hosts = appscale_info.get_db_ips()

  cluster = None
  session = None
  remaining_retries = INITIAL_CONNECT_RETRIES
  while True:
    try:
      cluster = Cluster(hosts)
      session = cluster.connect()
      break
    except cassandra.cluster.NoHostAvailable as connection_error:
      remaining_retries -= 1
      if remaining_retries < 0:
        raise connection_error
      time.sleep(3)
  session.default_consistency_level = ConsistencyLevel.QUORUM

  create_keyspace = """
    CREATE KEYSPACE IF NOT EXISTS "{keyspace}"
    WITH REPLICATION = %(replication)s
  """.format(keyspace=KEYSPACE)
  keyspace_replication = {'class': 'SimpleStrategy',
                          'replication_factor': replication}
  session.execute(create_keyspace, {'replication': keyspace_replication})
  session.set_keyspace(KEYSPACE)

  for table in dbconstants.INITIAL_TABLES:
    create_table = """
      CREATE TABLE IF NOT EXISTS "{table}" (
        {key} blob,
        {column} text,
        {value} blob,
        PRIMARY KEY ({key}, {column})
      ) WITH COMPACT STORAGE
    """.format(table=table,
               key=ThriftColumn.KEY,
               column=ThriftColumn.COLUMN_NAME,
               value=ThriftColumn.VALUE)
    logging.info('Trying to create {}'.format(table))
    cluster.refresh_schema_metadata()
    session.execute(create_table)

  create_batch_tables(cluster, session)
  create_pull_queue_tables(cluster, session)

  first_entity = session.execute(
    'SELECT * FROM "{}" LIMIT 1'.format(dbconstants.APP_ENTITY_TABLE))
  existing_entities = len(list(first_entity)) == 1

  define_ua_schema(session)

  metadata_insert = """
    INSERT INTO "{table}" ({key}, {column}, {value})
    VALUES (%(key)s, %(column)s, %(value)s)
  """.format(
    table=dbconstants.DATASTORE_METADATA_TABLE,
    key=ThriftColumn.KEY,
    column=ThriftColumn.COLUMN_NAME,
    value=ThriftColumn.VALUE
  )

  if not existing_entities:
    parameters = {'key': bytearray(cassandra_interface.VERSION_INFO_KEY),
                  'column': cassandra_interface.VERSION_INFO_KEY,
                  'value': bytearray(str(POST_JOURNAL_VERSION))}
    session.execute(metadata_insert, parameters)

  # Indicate that the database has been successfully primed.
  parameters = {'key': bytearray(cassandra_interface.PRIMED_KEY),
                'column': cassandra_interface.PRIMED_KEY,
                'value': bytearray('true')}
  session.execute(metadata_insert, parameters)
  logging.info('Cassandra is primed.')
class DatastoreProxy(AppDBInterface):
    """ 
    Cassandra implementation of the AppDBInterface
  """

    def __init__(self, log_level=logging.INFO):
        """
    Constructor.
    """
        class_name = self.__class__.__name__
        self.logger = logging.getLogger(class_name)
        self.logger.setLevel(log_level)
        self.logger.info("Starting {}".format(class_name))

        self.hosts = appscale_info.get_db_ips()

        remaining_retries = INITIAL_CONNECT_RETRIES
        while True:
            try:
                self.cluster = Cluster(self.hosts)
                self.session = self.cluster.connect(KEYSPACE)
                break
            except cassandra.cluster.NoHostAvailable as connection_error:
                remaining_retries -= 1
                if remaining_retries < 0:
                    raise connection_error
                time.sleep(3)

        self.session.default_consistency_level = ConsistencyLevel.QUORUM
        self.retry_policy = IdempotentRetryPolicy()

    def close(self):
        """ Close all sessions and connections to Cassandra. """
        self.cluster.shutdown()

    def batch_get_entity(self, table_name, row_keys, column_names):
        """
    Takes in batches of keys and retrieves their corresponding rows.
    
    Args:
      table_name: The table to access
      row_keys: A list of keys to access
      column_names: A list of columns to access
    Returns:
      A dictionary of rows and columns/values of those rows. The format 
      looks like such: {key:{column_name:value,...}}
    Raises:
      TypeError: If an argument passed in was not of the expected type.
      AppScaleDBConnectionError: If the batch_get could not be performed due to
        an error with Cassandra.
    """
        if not isinstance(table_name, str):
            raise TypeError("Expected a str")
        if not isinstance(column_names, list):
            raise TypeError("Expected a list")
        if not isinstance(row_keys, list):
            raise TypeError("Expected a list")

        row_keys_bytes = [bytearray(row_key) for row_key in row_keys]

        statement = 'SELECT * FROM "{table}" ' "WHERE {key} IN %s and {column} IN %s".format(
            table=table_name, key=ThriftColumn.KEY, column=ThriftColumn.COLUMN_NAME
        )
        query = SimpleStatement(statement, retry_policy=self.retry_policy)
        parameters = (ValueSequence(row_keys_bytes), ValueSequence(column_names))

        try:
            results = self.session.execute(query, parameters=parameters)

            results_dict = {row_key: {} for row_key in row_keys}
            for (key, column, value) in results:
                if key not in results_dict:
                    results_dict[key] = {}
                results_dict[key][column] = value

            return results_dict
        except (cassandra.Unavailable, cassandra.Timeout, cassandra.CoordinationFailure, cassandra.OperationTimedOut):
            message = "Exception during batch_get_entity"
            logging.exception(message)
            raise AppScaleDBConnectionError(message)

    def batch_put_entity(self, table_name, row_keys, column_names, cell_values, ttl=None):
        """
    Allows callers to store multiple rows with a single call. A row can 
    have multiple columns and values with them. We refer to each row as 
    an entity.
   
    Args: 
      table_name: The table to mutate
      row_keys: A list of keys to store on
      column_names: A list of columns to mutate
      cell_values: A dict of key/value pairs
      ttl: The number of seconds to keep the row.
    Raises:
      TypeError: If an argument passed in was not of the expected type.
      AppScaleDBConnectionError: If the batch_put could not be performed due to
        an error with Cassandra.
    """
        if not isinstance(table_name, str):
            raise TypeError("Expected a str")
        if not isinstance(column_names, list):
            raise TypeError("Expected a list")
        if not isinstance(row_keys, list):
            raise TypeError("Expected a list")
        if not isinstance(cell_values, dict):
            raise TypeError("Expected a dict")

        insert_str = """
      INSERT INTO "{table}" ({key}, {column}, {value})
      VALUES (?, ?, ?)
    """.format(
            table=table_name, key=ThriftColumn.KEY, column=ThriftColumn.COLUMN_NAME, value=ThriftColumn.VALUE
        )

        if ttl is not None:
            insert_str += "USING TTL {}".format(ttl)

        statement = self.session.prepare(insert_str)

        batch_insert = BatchStatement(retry_policy=self.retry_policy)

        for row_key in row_keys:
            for column in column_names:
                batch_insert.add(statement, (bytearray(row_key), column, bytearray(cell_values[row_key][column])))

        try:
            self.session.execute(batch_insert)
        except (cassandra.Unavailable, cassandra.Timeout, cassandra.CoordinationFailure, cassandra.OperationTimedOut):
            message = "Exception during batch_put_entity"
            logging.exception(message)
            raise AppScaleDBConnectionError(message)

    def prepare_insert(self, table):
        """ Prepare an insert statement.

    Args:
      table: A string containing the table name.
    Returns:
      A PreparedStatement object.
    """
        statement = """
      INSERT INTO "{table}" ({key}, {column}, {value})
      VALUES (?, ?, ?)
    """.format(
            table=table, key=ThriftColumn.KEY, column=ThriftColumn.COLUMN_NAME, value=ThriftColumn.VALUE
        )
        return self.session.prepare(statement)

    def prepare_delete(self, table):
        """ Prepare a delete statement.

    Args:
      table: A string containing the table name.
    Returns:
      A PreparedStatement object.
    """
        statement = """
      DELETE FROM "{table}" WHERE {key} = ?
    """.format(
            table=table, key=ThriftColumn.KEY
        )
        return self.session.prepare(statement)

    def _normal_batch(self, mutations):
        """ Use Cassandra's native batch statement to apply mutations atomically.

    Args:
      mutations: A list of dictionaries representing mutations.
    """
        self.logger.debug("Normal batch: {} mutations".format(len(mutations)))
        batch = BatchStatement(consistency_level=ConsistencyLevel.QUORUM, retry_policy=self.retry_policy)
        prepared_statements = {"insert": {}, "delete": {}}
        for mutation in mutations:
            table = mutation["table"]
            if mutation["operation"] == TxnActions.PUT:
                if table not in prepared_statements["insert"]:
                    prepared_statements["insert"][table] = self.prepare_insert(table)
                values = mutation["values"]
                for column in values:
                    batch.add(
                        prepared_statements["insert"][table],
                        (bytearray(mutation["key"]), column, bytearray(values[column])),
                    )
            elif mutation["operation"] == TxnActions.DELETE:
                if table not in prepared_statements["delete"]:
                    prepared_statements["delete"][table] = self.prepare_delete(table)
                batch.add(prepared_statements["delete"][table], (bytearray(mutation["key"]),))

        try:
            self.session.execute(batch)
        except (cassandra.Unavailable, cassandra.Timeout, cassandra.CoordinationFailure, cassandra.OperationTimedOut):
            message = "Exception during batch_mutate"
            logging.exception(message)
            raise AppScaleDBConnectionError(message)

    def apply_mutations(self, mutations):
        """ Apply mutations across tables.

    Args:
      mutations: A list of dictionaries representing mutations.
    """
        for mutation in mutations:
            table = mutation["table"]
            if mutation["operation"] == TxnActions.PUT:
                values = mutation["values"]
                for column in values:
                    insert_row = """
            INSERT INTO "{table}" ({key}, {column}, {value})
            VALUES (%(key)s, %(column)s, %(value)s)
          """.format(
                        table=table, key=ThriftColumn.KEY, column=ThriftColumn.COLUMN_NAME, value=ThriftColumn.VALUE
                    )
                    parameters = {
                        "key": bytearray(mutation["key"]),
                        "column": column,
                        "value": bytearray(values[column]),
                    }
                    self.session.execute(insert_row, parameters)
            elif mutation["operation"] == TxnActions.DELETE:
                delete_row = """
          DELETE FROM "{table}" WHERE {key} = %(key)s
        """.format(
                    table=table, key=ThriftColumn.KEY
                )
                parameters = {"key": bytearray(mutation["key"])}
                self.session.execute(delete_row, parameters)

    def _large_batch(self, app, mutations, entity_changes, txn):
        """ Insert or delete multiple rows across tables in an atomic statement.

    Args:
      app: A string containing the application ID.
      mutations: A list of dictionaries representing mutations.
      entity_changes: A list of changes at the entity level.
      txn: A transaction ID handler.
    Raises:
      FailedBatch if a concurrent process modifies the batch status.
    """
        self.logger.debug("Large batch: transaction {}, {} mutations".format(txn, len(mutations)))
        set_status = """
      INSERT INTO batch_status (app, transaction, applied)
      VALUES (%(app)s, %(transaction)s, False)
      IF NOT EXISTS
    """
        parameters = {"app": app, "transaction": txn}
        result = self.session.execute(set_status, parameters)
        if not result.was_applied:
            raise FailedBatch("A batch for transaction {} already exists".format(txn))

        for entity_change in entity_changes:
            insert_item = """
        INSERT INTO batches (app, transaction, namespace, path,
                             old_value, new_value)
        VALUES (%(app)s, %(transaction)s, %(namespace)s, %(path)s,
                %(old_value)s, %(new_value)s)
      """
            old_value = None
            if entity_change["old"] is not None:
                old_value = bytearray(entity_change["old"].Encode())
            new_value = None
            if entity_change["new"] is not None:
                new_value = bytearray(entity_change["new"].Encode())

            parameters = {
                "app": app,
                "transaction": txn,
                "namespace": entity_change["key"].name_space(),
                "path": bytearray(entity_change["key"].path().Encode()),
                "old_value": old_value,
                "new_value": new_value,
            }
            self.session.execute(insert_item, parameters)

        update_status = """
      UPDATE batch_status
      SET applied = True
      WHERE app = %(app)s
      AND transaction = %(transaction)s
      IF applied = False
    """
        parameters = {"app": app, "transaction": txn}
        result = self.session.execute(update_status, parameters)
        if not result.was_applied:
            raise FailedBatch("Another process modified batch for transaction {}".format(txn))

        self.apply_mutations(mutations)

        clear_batch = """
      DELETE FROM batches
      WHERE app = %(app)s AND transaction = %(transaction)s
    """
        parameters = {"app": app, "transaction": txn}
        self.session.execute(clear_batch, parameters)

        clear_status = """
      DELETE FROM batch_status
      WHERE app = %(app)s and transaction = %(transaction)s
    """
        parameters = {"app": app, "transaction": txn}
        self.session.execute(clear_status, parameters)

    def batch_mutate(self, app, mutations, entity_changes, txn):
        """ Insert or delete multiple rows across tables in an atomic statement.

    Args:
      app: A string containing the application ID.
      mutations: A list of dictionaries representing mutations.
      entity_changes: A list of changes at the entity level.
      txn: A transaction ID handler.
    """
        size = batch_size(mutations)
        self.logger.debug("batch_size: {}".format(size))
        if size > LARGE_BATCH_THRESHOLD:
            self._large_batch(app, mutations, entity_changes, txn)
        else:
            self._normal_batch(mutations)

    def batch_delete(self, table_name, row_keys, column_names=()):
        """
    Remove a set of rows corresponding to a set of keys.
     
    Args:
      table_name: Table to delete rows from
      row_keys: A list of keys to remove
      column_names: Not used
    Raises:
      TypeError: If an argument passed in was not of the expected type.
      AppScaleDBConnectionError: If the batch_delete could not be performed due
        to an error with Cassandra.
    """
        if not isinstance(table_name, str):
            raise TypeError("Expected a str")
        if not isinstance(row_keys, list):
            raise TypeError("Expected a list")

        row_keys_bytes = [bytearray(row_key) for row_key in row_keys]

        statement = 'DELETE FROM "{table}" WHERE {key} IN %s'.format(table=table_name, key=ThriftColumn.KEY)
        query = SimpleStatement(statement, retry_policy=self.retry_policy)
        parameters = (ValueSequence(row_keys_bytes),)

        try:
            self.session.execute(query, parameters=parameters)
        except (cassandra.Unavailable, cassandra.Timeout, cassandra.CoordinationFailure, cassandra.OperationTimedOut):
            message = "Exception during batch_delete"
            logging.exception(message)
            raise AppScaleDBConnectionError(message)

    def delete_table(self, table_name):
        """ 
    Drops a given table (aka column family in Cassandra)
  
    Args:
      table_name: A string name of the table to drop
    Raises:
      TypeError: If an argument passed in was not of the expected type.
      AppScaleDBConnectionError: If the delete_table could not be performed due
        to an error with Cassandra.
    """
        if not isinstance(table_name, str):
            raise TypeError("Expected a str")

        statement = 'DROP TABLE IF EXISTS "{table}"'.format(table=table_name)
        query = SimpleStatement(statement, retry_policy=self.retry_policy)

        try:
            self.session.execute(query)
        except (cassandra.Unavailable, cassandra.Timeout, cassandra.CoordinationFailure, cassandra.OperationTimedOut):
            message = "Exception during delete_table"
            logging.exception(message)
            raise AppScaleDBConnectionError(message)

    def create_table(self, table_name, column_names):
        """ 
    Creates a table if it doesn't already exist.
    
    Args:
      table_name: The column family name
      column_names: Not used but here to match the interface
    Raises:
      TypeError: If an argument passed in was not of the expected type.
      AppScaleDBConnectionError: If the create_table could not be performed due
        to an error with Cassandra.
    """
        if not isinstance(table_name, str):
            raise TypeError("Expected a str")
        if not isinstance(column_names, list):
            raise TypeError("Expected a list")

        self.cluster.refresh_schema_metadata()
        statement = (
            'CREATE TABLE IF NOT EXISTS "{table}" ('
            "{key} blob,"
            "{column} text,"
            "{value} blob,"
            "PRIMARY KEY ({key}, {column})"
            ") WITH COMPACT STORAGE".format(
                table=table_name, key=ThriftColumn.KEY, column=ThriftColumn.COLUMN_NAME, value=ThriftColumn.VALUE
            )
        )
        query = SimpleStatement(statement)

        try:
            self.session.execute(query)
        except (cassandra.Unavailable, cassandra.Timeout, cassandra.CoordinationFailure, cassandra.OperationTimedOut):
            message = "Exception during create_table"
            logging.exception(message)
            raise AppScaleDBConnectionError(message)

    def range_query(
        self,
        table_name,
        column_names,
        start_key,
        end_key,
        limit,
        offset=0,
        start_inclusive=True,
        end_inclusive=True,
        keys_only=False,
    ):
        """ 
    Gets a dense range ordered by keys. Returns an ordered list of 
    a dictionary of [key:{column1:value1, column2:value2},...]
    or a list of keys if keys only.
     
    Args:
      table_name: Name of table to access
      column_names: Columns which get returned within the key range
      start_key: String for which the query starts at
      end_key: String for which the query ends at
      limit: Maximum number of results to return
      offset: Cuts off these many from the results [offset:]
      start_inclusive: Boolean if results should include the start_key
      end_inclusive: Boolean if results should include the end_key
      keys_only: Boolean if to only keys and not values
    Raises:
      TypeError: If an argument passed in was not of the expected type.
      AppScaleDBConnectionError: If the range_query could not be performed due
        to an error with Cassandra.
    Returns:
      An ordered list of dictionaries of key=>columns/values
    """
        if not isinstance(table_name, str):
            raise TypeError("table_name must be a string")
        if not isinstance(column_names, list):
            raise TypeError("column_names must be a list")
        if not isinstance(start_key, str):
            raise TypeError("start_key must be a string")
        if not isinstance(end_key, str):
            raise TypeError("end_key must be a string")
        if not isinstance(limit, (int, long)) and limit is not None:
            raise TypeError("limit must be int, long, or NoneType")
        if not isinstance(offset, (int, long)):
            raise TypeError("offset must be int or long")

        if start_inclusive:
            gt_compare = ">="
        else:
            gt_compare = ">"

        if end_inclusive:
            lt_compare = "<="
        else:
            lt_compare = "<"

        query_limit = ""
        if limit is not None:
            query_limit = "LIMIT {}".format(len(column_names) * limit)

        statement = """
      SELECT * FROM "{table}" WHERE
      token({key}) {gt_compare} %s AND
      token({key}) {lt_compare} %s AND
      {column} IN %s
      {limit}
      ALLOW FILTERING
    """.format(
            table=table_name,
            key=ThriftColumn.KEY,
            gt_compare=gt_compare,
            lt_compare=lt_compare,
            column=ThriftColumn.COLUMN_NAME,
            limit=query_limit,
        )

        query = SimpleStatement(statement, retry_policy=self.retry_policy)
        parameters = (bytearray(start_key), bytearray(end_key), ValueSequence(column_names))

        try:
            results = self.session.execute(query, parameters=parameters)

            results_list = []
            current_item = {}
            current_key = None
            for (key, column, value) in results:
                if keys_only:
                    results_list.append(key)
                    continue

                if key != current_key:
                    if current_item:
                        results_list.append({current_key: current_item})
                    current_item = {}
                    current_key = key

                current_item[column] = value
            if current_item:
                results_list.append({current_key: current_item})
            return results_list[offset:]
        except (cassandra.Unavailable, cassandra.Timeout, cassandra.CoordinationFailure, cassandra.OperationTimedOut):
            message = "Exception during range_query"
            logging.exception(message)
            raise AppScaleDBConnectionError(message)

    def get_metadata(self, key):
        """ Retrieve a value from the datastore metadata table.

    Args:
      key: A string containing the key to fetch.
    Returns:
      A string containing the value or None if the key is not present.
    """
        statement = """
      SELECT {value} FROM "{table}"
      WHERE {key} = %s
      AND {column} = %s
    """.format(
            value=ThriftColumn.VALUE,
            table=dbconstants.DATASTORE_METADATA_TABLE,
            key=ThriftColumn.KEY,
            column=ThriftColumn.COLUMN_NAME,
        )
        try:
            results = self.session.execute(statement, (bytearray(key), key))
        except (cassandra.Unavailable, cassandra.Timeout, cassandra.CoordinationFailure, cassandra.OperationTimedOut):
            message = "Unable to fetch {} from datastore metadata".format(key)
            logging.exception(message)
            raise AppScaleDBConnectionError(message)

        try:
            return results[0].value
        except IndexError:
            return None

    def set_metadata(self, key, value):
        """ Set a datastore metadata value.

    Args:
      key: A string containing the key to set.
      value: A string containing the value to set.
    """
        if not isinstance(key, str):
            raise TypeError("key should be a string")

        if not isinstance(value, str):
            raise TypeError("value should be a string")

        statement = """
      INSERT INTO "{table}" ({key}, {column}, {value})
      VALUES (%(key)s, %(column)s, %(value)s)
    """.format(
            table=dbconstants.DATASTORE_METADATA_TABLE,
            key=ThriftColumn.KEY,
            column=ThriftColumn.COLUMN_NAME,
            value=ThriftColumn.VALUE,
        )
        parameters = {"key": bytearray(key), "column": key, "value": bytearray(value)}
        try:
            self.session.execute(statement, parameters)
        except (cassandra.Unavailable, cassandra.Timeout, cassandra.CoordinationFailure, cassandra.OperationTimedOut):
            message = "Unable to set datastore metadata for {}".format(key)
            logging.exception(message)
            raise AppScaleDBConnectionError(message)
        except cassandra.InvalidRequest:
            self.create_table(dbconstants.DATASTORE_METADATA_TABLE, dbconstants.DATASTORE_METADATA_SCHEMA)
            self.session.execute(statement, parameters)

    def get_indices(self, app_id):
        """ Gets the indices of the given application.

    Args:
      app_id: Name of the application.
    Returns:
      Returns a list of encoded entity_pb.CompositeIndex objects.
    """
        start_key = dbconstants.KEY_DELIMITER.join([app_id, "index", ""])
        end_key = dbconstants.KEY_DELIMITER.join([app_id, "index", dbconstants.TERMINATING_STRING])
        result = self.range_query(
            dbconstants.METADATA_TABLE,
            dbconstants.METADATA_SCHEMA,
            start_key,
            end_key,
            dbconstants.MAX_NUMBER_OF_COMPOSITE_INDEXES,
            offset=0,
            start_inclusive=True,
            end_inclusive=True,
        )
        list_result = []
        for list_item in result:
            for key, value in list_item.iteritems():
                list_result.append(value["data"])
        return list_result

    def valid_data_version(self):
        """ Checks whether or not the data layout can be used.

    Returns:
      A boolean.
    """
        try:
            version = self.get_metadata(VERSION_INFO_KEY)
        except cassandra.InvalidRequest:
            return False

        return version is not None and float(version) == EXPECTED_DATA_VERSION
Exemple #11
0
def prime_cassandra(replication):
    """ Create Cassandra keyspace and initial tables.

  Args:
    replication: An integer specifying the replication factor for the keyspace.
  Raises:
    AppScaleBadArg if replication factor is not greater than 0.
    TypeError if replication is not an integer.
  """
    if not isinstance(replication, int):
        raise TypeError('Replication must be an integer')

    if int(replication) <= 0:
        raise dbconstants.AppScaleBadArg(
            'Replication must be greater than zero')

    hosts = appscale_info.get_db_ips()

    cluster = None
    session = None
    remaining_retries = INITIAL_CONNECT_RETRIES
    while True:
        try:
            cluster = Cluster(hosts)
            session = cluster.connect()
            break
        except cassandra.cluster.NoHostAvailable as connection_error:
            remaining_retries -= 1
            if remaining_retries < 0:
                raise connection_error
            time.sleep(3)
    session.default_consistency_level = ConsistencyLevel.QUORUM

    create_keyspace = """
    CREATE KEYSPACE IF NOT EXISTS "{keyspace}"
    WITH REPLICATION = %(replication)s
  """.format(keyspace=KEYSPACE)
    keyspace_replication = {
        'class': 'SimpleStrategy',
        'replication_factor': replication
    }
    session.execute(create_keyspace, {'replication': keyspace_replication})
    session.set_keyspace(KEYSPACE)

    for table in dbconstants.INITIAL_TABLES:
        create_table = """
      CREATE TABLE IF NOT EXISTS "{table}" (
        {key} blob,
        {column} text,
        {value} blob,
        PRIMARY KEY ({key}, {column})
      ) WITH COMPACT STORAGE
    """.format(table=table,
               key=ThriftColumn.KEY,
               column=ThriftColumn.COLUMN_NAME,
               value=ThriftColumn.VALUE)
        logging.info('Trying to create {}'.format(table))
        cluster.refresh_schema_metadata()
        session.execute(create_table)

    create_batch_tables(cluster, session)
    create_pull_queue_tables(cluster, session)

    first_entity = session.execute('SELECT * FROM "{}" LIMIT 1'.format(
        dbconstants.APP_ENTITY_TABLE))
    existing_entities = len(list(first_entity)) == 1

    define_ua_schema(session)

    metadata_insert = """
    INSERT INTO "{table}" ({key}, {column}, {value})
    VALUES (%(key)s, %(column)s, %(value)s)
  """.format(table=dbconstants.DATASTORE_METADATA_TABLE,
             key=ThriftColumn.KEY,
             column=ThriftColumn.COLUMN_NAME,
             value=ThriftColumn.VALUE)

    if not existing_entities:
        parameters = {
            'key': bytearray(cassandra_interface.VERSION_INFO_KEY),
            'column': cassandra_interface.VERSION_INFO_KEY,
            'value': bytearray(str(POST_JOURNAL_VERSION))
        }
        session.execute(metadata_insert, parameters)

    # Indicate that the database has been successfully primed.
    parameters = {
        'key': bytearray(cassandra_interface.PRIMED_KEY),
        'column': cassandra_interface.PRIMED_KEY,
        'value': bytearray('true')
    }
    session.execute(metadata_insert, parameters)
    logging.info('Cassandra is primed.')
Exemple #12
0
class DatastoreProxy(AppDBInterface):
    """ 
    Cassandra implementation of the AppDBInterface
  """
    def __init__(self, log_level=logging.INFO):
        """
    Constructor.
    """
        class_name = self.__class__.__name__
        self.logger = logging.getLogger(class_name)
        self.logger.setLevel(log_level)
        self.logger.info('Starting {}'.format(class_name))

        self.hosts = appscale_info.get_db_ips()

        remaining_retries = INITIAL_CONNECT_RETRIES
        while True:
            try:
                self.cluster = Cluster(self.hosts)
                self.session = self.cluster.connect(KEYSPACE)
                break
            except cassandra.cluster.NoHostAvailable as connection_error:
                remaining_retries -= 1
                if remaining_retries < 0:
                    raise connection_error
                time.sleep(3)

        self.session.default_consistency_level = ConsistencyLevel.QUORUM
        self.retry_policy = IdempotentRetryPolicy()

    def close(self):
        """ Close all sessions and connections to Cassandra. """
        self.cluster.shutdown()

    def batch_get_entity(self, table_name, row_keys, column_names):
        """
    Takes in batches of keys and retrieves their corresponding rows.
    
    Args:
      table_name: The table to access
      row_keys: A list of keys to access
      column_names: A list of columns to access
    Returns:
      A dictionary of rows and columns/values of those rows. The format 
      looks like such: {key:{column_name:value,...}}
    Raises:
      TypeError: If an argument passed in was not of the expected type.
      AppScaleDBConnectionError: If the batch_get could not be performed due to
        an error with Cassandra.
    """
        if not isinstance(table_name, str): raise TypeError("Expected a str")
        if not isinstance(column_names, list):
            raise TypeError("Expected a list")
        if not isinstance(row_keys, list): raise TypeError("Expected a list")

        row_keys_bytes = [bytearray(row_key) for row_key in row_keys]

        statement = 'SELECT * FROM "{table}" '\
                    'WHERE {key} IN %s and {column} IN %s'.format(
                      table=table_name,
                      key=ThriftColumn.KEY,
                      column=ThriftColumn.COLUMN_NAME,
                    )
        query = SimpleStatement(statement, retry_policy=self.retry_policy)
        parameters = (ValueSequence(row_keys_bytes),
                      ValueSequence(column_names))

        try:
            results = self.session.execute(query, parameters=parameters)

            results_dict = {row_key: {} for row_key in row_keys}
            for (key, column, value) in results:
                if key not in results_dict:
                    results_dict[key] = {}
                results_dict[key][column] = value

            return results_dict
        except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
            message = 'Exception during batch_get_entity'
            logging.exception(message)
            raise AppScaleDBConnectionError(message)

    def batch_put_entity(self,
                         table_name,
                         row_keys,
                         column_names,
                         cell_values,
                         ttl=None):
        """
    Allows callers to store multiple rows with a single call. A row can 
    have multiple columns and values with them. We refer to each row as 
    an entity.
   
    Args: 
      table_name: The table to mutate
      row_keys: A list of keys to store on
      column_names: A list of columns to mutate
      cell_values: A dict of key/value pairs
      ttl: The number of seconds to keep the row.
    Raises:
      TypeError: If an argument passed in was not of the expected type.
      AppScaleDBConnectionError: If the batch_put could not be performed due to
        an error with Cassandra.
    """
        if not isinstance(table_name, str):
            raise TypeError("Expected a str")
        if not isinstance(column_names, list):
            raise TypeError("Expected a list")
        if not isinstance(row_keys, list):
            raise TypeError("Expected a list")
        if not isinstance(cell_values, dict):
            raise TypeError("Expected a dict")

        insert_str = """
      INSERT INTO "{table}" ({key}, {column}, {value})
      VALUES (?, ?, ?)
    """.format(table=table_name,
               key=ThriftColumn.KEY,
               column=ThriftColumn.COLUMN_NAME,
               value=ThriftColumn.VALUE)

        if ttl is not None:
            insert_str += 'USING TTL {}'.format(ttl)

        statement = self.session.prepare(insert_str)

        batch_insert = BatchStatement(retry_policy=self.retry_policy)

        for row_key in row_keys:
            for column in column_names:
                batch_insert.add(statement,
                                 (bytearray(row_key), column,
                                  bytearray(cell_values[row_key][column])))

        try:
            self.session.execute(batch_insert)
        except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
            message = 'Exception during batch_put_entity'
            logging.exception(message)
            raise AppScaleDBConnectionError(message)

    def prepare_insert(self, table):
        """ Prepare an insert statement.

    Args:
      table: A string containing the table name.
    Returns:
      A PreparedStatement object.
    """
        statement = """
      INSERT INTO "{table}" ({key}, {column}, {value})
      VALUES (?, ?, ?)
    """.format(table=table,
               key=ThriftColumn.KEY,
               column=ThriftColumn.COLUMN_NAME,
               value=ThriftColumn.VALUE)
        return self.session.prepare(statement)

    def prepare_delete(self, table):
        """ Prepare a delete statement.

    Args:
      table: A string containing the table name.
    Returns:
      A PreparedStatement object.
    """
        statement = """
      DELETE FROM "{table}" WHERE {key} = ?
    """.format(table=table, key=ThriftColumn.KEY)
        return self.session.prepare(statement)

    def _normal_batch(self, mutations):
        """ Use Cassandra's native batch statement to apply mutations atomically.

    Args:
      mutations: A list of dictionaries representing mutations.
    """
        self.logger.debug('Normal batch: {} mutations'.format(len(mutations)))
        batch = BatchStatement(consistency_level=ConsistencyLevel.QUORUM,
                               retry_policy=self.retry_policy)
        prepared_statements = {'insert': {}, 'delete': {}}
        for mutation in mutations:
            table = mutation['table']
            if mutation['operation'] == TxnActions.PUT:
                if table not in prepared_statements['insert']:
                    prepared_statements['insert'][table] = self.prepare_insert(
                        table)
                values = mutation['values']
                for column in values:
                    batch.add(prepared_statements['insert'][table], (bytearray(
                        mutation['key']), column, bytearray(values[column])))
            elif mutation['operation'] == TxnActions.DELETE:
                if table not in prepared_statements['delete']:
                    prepared_statements['delete'][table] = self.prepare_delete(
                        table)
                batch.add(prepared_statements['delete'][table],
                          (bytearray(mutation['key']), ))

        try:
            self.session.execute(batch)
        except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
            message = 'Exception during batch_mutate'
            logging.exception(message)
            raise AppScaleDBConnectionError(message)

    def apply_mutations(self, mutations):
        """ Apply mutations across tables.

    Args:
      mutations: A list of dictionaries representing mutations.
    """
        for mutation in mutations:
            table = mutation['table']
            if mutation['operation'] == TxnActions.PUT:
                values = mutation['values']
                for column in values:
                    insert_row = """
            INSERT INTO "{table}" ({key}, {column}, {value})
            VALUES (%(key)s, %(column)s, %(value)s)
          """.format(table=table,
                     key=ThriftColumn.KEY,
                     column=ThriftColumn.COLUMN_NAME,
                     value=ThriftColumn.VALUE)
                    parameters = {
                        'key': bytearray(mutation['key']),
                        'column': column,
                        'value': bytearray(values[column])
                    }
                    self.session.execute(insert_row, parameters)
            elif mutation['operation'] == TxnActions.DELETE:
                delete_row = """
          DELETE FROM "{table}" WHERE {key} = %(key)s
        """.format(table=table, key=ThriftColumn.KEY)
                parameters = {'key': bytearray(mutation['key'])}
                self.session.execute(delete_row, parameters)

    def _large_batch(self, app, mutations, entity_changes, txn):
        """ Insert or delete multiple rows across tables in an atomic statement.

    Args:
      app: A string containing the application ID.
      mutations: A list of dictionaries representing mutations.
      entity_changes: A list of changes at the entity level.
      txn: A transaction ID handler.
    Raises:
      FailedBatch if a concurrent process modifies the batch status.
    """
        self.logger.debug('Large batch: transaction {}, {} mutations'.format(
            txn, len(mutations)))
        set_status = """
      INSERT INTO batch_status (app, transaction, applied)
      VALUES (%(app)s, %(transaction)s, False)
      IF NOT EXISTS
    """
        parameters = {'app': app, 'transaction': txn}
        result = self.session.execute(set_status, parameters)
        if not result.was_applied:
            raise FailedBatch(
                'A batch for transaction {} already exists'.format(txn))

        for entity_change in entity_changes:
            insert_item = """
        INSERT INTO batches (app, transaction, namespace, path,
                             old_value, new_value)
        VALUES (%(app)s, %(transaction)s, %(namespace)s, %(path)s,
                %(old_value)s, %(new_value)s)
      """
            old_value = None
            if entity_change['old'] is not None:
                old_value = bytearray(entity_change['old'].Encode())
            new_value = None
            if entity_change['new'] is not None:
                new_value = bytearray(entity_change['new'].Encode())

            parameters = {
                'app': app,
                'transaction': txn,
                'namespace': entity_change['key'].name_space(),
                'path': bytearray(entity_change['key'].path().Encode()),
                'old_value': old_value,
                'new_value': new_value
            }
            self.session.execute(insert_item, parameters)

        update_status = """
      UPDATE batch_status
      SET applied = True
      WHERE app = %(app)s
      AND transaction = %(transaction)s
      IF applied = False
    """
        parameters = {'app': app, 'transaction': txn}
        result = self.session.execute(update_status, parameters)
        if not result.was_applied:
            raise FailedBatch(
                'Another process modified batch for transaction {}'.format(
                    txn))

        self.apply_mutations(mutations)

        clear_batch = """
      DELETE FROM batches
      WHERE app = %(app)s AND transaction = %(transaction)s
    """
        parameters = {'app': app, 'transaction': txn}
        self.session.execute(clear_batch, parameters)

        clear_status = """
      DELETE FROM batch_status
      WHERE app = %(app)s and transaction = %(transaction)s
    """
        parameters = {'app': app, 'transaction': txn}
        self.session.execute(clear_status, parameters)

    def batch_mutate(self, app, mutations, entity_changes, txn):
        """ Insert or delete multiple rows across tables in an atomic statement.

    Args:
      app: A string containing the application ID.
      mutations: A list of dictionaries representing mutations.
      entity_changes: A list of changes at the entity level.
      txn: A transaction ID handler.
    """
        size = batch_size(mutations)
        self.logger.debug('batch_size: {}'.format(size))
        if size > LARGE_BATCH_THRESHOLD:
            self._large_batch(app, mutations, entity_changes, txn)
        else:
            self._normal_batch(mutations)

    def batch_delete(self, table_name, row_keys, column_names=()):
        """
    Remove a set of rows corresponding to a set of keys.
     
    Args:
      table_name: Table to delete rows from
      row_keys: A list of keys to remove
      column_names: Not used
    Raises:
      TypeError: If an argument passed in was not of the expected type.
      AppScaleDBConnectionError: If the batch_delete could not be performed due
        to an error with Cassandra.
    """
        if not isinstance(table_name, str): raise TypeError("Expected a str")
        if not isinstance(row_keys, list): raise TypeError("Expected a list")

        row_keys_bytes = [bytearray(row_key) for row_key in row_keys]

        statement = 'DELETE FROM "{table}" WHERE {key} IN %s'.\
          format(
            table=table_name,
            key=ThriftColumn.KEY
          )
        query = SimpleStatement(statement, retry_policy=self.retry_policy)
        parameters = (ValueSequence(row_keys_bytes), )

        try:
            self.session.execute(query, parameters=parameters)
        except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
            message = 'Exception during batch_delete'
            logging.exception(message)
            raise AppScaleDBConnectionError(message)

    def delete_table(self, table_name):
        """ 
    Drops a given table (aka column family in Cassandra)
  
    Args:
      table_name: A string name of the table to drop
    Raises:
      TypeError: If an argument passed in was not of the expected type.
      AppScaleDBConnectionError: If the delete_table could not be performed due
        to an error with Cassandra.
    """
        if not isinstance(table_name, str): raise TypeError("Expected a str")

        statement = 'DROP TABLE IF EXISTS "{table}"'.format(table=table_name)
        query = SimpleStatement(statement, retry_policy=self.retry_policy)

        try:
            self.session.execute(query)
        except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
            message = 'Exception during delete_table'
            logging.exception(message)
            raise AppScaleDBConnectionError(message)

    def create_table(self, table_name, column_names):
        """ 
    Creates a table if it doesn't already exist.
    
    Args:
      table_name: The column family name
      column_names: Not used but here to match the interface
    Raises:
      TypeError: If an argument passed in was not of the expected type.
      AppScaleDBConnectionError: If the create_table could not be performed due
        to an error with Cassandra.
    """
        if not isinstance(table_name, str): raise TypeError("Expected a str")
        if not isinstance(column_names, list):
            raise TypeError("Expected a list")

        self.cluster.refresh_schema_metadata()
        statement = 'CREATE TABLE IF NOT EXISTS "{table}" ('\
            '{key} blob,'\
            '{column} text,'\
            '{value} blob,'\
            'PRIMARY KEY ({key}, {column})'\
          ') WITH COMPACT STORAGE'.format(
            table=table_name,
            key=ThriftColumn.KEY,
            column=ThriftColumn.COLUMN_NAME,
            value=ThriftColumn.VALUE
          )
        query = SimpleStatement(statement)

        try:
            self.session.execute(query)
        except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
            message = 'Exception during create_table'
            logging.exception(message)
            raise AppScaleDBConnectionError(message)

    def range_query(self,
                    table_name,
                    column_names,
                    start_key,
                    end_key,
                    limit,
                    offset=0,
                    start_inclusive=True,
                    end_inclusive=True,
                    keys_only=False):
        """ 
    Gets a dense range ordered by keys. Returns an ordered list of 
    a dictionary of [key:{column1:value1, column2:value2},...]
    or a list of keys if keys only.
     
    Args:
      table_name: Name of table to access
      column_names: Columns which get returned within the key range
      start_key: String for which the query starts at
      end_key: String for which the query ends at
      limit: Maximum number of results to return
      offset: Cuts off these many from the results [offset:]
      start_inclusive: Boolean if results should include the start_key
      end_inclusive: Boolean if results should include the end_key
      keys_only: Boolean if to only keys and not values
    Raises:
      TypeError: If an argument passed in was not of the expected type.
      AppScaleDBConnectionError: If the range_query could not be performed due
        to an error with Cassandra.
    Returns:
      An ordered list of dictionaries of key=>columns/values
    """
        if not isinstance(table_name, str):
            raise TypeError('table_name must be a string')
        if not isinstance(column_names, list):
            raise TypeError('column_names must be a list')
        if not isinstance(start_key, str):
            raise TypeError('start_key must be a string')
        if not isinstance(end_key, str):
            raise TypeError('end_key must be a string')
        if not isinstance(limit, (int, long)) and limit is not None:
            raise TypeError('limit must be int, long, or NoneType')
        if not isinstance(offset, (int, long)):
            raise TypeError('offset must be int or long')

        if start_inclusive:
            gt_compare = '>='
        else:
            gt_compare = '>'

        if end_inclusive:
            lt_compare = '<='
        else:
            lt_compare = '<'

        query_limit = ''
        if limit is not None:
            query_limit = 'LIMIT {}'.format(len(column_names) * limit)

        statement = """
      SELECT * FROM "{table}" WHERE
      token({key}) {gt_compare} %s AND
      token({key}) {lt_compare} %s AND
      {column} IN %s
      {limit}
      ALLOW FILTERING
    """.format(table=table_name,
               key=ThriftColumn.KEY,
               gt_compare=gt_compare,
               lt_compare=lt_compare,
               column=ThriftColumn.COLUMN_NAME,
               limit=query_limit)

        query = SimpleStatement(statement, retry_policy=self.retry_policy)
        parameters = (bytearray(start_key), bytearray(end_key),
                      ValueSequence(column_names))

        try:
            results = self.session.execute(query, parameters=parameters)

            results_list = []
            current_item = {}
            current_key = None
            for (key, column, value) in results:
                if keys_only:
                    results_list.append(key)
                    continue

                if key != current_key:
                    if current_item:
                        results_list.append({current_key: current_item})
                    current_item = {}
                    current_key = key

                current_item[column] = value
            if current_item:
                results_list.append({current_key: current_item})
            return results_list[offset:]
        except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
            message = 'Exception during range_query'
            logging.exception(message)
            raise AppScaleDBConnectionError(message)

    def get_metadata(self, key):
        """ Retrieve a value from the datastore metadata table.

    Args:
      key: A string containing the key to fetch.
    Returns:
      A string containing the value or None if the key is not present.
    """
        statement = """
      SELECT {value} FROM "{table}"
      WHERE {key} = %s
      AND {column} = %s
    """.format(value=ThriftColumn.VALUE,
               table=dbconstants.DATASTORE_METADATA_TABLE,
               key=ThriftColumn.KEY,
               column=ThriftColumn.COLUMN_NAME)
        try:
            results = self.session.execute(statement, (bytearray(key), key))
        except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
            message = 'Unable to fetch {} from datastore metadata'.format(key)
            logging.exception(message)
            raise AppScaleDBConnectionError(message)

        try:
            return results[0].value
        except IndexError:
            return None

    def set_metadata(self, key, value):
        """ Set a datastore metadata value.

    Args:
      key: A string containing the key to set.
      value: A string containing the value to set.
    """
        if not isinstance(key, str):
            raise TypeError('key should be a string')

        if not isinstance(value, str):
            raise TypeError('value should be a string')

        statement = """
      INSERT INTO "{table}" ({key}, {column}, {value})
      VALUES (%(key)s, %(column)s, %(value)s)
    """.format(table=dbconstants.DATASTORE_METADATA_TABLE,
               key=ThriftColumn.KEY,
               column=ThriftColumn.COLUMN_NAME,
               value=ThriftColumn.VALUE)
        parameters = {
            'key': bytearray(key),
            'column': key,
            'value': bytearray(value)
        }
        try:
            self.session.execute(statement, parameters)
        except dbconstants.TRANSIENT_CASSANDRA_ERRORS:
            message = 'Unable to set datastore metadata for {}'.format(key)
            logging.exception(message)
            raise AppScaleDBConnectionError(message)
        except cassandra.InvalidRequest:
            self.create_table(dbconstants.DATASTORE_METADATA_TABLE,
                              dbconstants.DATASTORE_METADATA_SCHEMA)
            self.session.execute(statement, parameters)

    def get_indices(self, app_id):
        """ Gets the indices of the given application.

    Args:
      app_id: Name of the application.
    Returns:
      Returns a list of encoded entity_pb.CompositeIndex objects.
    """
        start_key = dbconstants.KEY_DELIMITER.join([app_id, 'index', ''])
        end_key = dbconstants.KEY_DELIMITER.join(
            [app_id, 'index', dbconstants.TERMINATING_STRING])
        result = self.range_query(dbconstants.METADATA_TABLE,
                                  dbconstants.METADATA_SCHEMA,
                                  start_key,
                                  end_key,
                                  dbconstants.MAX_NUMBER_OF_COMPOSITE_INDEXES,
                                  offset=0,
                                  start_inclusive=True,
                                  end_inclusive=True)
        list_result = []
        for list_item in result:
            for key, value in list_item.iteritems():
                list_result.append(value['data'])
        return list_result

    def valid_data_version(self):
        """ Checks whether or not the data layout can be used.

    Returns:
      A boolean.
    """
        try:
            version = self.get_metadata(VERSION_INFO_KEY)
        except cassandra.InvalidRequest:
            return False

        return version is not None and float(version) == EXPECTED_DATA_VERSION