Beispiel #1
0
    def __init__(
        self,
        model_: Optional[Type[model.Model]] = None,
        table_name: Optional[str] = None,
        primary_keys: Optional[List[str]] = None,
        fields: Optional[Dict[str, field.Field]] = None,
        relations: Optional[Dict[str, relationship.Relationship]] = None,
        interleaved: Optional[Type["model.Model"]] = None,
    ):
        if not ((model_ is not None) ^
                (table_name is not None and primary_keys is not None
                 and fields is not None)):
            raise error.SpannerError(
                "Exactly one of: [model_], [table_name, primary_keys, fields is, relations] is required"
            )
        if model_ and (table_name or primary_keys or fields or relations
                       or interleaved):
            raise error.SpannerError(
                "Can not specify any other optional param if model_ is specified"
            )

        if model_:
            self._table_name = model_.table
            self._primary_keys = model_.primary_keys
            self._fields = model_.fields
            self._relations = model_.relations
            self._interleaved = model_.interleaved
        else:
            self._table_name = table_name
            self._primary_keys = primary_keys
            self._fields = fields
            self._relations = relations or {}
            self._interleaved = interleaved
Beispiel #2
0
    def _execute_write(cls, db_api: Callable[..., Any],
                       transaction: Optional[spanner_transaction.Transaction],
                       dictionaries: Iterable[Dict[str, Any]]) -> None:
        """Validates all write value types and commits write to Spanner."""
        columns, values = None, []
        for dictionary in dictionaries:
            invalid_keys = set(dictionary.keys()) - set(cls.columns)
            if invalid_keys:
                raise error.SpannerError(
                    'Invalid keys set on {model}: {keys}'.format(
                        model=cls.__name__, keys=invalid_keys))

            if columns is None:
                columns = dictionary.keys()
            if columns != dictionary.keys():
                raise error.SpannerError(
                    'Attempted to update rows with different sets of keys')

            for key, value in dictionary.items():
                cls.validate_value(key, value, error.SpannerError)
            values.append([dictionary[column] for column in columns])

        args = [cls.table, columns, values]
        if transaction is not None:
            return db_api(transaction, *args)
        else:
            return cls.spanner_api().run_write(db_api, *args)
    def rollback(self, target_migration: str) -> None:
        """Rolls back migrated migrations on the curent database.

    Note: SpannerAdminApi connection is modified as a result of calling
    this method. Other connections to SpannerAdminApi in the same process
    may be affected.

    Args:
      target_migration: Stop rolling back migrations after this migration is
        rolled back. Must be present in list of migrated migrations.
    """
        if not target_migration:
            raise error.SpannerError('Must specify a migration to roll back')

        self._connect()
        self._validate_migrations()
        # Filter to migrated migrations from most recently applied
        migrations = self._filter_migrations(reversed(self.migrations()), True,
                                             target_migration)
        for migration_ in migrations:
            _logger.info('Processing migration %s', migration_.migration_id)
            schema_update = migration_.downgrade()
            if not isinstance(schema_update, update.SchemaUpdate):
                raise error.SpannerError(
                    'Migration {} did not return a SchemaUpdate'.format(
                        migration_.migration_id))
            schema_update.execute()

            self._update_status(migration_.migration_id, False)
        self._hangup()
Beispiel #4
0
 def validate(self) -> None:
   model_ = metadata.SpannerMetadata.model(self._table)
   if not model_:
     raise error.SpannerError('Table {} does not exist'.format(self._table))
   if not self._field.nullable():
     raise error.SpannerError('Column {} is not nullable'.format(self._column))
   if self._field.primary_key():
     raise error.SpannerError('Column {} is a primary key'.format(
         self._column))
Beispiel #5
0
 def _validate_not_interleaved(self,
                               existing_model: Type[model.Model]) -> None:
   for model_ in metadata.SpannerMetadata.models().values():
     if model_.interleaved == existing_model:
       raise error.SpannerError('Table {} has interleaved table {}'.format(
           self._table, model_.table))
     for index in model_.indexes.values():
       if index.parent == self._table:
         raise error.SpannerError('Table {} has interleaved index {}'.format(
             self._table, index.name))
Beispiel #6
0
  def _validate_primary_keys(self) -> None:
    """Verifies that the primary key data is valid."""
    if not self._model.primary_keys:
      raise error.SpannerError('Table {} has no primary key'.format(
          self._model.table))

    for key in self._model.primary_keys:
      if key not in self._model.fields:
        raise error.SpannerError(
            'Table {} column {} in primary key but not in schema'.format(
                self._model.table, key))
Beispiel #7
0
  def validate(self) -> None:
    existing_model = metadata.SpannerMetadata.model(self._table)
    if not existing_model:
      raise error.SpannerError('Table {} does not exist'.format(self._table))

    # Model indexes include the primary index
    if len(existing_model.indexes) > 1:
      raise error.SpannerError('Table {} has a secondary index'.format(
          self._table))

    self._validate_not_interleaved(existing_model)
Beispiel #8
0
    def get(self, name: Union[Type[Any], str]) -> Type[Any]:
        if isinstance(name, type):
            name = self._name_from_class(name)

        if name not in self._registered:
            raise error.SpannerError(
                '{} was not found, verify it has been imported'.format(name))
        if len(self._registered[name].references) > 1:
            raise error.SpannerError(
                'Multiple classes match {}, add more specificity'.format(name))
        return self._registered[name].references[0]
Beispiel #9
0
  def validate(self) -> None:
    model_ = metadata.SpannerMetadata.model(self._table)
    if not model_:
      raise error.SpannerError('Table {} does not exist'.format(self._table))

    db_index = model_.indexes.get(self._index)
    if not db_index:
      raise error.SpannerError('Index {} does not exist'.format(self._index))
    if db_index.primary:
      raise error.SpannerError('Index {} is the primary index'.format(
          self._index))
Beispiel #10
0
    def _validate_primary_keys(self) -> None:
        """Verifies that the primary key data is valid."""
        if not self._primary_keys:
            raise error.SpannerError("Table {} has no primary key".format(
                self._table_name))

        for key in self._primary_keys:
            if key not in self._fields:
                raise error.SpannerError(
                    "Table {} column {} in primary key but not in schema".
                    format(self._table_name, key))
Beispiel #11
0
    def _validate_parent(self) -> None:
        """Verifies that the parent table information is valid."""
        parent_primary_keys = self._interleaved.primary_keys
        primary_keys = self._primary_keys

        message = "Table {} is not a child of parent table {}".format(
            self._table_name, self._interleaved.table)
        for parent_key, key in zip(parent_primary_keys, primary_keys):
            if parent_key != key:
                raise error.SpannerError(message)
        if len(parent_primary_keys) > len(primary_keys):
            raise error.SpannerError(message)
Beispiel #12
0
    def validate(self) -> None:
        if not self._table_name:
            raise error.SpannerError("New table has no name")

        existing_model = metadata.SpannerMetadata.model(self._table_name)
        if existing_model:
            raise error.SpannerError("Table {} already exists".format(
                self._table_name))

        if self._interleaved:
            self._validate_parent()

        self._validate_primary_keys()
Beispiel #13
0
  def _validate_columns(self, model_: Type[model.Model]) -> None:
    """Verifies all columns exist and are not part of the primary key."""
    for column in self._columns:
      if column not in model_.columns:
        raise error.SpannerError('Table {} has no column {}'.format(
            self._table, column))

    for column in self._storing_columns:
      if column not in model_.columns:
        raise error.SpannerError('Table {} has no column {}'.format(
            self._table, column))
      if column in model_.primary_keys:
        raise error.SpannerError('{} is part of the primary key for {}'.format(
            column, self._table))
Beispiel #14
0
  def validate(self) -> None:
    model_ = metadata.SpannerMetadata.model(self._table)
    if not model_:
      raise error.SpannerError('Table {} does not exist'.format(self._table))

    if self._column not in model_.fields:
      raise error.SpannerError('Column {} does not exist on {}'.format(
          self._column, self._table))

    # Verify no indices exist on the column we're trying to drop
    num_indexed_columns = index_column.IndexColumnSchema.count(
        None, condition.equal_to('column_name', self._column),
        condition.equal_to('table_name', self._table))
    if num_indexed_columns > 0:
      raise error.SpannerError('Column {} is indexed'.format(self._column))
Beispiel #15
0
  def validate(self) -> None:
    model_ = metadata.SpannerMetadata.model(self._table)
    if not model_:
      raise error.SpannerError('Table {} does not exist'.format(self._table))

    if not self._columns:
      raise error.SpannerError('Index {} has no columns'.format(self._index))

    if self._index in model_.indexes:
      raise error.SpannerError('Index {} already exists'.format(self._index))

    self._validate_columns(model_)

    if self._parent_table:
      self._validate_parent(model_)
Beispiel #16
0
    def _migration_from_file(self, filename: str) -> migration.Migration:
        """Loads a single migration from the given filename in the base dir."""
        module_name = re.sub(r"\W", "_", filename)
        path = os.path.join(self.basedir, filename)

        if self._pkg_name is not None:
            # Prepend package name to module name to import full module name
            module_name = "{}.{}".format(self._pkg_name, module_name)

        spec = importlib.util.spec_from_file_location(module_name, path)
        module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(module)
        module_doc = module.__doc__.split("\n")
        if not module_doc:
            description = "<unknown>"
        else:
            description = module_doc[0]
        try:
            result = migration.Migration(
                module.migration_id,
                module.prev_migration_id,
                description,
                getattr(module, "upgrade", None),
                getattr(module, "downgrade", None),
            )
        except AttributeError:
            raise error.SpannerError("{} has no migration id".format(path))
        return result
Beispiel #17
0
    def finalize(self) -> None:
        """Finish generating metadata state.

    Some metadata depends on having all configuration data set before it can
    be calculated--the primary index, for example, needs all fields to be added
    before it can be calculated. This method is called to indicate that all
    relevant state has been added and the calculation of the final data should
    now happen.
    """
        if self._finalized:
            raise error.SpannerError('Metadata was already finalized')
        sorted_fields = list(
            sorted(self.fields.values(), key=lambda f: f.position))

        if index.Index.PRIMARY_INDEX not in self.indexes:
            primary_keys = [f.name for f in sorted_fields if f.primary_key()]
            primary_index = index.Index(primary_keys)
            primary_index.name = index.Index.PRIMARY_INDEX
            self.indexes[index.Index.PRIMARY_INDEX] = primary_index
        self.primary_keys = self.indexes[index.Index.PRIMARY_INDEX].columns

        self.columns = [f.name for f in sorted_fields]

        for _, relation in self.relations.items():
            relation.origin = self.model_class
        registry.model_registry().register(self.model_class)
        self._finalized = True
Beispiel #18
0
 def add_field(self, name: str, new_field: field.Field) -> None:
     new_field.name = name
     new_field.position = len(self.fields)
     if new_field.name in self.fields:
         raise error.SpannerError(
             'Already contains a field named "{}"'.format(new_field.name))
     self.fields[name] = new_field
    def migrate(self, target_migration: Optional[str] = None) -> None:
        """Executes unmigrated migrations on the curent database.

    Note: SpannerAdminApi connection is modified as a result of calling
    this method. Other connections to SpannerAdminApi in the same process
    may be affected.

    Args:
      target_migration: If present, stop migrations after the target is
        executed. If None (default), executes all unmigrated migrations
    """
        self._connect()
        self._validate_migrations()
        # Filter to unmigrated migrations
        migrations = self._filter_migrations(self.migrations(), False,
                                             target_migration)
        for migration_ in migrations:
            _logger.info('Processing migration %s', migration_.migration_id)
            schema_update = migration_.upgrade()
            if not isinstance(schema_update, update.SchemaUpdate):
                raise error.SpannerError(
                    'Migration {} did not return a SchemaUpdate'.format(
                        migration_.migration_id))
            schema_update.execute()

            self._update_status(migration_.migration_id, True)
        self._hangup()
Beispiel #20
0
  def field_type(self) -> Type[field.FieldType]:
    for field_type in field.ALL_TYPES:
      if self.spanner_type == field_type.ddl():
        return field_type

    raise error.SpannerError('No corresponding Type for {}'.format(
        self.spanner_type))
    def _filter_migrations(
        self,
        migrations: Iterable[migration.Migration],
        migrated: bool,
        last_migration: Optional[str] = None,
    ) -> List[migration.Migration]:
        """Filters the list of migrations according to the desired conditions.

    Args:
      migrations: List of migrations to filter
      migrated: Only add migrations whose migration status matches this flag
      last_migration: Stop adding migrations to the list after this one is found

    Returns:
      List of filtered migrations
    """
        filtered = []
        last_migration_found = False
        for migration_ in migrations:
            if self.migrated(migration_.migration_id) == migrated:
                filtered.append(migration_)

                if last_migration and migration_.migration_id == last_migration:
                    last_migration_found = True
                    break

        if last_migration and not last_migration_found:
            raise error.SpannerError(
                "{} already has desired status or does not exist".format(
                    last_migration))
        return filtered
Beispiel #22
0
    def __init__(self, values: Dict[str, Any], persisted: bool = False):
        start_values = {}
        self.__dict__['start_values'] = start_values
        self.__dict__['_persisted'] = persisted

        # If the values came from Spanner, trust them and skip validation
        if not persisted:
            # An object is invalid if primary key values are missing
            missing_keys = set(self._primary_keys) - set(values.keys())
            if missing_keys:
                raise error.SpannerError(
                    'All primary keys must be specified. Missing: {keys}'.
                    format(keys=missing_keys))

            for column in self._columns:
                self._metaclass.validate_value(column, values.get(column),
                                               ValueError)

        for column in self._columns:
            value = values.get(column)
            start_values[column] = copy.copy(value)
            self.__dict__[column] = value

        for relation in self._relations:
            if relation in values:
                self.__dict__[relation] = values[relation]
Beispiel #23
0
 def __init__(self, model: Type["Model"], conditions: Iterable[condition.Condition]):
     super().__init__(model, conditions)
     for c in conditions:
         if c.segment() not in [condition.Segment.WHERE, condition.Segment.FROM]:
             raise error.SpannerError(
                 "Only conditions that affect the WHERE or "
                 "FROM clauses are allowed for count queries"
             )
 def destination(self) -> Type[Any]:
     if not self.relation:
         raise error.SpannerError(
             'Condition must be bound before destination is called')
     if self.foreign_key_relation:
         return self.relation.constraint.referenced_table
     else:
         return self.relation.destination
 def __init__(self, *orderings: Tuple[Union[field.Field, str], OrderType]):
     super().__init__()
     for (_, order_type) in orderings:
         if not isinstance(order_type, OrderType):
             raise error.SpannerError(
                 '{order} is not of type OrderType'.format(
                     order=order_type))
     self.orderings = orderings
Beispiel #26
0
 def __init__(self, *condition_lists: List[Condition]):
     super().__init__()
     if len(condition_lists) < 2:
         raise error.SpannerError(
             "OrCondition requires at least two lists of conditions")
     self.condition_lists = condition_lists
     self.all_conditions = []
     for conditions in condition_lists:
         self.all_conditions.extend(conditions)
    def _validate_migrations(self) -> None:
        """Validates the migration status of all migrations makes sense."""
        migrations = self.migrations()
        if not migrations:
            return

        first = migrations[0]
        if not self.migrated(first.prev_migration_id):
            raise error.SpannerError(
                "First migration {} depends on unmigrated migration {}".format(
                    first.migration_id, first.prev_migration_id))

        for migration_ in migrations:
            if self.migrated(migration_.migration_id) and not self.migrated(
                    migration_.prev_migration_id):
                raise error.SpannerError(
                    "Migrated migration {} depends on an unmigrated migration".
                    format(migration_.migration_id))
    def __init__(self, value: int, offset: int = 0):
        super().__init__()
        for param in [value, offset]:
            if not isinstance(param, int):
                raise error.SpannerError(
                    '{param} is not of type int'.format(param=param))

        self.limit = value
        self.offset = offset
Beispiel #29
0
    def types(self) -> Dict[str, type_pb2.Type]:
        """Returns parameter types to be used in the SQL query.

    Returns:
      Dictionary mapping from parameter name to the type of the value that
      should be substituted for that parameter in the SQL query
    """
        if not self.model_class:
            raise error.SpannerError("Condition must be bound before usage")
        return self._types()
    def params(self) -> Dict[str, Any]:
        """Returns parameters to be used in the SQL query.

    Returns:
      Dictionary mapping from parameter name to value that should be
      substituted for that parameter in the SQL query
    """
        if not self.model_class:
            raise error.SpannerError('Condition must be bound before usage')
        return self._params()