Ejemplo n.º 1
0
class GrpcOptionsSchema(Schema):
    """
    Manifest gRPC options schema
    """
    client_host = fields.String(required=True, default=DEFAULT_HOST)
    server_host = fields.String(required=True, default=DEFAULT_HOST)
    secure_channel = fields.Bool(required=True, default=DEFAULT_IS_SECURE)
    port = fields.String(required=True, default=DEFAULT_PORT)
    grace = fields.Float(required=True, default=DEFAULT_GRACE)
Ejemplo n.º 2
0
 class Schema(Schema):
     base = fields.FilePath()
     package = fields.String()
     bindings = fields.List(Binding.Schema(), default=[])
     bootstraps = fields.List(Bootstrap.Schema(), default=[])
     values = fields.Dict(default={})
     logging = fields.Nested({
         'level':
         fields.Enum(fields.String(),
                     {'DEBUG', 'INFO', 'WARNING', 'CRITICAL', 'ERROR'},
                     default='DEBUG')
     })
Ejemplo n.º 3
0
class Resource(Entity, metaclass=ResourceMeta):

    # metaclass attribute stubs:
    ravel = None

    # base resource fields:
    _id = fields.UuidString(default=lambda: uuid.uuid4().hex, nullable=False)
    _rev = fields.String()

    def __init__(self, state=None, **more_state):
        # initialize internal state data dict
        self.internal = DictObject()
        self.internal.state = DirtyDict()
        self.merge(state, **more_state)

        # eagerly generate default ID if none provided
        if ID not in self.internal.state:
            id_func = self.ravel.defaults.get(ID)
            self.internal.state[ID] = id_func() if id_func else None

    def __getitem__(self, key):
        if key in self.ravel.resolvers:
            return getattr(self, key)
        raise KeyError(key)

    def __setitem__(self, key, value):
        if key in self.ravel.resolvers:
            return setattr(self, key, value)
        raise KeyError(key)

    def __delitem__(self, key):
        if key in self.ravel.resolvers:
            delattr(self, key)
        else:
            raise KeyError(key)

    def __iter__(self):
        return iter(self.internal.state)

    def __contains__(self, key):
        return key in self.internal.state

    def __repr__(self):
        name = get_class_name(self)
        dirty = '*' if self.is_dirty else ''
        id_value = self.internal.state.get(ID)
        if id_value is None:
            id_str = '?'
        elif isinstance(id_value, str):
            id_str = id_value[:7]
        elif isinstance(id_value, uuid.UUID):
            id_str = id_value.hex[:7]
        else:
            id_str = repr(id_value)

        return f'{name}({id_str}){dirty}'

    @classmethod
    def __abstract__(cls) -> bool:
        return True

    @classmethod
    def __protected__(cls) -> bool:
        return False

    @classmethod
    def __store__(cls) -> Type[Store]:
        return SimulationStore

    @classmethod
    def on_bootstrap(cls, app, **kwargs):
        pass

    @classmethod
    def on_bind(cls):
        pass

    @classmethod
    def bootstrap(cls, app: 'Application', **kwargs):
        t1 = datetime.now()
        cls.ravel.app = app

        # bootstrap all resolvers owned by this class
        for resolver in cls.ravel.resolvers.values():
            if not resolver.is_bootstrapped():
                resolver.bootstrap(cls.ravel.app)

        # lastly perform custom developer logic
        cls.on_bootstrap(app, **kwargs)
        cls.ravel.local.is_bootstrapped = True

        t2 = datetime.now()
        secs = (t2 - t1).total_seconds()
        console.debug(f'bootstrapped {TypeUtils.get_class_name(cls)} '
                      f'in {secs:.2f}s')

    @classmethod
    def bind(cls, store: 'Store', **kwargs):
        t1 = datetime.now()
        cls.ravel.local.store = store

        for resolver in cls.ravel.resolvers.values():
            resolver.bind()

        cls.on_bind()
        cls.ravel.is_bound = True

        t2 = datetime.now()
        secs = (t2 - t1).total_seconds()
        console.debug(f'bound {TypeUtils.get_class_name(store)} to '
                      f'{TypeUtils.get_class_name(cls)} '
                      f'in {secs:.2f}s')

    @classmethod
    def is_bootstrapped(cls) -> bool:
        return cls.ravel.local.is_bootstrapped

    @classmethod
    def is_bound(cls) -> bool:
        return cls.ravel.is_bound

    @property
    def app(self) -> 'Application':
        if not self.ravel.app:
            raise NotBootstrapped(f'{get_class_name(self)} must be associated '
                                  f'with a bootstrapped app')
        return self.ravel.app

    @property
    def log(self) -> 'ConsoleLoggerInterface':
        return self.ravel.app.log if self.ravel.app else None

    @property
    def class_name(self) -> Text:
        return get_class_name(self)

    @property
    def store(self) -> 'Store':
        return self.ravel.local.store

    @property
    def is_dirty(self) -> bool:
        return bool(self.internal.state.dirty
                    & self.ravel.schema.fields.keys())

    @property
    def dirty(self) -> Dict:
        return {
            k: self.internal.state[k]
            for k in self.internal.state.dirty if k in self.Schema.fields
        }

    @classmethod
    def generate(
        cls,
        keys: Set[Text] = None,
        values: Dict = None,
        use_defaults=True,
    ) -> 'Resource':
        instance = cls()
        values = values or {}
        keys = keys or set(cls.ravel.schema.fields.keys())
        resolver_objs = Resolver.sort(
            [cls.ravel.resolvers[k] for k in keys if k not in {REV}])

        instance = cls()

        for resolver in resolver_objs:
            if resolver.name in values:
                value = values[resolver.name]
            elif use_defaults and resolver.name in cls.ravel.defaults:
                func = cls.ravel.defaults[resolver.name]
                value = func()
            else:
                value = resolver.generate(instance)

            instance.internal.state[resolver.name] = value

        return instance

    def set(self, other=None, **values) -> 'Resource':
        return self.merge(other=other, **values)

    def merge(self, other=None, **values) -> 'Resource':
        try:
            if isinstance(other, dict):
                for k, v in other.items():
                    setattr(self, k, v)
            elif isinstance(other, Resource):
                for k, v in other.internal.state.items():
                    setattr(self, k, v)

            if values:
                self.merge(values)

            return self
        except Exception:
            self.app.log.error(message=(f'failed to merge object into '
                                        f'resource {str(self._id)[:7]}'),
                               data={
                                   'resource': self._id,
                                   'class': self.class_name,
                                   'other': other,
                                   'values': values,
                               })
            raise

    def clean(self, fields=None) -> 'Resource':
        if fields:
            fields = fields if is_sequence(fields) else {fields}
            keys = self._normalize_selectors(fields)
        else:
            keys = set(self.ravel.resolvers.keys())

        if keys:
            self.internal.state.clean(keys=keys)

        return self

    def mark(self, fields=None) -> 'Resource':
        # TODO: rename "mark" method to "touch"
        if fields is not None:
            if not fields:
                return self
            fields = fields if is_sequence(fields) else {fields}
            keys = self._normalize_selectors(fields)
        else:
            keys = set(self.Schema.fields.keys())

        self.internal.state.mark(keys)
        return self

    def dump(self,
             resolvers: Optional[Set[Text]] = None,
             style: Optional[DumpStyle] = None) -> Dict:
        """
        Dump the fields of this business object along with its related objects
        (declared as relationships) to a plain ol' dict.
        """
        # get Dumper instance based on DumpStyle (nested, side-loaded, etc)
        dumper = Dumper.for_style(style or DumpStyle.nested)

        if resolvers is not None:
            # only dump resolver state specifically requested
            keys_to_dump = self._normalize_selectors(resolvers)
        else:
            # or else dump all instance state
            keys_to_dump = list(self.internal.state.keys())

        dumped_instance_state = dumper.dump(self, keys=keys_to_dump)
        return dumped_instance_state

    def copy(self) -> 'Resource':
        """
        Create a clone of this Resource
        """
        clone = type(self)(state=deepcopy(self.internal.state))
        return clone.clean()

    def validate(self, resolvers: Set[Text] = None, strict=False) -> Dict:
        """
        Validate an object's loaded state data. If you need to check if some
        state data is loaded or not and raise an exception in case absent,
        use self.require.
        """
        errors = {}
        resolver_names_to_validate = (resolvers
                                      or set(self.ravel.resolvers.keys()))
        for name in resolver_names_to_validate:
            resolver = self.ravel.resolvers[name]
            if name not in self.internal.state:
                console.warning(message=f'skipping {name} validation',
                                data={'reason': 'not loaded'})
            else:
                value = self.internal.state.get(name)
                if value is None and not resolver.nullable:
                    errors[name] = 'not nullable'
                if name in self.ravel.schema.fields:
                    field = self.ravel.schema.fields[name]
                    _value, error = field.process(value)
                    if error is not None:
                        errors[name] = error

        if strict and errors:
            console.error(message='validation error', data={'errors': errors})
            raise ValidationError('see error log for details')

        return errors

    def require(
        self,
        resolvers: Set[Text] = None,
        use_defaults=True,
        strict=False,
        overwrite=False,
    ) -> Set[Text]:
        """
        Checks if all specified resolvers are present. If they are required
        but not present, an exception will be raised for `strict` mode;
        otherwise, a set of the missing resolver names is returned.
        """
        if isinstance(resolvers, str):
            resolvers = {resolvers}
        required_resolver_names = (resolvers
                                   or set(k
                                          for k in self.ravel.resolvers.keys()
                                          if self.ravel.resolvers[k].required))
        missing_resolver_names = set()
        defaults = self.ravel.defaults
        for name in required_resolver_names:
            resolver = self.ravel.resolvers[name]
            if overwrite or name not in self.internal.state:
                if use_defaults and name in defaults:
                    value = defaults[name]()
                    if value is None and not resolver.nullable:
                        raise ValidationError(f'{name} not nullable')
                    self[name] = value
                else:
                    missing_resolver_names.add(name)

        if strict and missing_resolver_names:
            console.error(message=f'{self.class_name} missing required data',
                          data={'missing': missing_resolver_names})
            raise ValidationError('see error log for details')

        return missing_resolver_names

    def resolve(self, resolvers: Union[Text, Set[Text]] = None) -> 'Resource':
        """
        Execute each of the resolvers, specified by name, storing the results
        in `self.internal.state`.
        """
        if self._id is None:
            return self

        if isinstance(resolvers, str):
            resolvers = {resolvers}
        elif not resolvers:
            resolvers = set(self.ravel.resolvers.keys())

        # execute all requested resolvers
        for k in resolvers:
            resolver = self.ravel.resolvers.get(k)
            if resolver is not None:
                if k in self.ravel.resolvers.fields:
                    # field loader resolvers are treated specially to overcome
                    # the limitation of Resolver.target always expecte to be a
                    # Resource class.
                    # if resolvers == {'session_state'}:
                    #     print(resolvers)
                    #     import ipdb; ipdb.set_trace()
                    obj = resolver.resolve(self)
                    setattr(self, k, getattr(obj, k))
                else:
                    setattr(self, k, resolver.resolve(self))

        # clean the resolved values so they arent't accidently saved on
        # update/create, as we just fetched them from the store.
        self.clean(resolvers)

        return self

    def refresh(self, keys: Union[Text, Set[Text]] = None) -> 'Resource':
        """
        For any field currently loaded in the resource, fetch a fresh copy
        from the store.
        """
        if not keys:
            keys = set(self.internal.state.keys())
        elif isinstance(keys, str):
            keys = {keys}

        return self.resolve(keys)

    def unload(self, keys: Set[Text] = None) -> 'Resource':
        """
        Remove the given keys from field data and/or relationship data.
        """
        if keys:
            if isinstance(keys, str):
                keys = {keys}
                keys = self._normalize_selectors(keys)
        else:
            keys = set(self.internal.state.keys()
                       | self.ravel.resolvers.keys())
        for k in keys:
            if k in self.internal.state:
                del self.internal.state[k]

        return self

    def is_loaded(self, resolvers: Union[Text, Set[Text]]) -> bool:
        """
        Are all given field and/or relationship values loaded?
        """
        if resolvers:
            if isinstance(resolvers, str):
                resolvers = {resolvers}
                keys = self._normalize_selectors(resolvers)
        else:
            keys = set(self.internal.state.keys()
                       | self.ravel.resolvers.keys())

        for k in keys:
            is_key_in_data = k in self.internal.state
            is_key_in_resolvers = k in self.ravel.resolvers
            if not (is_key_in_data or is_key_in_resolvers):
                return False

        return True

    def _prepare_record_for_create(self,
                                   keys_to_save: Optional[Set[Text]] = None):
        """
        Prepares a a Resource state dict for insertion via DAL.
        """
        # extract only those elements of state data that correspond to
        # Fields declared on this Resource class.
        if ID not in self.internal.state:
            self._id = self.ravel.local.store.create_id(self.internal.state)

        # when inserting or updating, we don't want to write the _rev value on
        # accident. The DAL is solely responsible for modifying this value.
        if REV in self.internal.state:
            del self.internal.state[REV]

        record = {}
        keys_to_save = keys_to_save or set(self.internal.state.keys())
        keys_to_save |= self.ravel.defaults.keys()
        keys_to_save &= self.ravel.resolvers.fields.keys()

        for key in keys_to_save:
            resolver = self.ravel.resolvers[key]
            default = self.ravel.defaults.get(key)
            if key not in self.internal.state:
                if default is not None:
                    self.internal.state[key] = value = default()
                    record[key] = value
                elif resolver.required:
                    raise ValidationError(f'{key} is a required field')
            else:
                value = self.internal.state[key]
                if value is None and (not resolver.nullable):
                    if default:
                        self.internal.state[key] = value = default()
                    # if the value is still none, just remove it from
                    # the state dict instead of raising
                    if self.internal.state[key] is None:
                        console.warning(
                            message=(f'resolved None for {resolver} but not '
                                     f'nullable'),
                            data={
                                'resource': self._id,
                                'class': self.class_name,
                                'field': key,
                            })
                        del self.internal.state[key]
                        continue
                record[key] = self.internal.state[key]

        return record

    @staticmethod
    def _normalize_selectors(selectors: Set):
        keys = set()
        for k in selectors:
            if isinstance(k, str):
                keys.add(k)
            elif isinstance(k, ResolverProperty):
                keys.add(k.name)
        return keys

    # CRUD Methods

    @classmethod
    def select(
        cls,
        *resolvers: Tuple[Text],
        parent: 'Query' = None,
        request: 'Request' = None,
    ) -> 'Query':
        query = Query(target=cls, request=request, parent=parent)
        query.select(resolvers)
        cls.on_select(query)
        query.callbacks.append(cls.post_select)
        return query

    def create(self, data: Dict = None, **more_data) -> 'Resource':
        self.pre_create()

        data = dict(data or {}, **more_data)
        if data:
            self.merge(data)

        prepared_record = self._prepare_record_for_create()
        prepared_record.pop(REV, None)

        self.on_create(prepared_record)

        created_record = self.ravel.local.store.dispatch(
            'create', (prepared_record, ))

        self.merge(created_record)
        self.clean()
        self.post_create()

        return self

    def setdefault(self, key, value):
        if key in self.internal.state:
            return self[key]
        else:
            self[key] = value
            return value

    @classmethod
    def get(cls, _id, select=None) -> Optional[Union['Resource', 'Batch']]:
        if _id is None:
            return None

        if is_sequence(_id):
            return cls.get_many(_ids=_id, select=select)

        if not select:
            select = set(cls.ravel.schema.fields.keys())
        elif not isinstance(select, set):
            select = set(select)

        select |= {ID, REV}
        select -= cls.ravel.virtual_fields.keys()

        cls.on_get(_id, select)

        state = cls.ravel.local.store.dispatch('fetch', (_id, ),
                                               {'fields': select})

        resource = cls(state=state).clean() if state else None

        cls.post_get(resource)

        return resource

    @classmethod
    def get_many(
        cls,
        _ids: List = None,
        select=None,
        offset=None,
        limit=None,
        order_by=None,
    ) -> 'Batch':
        """
        Return a list of Resources in the store.
        """
        if not _ids:
            return cls.Batch()

        if not select:
            select = set(cls.ravel.schema.fields)
        elif isinstance(select, set):
            select = set(select)

        select |= {ID, REV}
        select -= cls.ravel.virtual_fields.keys()

        if not (offset or limit or order_by):
            store = cls.ravel.local.store
            args = (_ids, )
            kwargs = {'fields': select}
            states = store.dispatch('fetch_many', args, kwargs).values()
            cls.on_get_many(_ids, select)
            batch = cls.Batch(
                cls(state=state).clean() for state in states
                if state is not None)
            cls.post_get_many(batch)
        else:
            query = cls.select(select).where(cls._id.including(_ids))
            query = query.order_by(order_by).offset(offset).limit(limit)
            cls.on_select(query)
            batch = query.execute()
            cls.post_select(query, batch)

        return batch

    @classmethod
    def get_all(
        cls,
        select: Set[Text] = None,
        order_by: List['OrderBy'] = None,
        offset: int = None,
        limit: int = None,
    ) -> 'Batch':
        """
        Return a list of all Resources in the store.
        """
        # build the query
        query = cls.select().where(cls._id != None)

        if select:
            query = query.select(select)
        if order_by:
            query = query.order_by(order_by)
        if limit and limit > 0:
            query = query.limit(limit)
        if offset is not None and offset >= 0:
            query = query.offset(offset)

        cls.on_select(query)

        batch = query.execute()

        cls.post_select(query, batch)

        return batch

    def delete(self) -> 'Resource':
        """
        Call delete on this object's store and therefore mark all fields as
        dirty and delete its _id so that save now triggers Store.create.
        """
        self.ravel.local.store.dispatch('delete', (self._id, ))
        self.mark(self.internal.state.keys())
        self._id = None
        self._rev = None
        return self

    @classmethod
    def delete_many(cls, resources: List['Resource']) -> None:
        # extract ID's of all objects to delete and clear
        # them from the instance objects' state dicts
        resource_ids = []
        for resource in resources:
            resource.mark()
            resource_ids.append(resource._id)
            resource._id = None
            resource._rev = None

        if resource_ids:
            store = cls.ravel.local.store
            cls.on_delete_many(resource_ids)
            store.dispatch('delete_many', args=(resource_ids, ))
            cls.post_delete_many(resource_ids)

    @classmethod
    def delete_all(cls) -> None:
        store = cls.ravel.local.store
        store.dispatch('delete_all')

    @classmethod
    def exists(cls, entity: 'Entity') -> bool:
        """
        Does a simple check if a Resource exists by id.
        """
        store = cls.ravel.local.store

        if not entity:
            return False

        if is_resource(entity):
            args = (entity._id, )
        else:
            id_value, errors = cls._id.resolver.field.process(entity)
            args = (id_value, )
            if errors:
                raise ValueError(str(errors))

        return store.dispatch('exists', args=args)

    @classmethod
    def exists_many(cls, entity: 'Entity') -> bool:
        """
        Does a simple check if a Resource exists by id.
        """
        store = cls.ravel.local.store

        if not entity:
            return False

        if is_batch(entity):
            args = (entity._id, )
        else:
            assert is_sequence(entity)
            id_list = entity
            args = (id_list, )
            for id_value in id_list:
                value, errors = cls._id.resolver.field.process(id_value)
                if errors:
                    raise ValueError(str(errors))

        return store.dispatch('exists_many', args=args)

    def save(self,
             resolvers: Union[Text, Set[Text]] = None,
             depth=0) -> 'Resource':
        return self.save_many([self], resolvers=resolvers, depth=depth)[0]

    def update(self, data: Dict = None, **more_data) -> 'Resource':
        data = dict(data or {}, **more_data)
        if data:
            self.merge(data)

        raw_record = self.dirty.copy()
        raw_record.pop(REV, None)
        raw_record.pop(ID, None)

        errors = {}
        prepared_record = {}
        for k, v in raw_record.items():
            field = self.Schema.fields.get(k)
            if field is not None:
                if field.name not in self.ravel.virtual_fields:
                    if v is None and field.nullable:
                        prepared_record[k] = None
                    else:
                        prepared_record[k], error = field.process(v)
                        if error:
                            console.error(
                                f'{self} failed validation for {k}: {v}')
                            errors[k] = error

        self.on_update(prepared_record)

        if errors:
            raise ValidationError(f'update failed for {self}: {errors}')

        updated_record = self.ravel.local.store.dispatch(
            'update', (self._id, prepared_record))

        if updated_record:
            self.merge(updated_record)

        self.clean(prepared_record.keys() | updated_record.keys())
        self.post_update()

        return self

    @classmethod
    def create_many(cls,
                    resources: List['Resource'],
                    fields: Set[Text] = None) -> 'Batch':
        """
        Call `store.create_method` on input `Resource` list and return them in
        the form of a Batch.
        """
        # normalize resources to a list of Resource objects
        records = []
        prepared_resources = []
        for resource in resources:
            if resource is None:
                continue
            if isinstance(resource, dict):
                # convert raw dict into a proper Resource object
                state_dict = resource
                resource = cls(state=state_dict)

            record = resource._prepare_record_for_create(fields)
            records.append(record)

            resource.internal.state.update(record)
            prepared_resources.append(resource)

        if not records:
            return cls.Batch()

        cls.on_create_many(records)

        store = cls.ravel.local.store
        created_records = store.dispatch('create_many', (records, ))

        for resource, record in zip(prepared_resources, created_records):
            resource.merge(record)
            resource.clean()

        # insert the batch to the store
        batch = cls.Batch(prepared_resources)
        cls.post_create_many(batch)

        return batch

    @classmethod
    def update_many(cls,
                    resources: List['Resource'],
                    fields: Set[Text] = None,
                    data: Dict = None,
                    **more_data) -> 'Batch':
        """
        Call the Store's update_many method on the list of Resources.
        Multiple Store calls may be made. As a preprocessing step, the input
        resource list is partitioned into groups, according to which subset
        of fields are dirty.

        For example, consider this list of resources,

        ```python
        resources = [
            user1,     # dirty == {'email'}
            user2,     # dirty == {'email', 'name'}
            user3,     # dirty == {'email'}
        ]
        ```

        Calling update on this list will result in two paritions:
        ```python
        assert part1 == {user1, user3}
        assert part2 == {user2}
        ```

        A spearate store call to `update_many` will be made for each partition.
        """
        # common_values are values that should be updated
        # across all objects.
        common_values = dict(data or {}, **more_data)

        # in the procedure below, we partition all incoming Resources
        # into groups, grouped by the set of fields being updated. In this way,
        # we issue an update_many datament for each partition in the DAL.
        partitions = defaultdict(cls.Batch)

        fields_to_update = fields

        for resource in resources:
            if resource is None:
                continue
            if common_values:
                resource.merge(common_values)

            partition_key = tuple(resource.dirty.keys())
            partitions[partition_key].append(resource)

        # id_2_copies used to synchronize updated state across all
        # instances that share the same ID.
        id_2_copies = defaultdict(list)

        cls.on_update_many(cls.Batch(resources))

        for partition_key_tuple, resource_partition in partitions.items():
            records, _ids = [], []

            for resource in resource_partition:
                record = resource.dirty.copy()
                record.pop(REV, None)
                record.pop(ID, None)
                if fields_to_update:
                    record = {
                        k: v
                        for k, v in record.items()
                        if ((k in fields_to_update) and (
                            k not in cls.ravel.virtual_fields))
                    }
                records.append(record)
                _ids.append(resource._id)

            store = cls.ravel.local.store
            updated_records = store.dispatch('update_many', (_ids, records))

            if not updated_records:
                resource_partition.clean(partition_key_tuple)
                continue

            for resource in resource_partition:
                record = updated_records.get(resource._id)
                if record:
                    resource.merge(record)
                    resource.clean(record.keys())

                    # sync updated state across previously encoutered
                    # instances of this resource (according to ID)
                    if resource._id in id_2_copies:
                        for res_copy in id_2_copies[resource._id]:
                            res_copy.merge(record)
                    id_2_copies[resource._id].append(resource)

        updated_resources = cls.Batch(resources)

        cls.post_update_many(updated_resources)

        return updated_resources

    @classmethod
    def save_many(cls,
                  resources: List['Resource'],
                  resolvers: Union[Text, Set[Text]] = None,
                  depth: int = 0) -> 'Batch':
        """
        Essentially a bulk upsert.
        """
        def seems_created(resource):
            return ((ID in resource.internal.state)
                    and (ID not in resource.internal.state.dirty))

        if resolvers is not None:
            if isinstance(resolvers, str):
                resolvers = {resolvers}
            elif not isinstance(resolvers, set):
                resolvers = set(resolvers)
            fields_to_save = set()
            resolvers_to_save = set()
            for k in resolvers:
                if k in cls.ravel.schema.fields:
                    fields_to_save.add(k)
                else:
                    resolvers_to_save.add(k)
        else:
            fields_to_save = None
            resolvers_to_save = set()

        # partition resources into those that are "uncreated" and those which
        # simply need to be updated.
        to_update = []
        to_create = []
        for resource in resources:
            # TODO: merge duplicates
            if not seems_created(resource):
                to_create.append(resource)
            else:
                to_update.append(resource)

        # perform bulk create and update
        if to_create:
            cls.create_many(to_create, fields=fields_to_save)
        if to_update:
            cls.update_many(to_update, fields=fields_to_save)

        retval = cls.Batch(to_update + to_create)

        if depth < 1:
            # base case. do not recurse on Resolvers
            return retval

        # aggregate and save all Resources referenced by all objects in
        # `resource` via their resolvers.
        class_2_objects = defaultdict(set)
        resolvers = cls.ravel.resolvers.by_tag('fields', invert=True)
        for resolver in resolvers.values():
            if resolver.name not in resolvers_to_save:
                continue
            for resource in resources:
                if resolver.name in resource.internal.state:
                    value = resource.internal.state[resolver.name]
                    resolver.on_save(resolver, resource, value)
                    if value:
                        if is_resource(value):
                            class_2_objects[resolver.owner].add(value)
                        else:
                            assert is_sequence(value)
                            class_2_objects[resolver.owner].update(value)
                    elif value is None and not resolver.nullable:
                        raise ValidationError(
                            f'{get_class_name(cls)}.{resolver.name} '
                            f'is required by save')

        # recursively call save_many for each type of Resource
        for resource_type, resources in class_2_objects.items():
            resource_type.save_many(resources, depth=depth - 1)

        return retval

    @classmethod
    def on_select(cls, query: 'Query'):
        pass

    @classmethod
    def post_select(cls, query: 'Query', result: 'Entity'):
        pass

    @classmethod
    def on_get(cls, _id, fields):
        pass

    @classmethod
    def on_get_many(cls, _ids, fields):
        pass

    @classmethod
    def post_get_many(cls, resources):
        pass

    @classmethod
    def post_get(cls, resource: 'Resource'):
        pass

    @classmethod
    def on_update_many(cls, batch):
        pass

    @classmethod
    def post_update_many(cls, batch):
        pass

    @classmethod
    def on_create_many(cls, batch):
        pass

    @classmethod
    def post_create_many(cls, batch):
        pass

    @classmethod
    def on_delete_many(cls, _ids):
        pass

    @classmethod
    def post_delete_many(cls, _ids):
        pass

    @classmethod
    def post_get(cls, resource: 'Resource'):
        pass

    def on_update(self, values: Dict):
        pass

    def post_update(self):
        pass

    def pre_create(self):
        pass

    def on_create(self, values: Dict):
        pass

    def post_create(self):
        pass

    def on_delete(self):
        pass

    def post_delete(self):
        pass
Ejemplo n.º 4
0
 class Schema(Schema):
     store = fields.String()
     default = fields.Bool(default=False)
     params = fields.Dict(default={})
Ejemplo n.º 5
0
 class Schema(Schema):
     resource = fields.String()
     store = fields.String()
     params = fields.Dict(default={})
Ejemplo n.º 6
0
class SqlalchemyStore(Store):
    """
    A SQLAlchemy-based store, which keeps a single connection pool (AKA
    Engine) shared by all threads; however, each thread keeps singleton
    thread-local database connection and transaction objects, managed through
    connect()/close() and begin()/end().
    """

    env = Environment(
        SQLALCHEMY_STORE_ECHO=fields.Bool(default=False),
        SQLALCHEMY_STORE_SHOW_QUERIES=fields.Bool(default=False),
        SQLALCHEMY_STORE_DIALECT=fields.Enum(fields.String(),
                                             Dialect.values(),
                                             default=Dialect.sqlite),
        SQLALCHEMY_STORE_PROTOCOL=fields.String(default='sqlite'),
        SQLALCHEMY_STORE_USER=fields.String(),
        SQLALCHEMY_STORE_HOST=fields.String(),
        SQLALCHEMY_STORE_PORT=fields.String(),
        SQLALCHEMY_STORE_NAME=fields.String(),
    )

    # id_column_names is a mapping from table name to its _id column name
    _id_column_names = {}

    # only one thread needs to bootstrap the SqlalchemyStore. This lock is
    # used to ensure that this is what happens when the host app bootstraps.
    _bootstrap_lock = RLock()

    @classmethod
    def get_default_adapters(cls, dialect: Dialect,
                             table_name) -> List[Field.Adapter]:
        # TODO: Move this into the adapters file

        adapters = [
            fields.Field.adapt(
                on_adapt=lambda field: sa.Text,
                on_encode=lambda x: cls.ravel.app.json.encode(x),
                on_decode=lambda x: cls.ravel.app.json.decode(x),
            ),
            fields.Email.adapt(on_adapt=lambda field: sa.Text),
            fields.Bytes.adapt(on_adapt=lambda field: sa.LargeBinary),
            fields.BcryptString.adapt(on_adapt=lambda field: sa.Text),
            fields.Float.adapt(on_adapt=lambda field: sa.Float),
            fields.DateTime.adapt(on_adapt=lambda field: UtcDateTime),
            fields.Timestamp.adapt(on_adapt=lambda field: UtcDateTime),
            fields.Bool.adapt(on_adapt=lambda field: sa.Boolean),
            fields.TimeDelta.adapt(on_adapt=lambda field: sa.Interval),
            fields.Enum.adapt(
                on_adapt=lambda field: {
                    fields.String: sa.Text,
                    fields.Int: sa.Integer,
                    fields.Float: sa.Float,
                }[type(field.nested)]),
        ]
        adapters.extend(
            field_class.adapt(on_adapt=lambda field: sa.Text)
            for field_class in {
                fields.String, fields.FormatString, fields.UuidString,
                fields.DateTimeString
            })
        adapters.extend(
            field_class.adapt(on_adapt=lambda field: sa.BigInteger)
            for field_class in {
                fields.Int,
                fields.Uint32,
                fields.Uint64,
                fields.Uint,
                fields.Int32,
            })
        if dialect == Dialect.postgresql:
            adapters.extend(cls.get_postgresql_default_adapters(table_name))
        elif dialect == Dialect.mysql:
            adapters.extend(cls.get_mysql_default_adapters(table_name))
        elif dialect == Dialect.sqlite:
            adapters.extend(cls.get_sqlite_default_adapters(table_name))

        return adapters

    @classmethod
    def get_postgresql_default_adapters(cls,
                                        table_name) -> List[Field.Adapter]:
        pg_types = sa.dialects.postgresql

        def on_adapt_list(field):
            if isinstance(field.nested, fields.Enum):
                name = f'{table_name}__{field.name}'
                return ArrayOfEnum(
                    pg_types.ENUM(*field.nested.values, name=name))
            return pg_types.ARRAY({
                fields.String: sa.Text,
                fields.Email: sa.Text,
                fields.Uuid: pg_types.UUID,
                fields.Int: sa.Integer,
                fields.Bool: sa.Boolean,
                fields.Float: sa.Float,
                fields.DateTime: UtcDateTime,
                fields.Timestamp: UtcDateTime,
                fields.Dict: pg_types.JSONB,
                fields.Field: pg_types.JSONB,
                fields.Nested: pg_types.JSONB,
            }.get(type(field.nested), sa.Text))

        return [
            Point.adapt(
                on_adapt=lambda field: GeoalchemyGeometry(field.geo_type),
                on_encode=lambda x: x.to_EWKT_string(),
                on_decode=lambda x:
                (PointGeometry(x['geometry']['coordinates']) if x else None)),
            Polygon.adapt(
                on_adapt=lambda field: GeoalchemyGeometry(field.geo_type),
                on_encode=lambda x: x.to_EWKT_string(),
                on_decode=lambda x: PolygonGeometry(x['geometry'][
                    'coordinates'] if x else None)),
            fields.Field.adapt(on_adapt=lambda field: pg_types.JSONB),
            fields.Uuid.adapt(on_adapt=lambda field: pg_types.UUID),
            fields.Dict.adapt(on_adapt=lambda field: pg_types.JSONB),
            fields.Nested.adapt(on_adapt=lambda field: pg_types.JSONB, ),
            fields.Set.adapt(on_adapt=lambda field: pg_types.JSONB,
                             on_encode=lambda x: list(x),
                             on_decode=lambda x: set(x)),
            fields.UuidString.adapt(
                on_adapt=lambda field: pg_types.UUID,
                on_decode=lambda x: x.replace('-', '') if x else x,
            ),
            fields.List.adapt(on_adapt=on_adapt_list)
        ]

    @classmethod
    def get_mysql_default_adapters(cls, table_name) -> List[Field.Adapter]:
        return [
            fields.Dict.adapt(on_adapt=lambda field: sa.JSON),
            fields.Nested.adapt(on_adapt=lambda field: sa.JSON),
            fields.List.adapt(on_adapt=lambda field: sa.JSON),
            fields.Set.adapt(
                on_adapt=lambda field: sa.JSON,
                on_encode=lambda x: cls.ravel.app.json.encode(x),
                on_decode=lambda x: set(cls.ravel.app.json.decode(x))),
        ]

    @classmethod
    def get_sqlite_default_adapters(cls, table_name) -> List[Field.Adapter]:
        adapters = [
            field_class.adapt(
                on_adapt=lambda field: sa.Text,
                on_encode=lambda x: cls.ravel.app.json.encode(x),
                on_decode=lambda x: cls.ravel.app.json.decode(x),
            ) for field_class in {fields.Dict, fields.List, fields.Nested}
        ]
        adapters.append(
            fields.Set.adapt(
                on_adapt=lambda field: sa.Text,
                on_encode=lambda x: cls.ravel.app.json.encode(x),
                on_decode=lambda x: set(cls.ravel.app.json.decode(x))))
        return adapters

    def __init__(self, adapters: List[Field.Adapter] = None):
        super().__init__()
        self._custom_adapters = adapters or []
        self._table = None
        self._builder = None
        self._adapters = None
        self._id_column = None
        self._options = {}

    @property
    def adapters(self):
        return self._adapters

    @property
    def id_column_name(self):
        return self.resource_type.Schema.fields[ID].source

    def prepare(self, record: Dict, serialize=True) -> Dict:
        """
        When inserting or updating data, the some raw values in the record
        dict must be transformed before their corresponding sqlalchemy column
        type will accept the data.
        """
        cb_name = 'on_encode' if serialize else 'on_decode'
        prepared_record = {}
        for k, v in record.items():
            if k in (REV):
                prepared_record[k] = v
            adapter = self._adapters.get(k)
            if adapter:
                callback = getattr(adapter, cb_name, None)
                if callback:
                    try:
                        prepared_record[k] = callback(v)
                        continue
                    except Exception:
                        console.error(
                            message=f'failed to adapt column value: {k}',
                            data={
                                'value': v,
                                'field': adapter.field_class
                            })
                        raise
            prepared_record[k] = v
        return prepared_record

    def adapt_id(self, _id, serialize=True):
        cb_name = 'on_encode' if serialize else 'on_decode'
        adapter = self._adapters.get(self.id_column_name)
        if adapter:
            callback = getattr(adapter, cb_name)
            if callback:
                return callback(_id)
        return _id

    @classmethod
    def on_bootstrap(cls,
                     url=None,
                     dialect=None,
                     echo=False,
                     db=None,
                     **kwargs):
        """
        Initialize the SQLAlchemy connection pool (AKA Engine).
        """
        with cls._bootstrap_lock:
            cls.ravel.kwargs = kwargs

            # construct the URL to the DB server
            # url can be a string or a dict
            if isinstance(url, dict):
                url_parts = url
                cls.ravel.app.shared.sqla_url = (
                    '{protocol}://{user}@{host}:{port}/{db}'.format(
                        **url_parts))
            elif isinstance(url, str):
                cls.ravel.app.shared.sqla_url = url
            else:
                url_parts = dict(
                    protocol=cls.env.SQLALCHEMY_STORE_PROTOCOL,
                    user=cls.env.SQLALCHEMY_STORE_USER or '',
                    host=('@' + cls.env.SQLALCHEMY_STORE_HOST
                          if cls.env.SQLALCHEMY_STORE_HOST else ''),
                    port=(':' + cls.env.SQLALCHEMY_STORE_PORT
                          if cls.env.SQLALCHEMY_STORE_PORT else ''),
                    db=('/' + (db or cls.env.SQLALCHEMY_STORE_NAME or '')))
                cls.ravel.app.shared.sqla_url = url or (
                    '{protocol}://{user}{host}{port}{db}'.format(**url_parts))

            cls.dialect = dialect or cls.env.SQLALCHEMY_STORE_DIALECT

            from sqlalchemy.dialects import postgresql, sqlite, mysql

            cls.sa_dialect = None
            if cls.dialect == Dialect.postgresql:
                cls.sa_dialect = postgresql
            elif cls.dialect == Dialect.sqlite:
                cls.sa_dialect = sqlite
            elif cls.dialect == Dialect.mysql:
                cls.sa_dialect = mysql

            console.debug(message='creating sqlalchemy engine',
                          data={
                              'echo': cls.env.SQLALCHEMY_STORE_ECHO,
                              'dialect': cls.dialect,
                              'url': cls.ravel.app.shared.sqla_url,
                          })

            cls.ravel.local.sqla_tx = None
            cls.ravel.local.sqla_conn = None
            cls.ravel.local.sqla_metadata = sa.MetaData()
            cls.ravel.local.sqla_metadata.bind = sa.create_engine(
                name_or_url=cls.ravel.app.shared.sqla_url,
                echo=bool(echo or cls.env.SQLALCHEMY_STORE_ECHO),
                strategy='threadlocal')

            # set global thread-local sqlalchemy store method aliases
            cls.ravel.app.local.create_tables = cls.create_tables

    def on_bind(self,
                resource_type: Type['Resource'],
                table: Text = None,
                schema: 'Schema' = None,
                **kwargs):
        """
        Initialize SQLAlchemy data strutures used for constructing SQL
        expressions used to manage the bound resource type.
        """
        # map each of the resource's schema fields to a corresponding adapter,
        # which is used to prepare values upon insert and update.
        table = (table
                 or SqlalchemyTableBuilder.derive_table_name(resource_type))
        field_class_2_adapter = {
            adapter.field_class: adapter
            for adapter in self.get_default_adapters(self.dialect, table) +
            self._custom_adapters
        }
        self._adapters = {
            field_name: field_class_2_adapter[type(field)]
            for field_name, field in self.resource_type.Schema.fields.items()
            if (type(field) in field_class_2_adapter
                and field.meta.get('ravel_on_resolve') is None)
        }

        # build the Sqlalchemy Table object for the bound resource type.
        self._builder = SqlalchemyTableBuilder(self)

        try:
            self._table = self._builder.build_table(name=table, schema=schema)
        except Exception:
            console.error(f'failed to build sa.Table: {table}')
            raise

        self._id_column = getattr(self._table.c, self.id_column_name)

        # remember which column is the _id column
        self._id_column_names[self._table.name] = self.id_column_name

        # set SqlalchemyStore options here, using bootstrap-level
        # options as base/default options.
        self._options = dict(self.ravel.kwargs, **kwargs)

    def query(
        self,
        predicate: 'Predicate',
        fields: Set[Text] = None,
        limit: int = None,
        offset: int = None,
        order_by: Tuple = None,
        **kwargs,
    ):
        fields = fields or {k: None for k in self._adapters}
        fields.update({
            self.id_column_name: None,
            self.resource_type.Schema.fields[REV].source: None,
        })

        columns = []
        table_alias = self.table.alias(''.join(
            s.strip('_')[0] for s in self.table.name.split('_')))
        for k in fields:
            col = getattr(table_alias.c, k)
            if isinstance(col.type, GeoalchemyGeometry):
                columns.append(sa.func.ST_AsGeoJSON(col).label(k))
            else:
                columns.append(col)

        predicate = Predicate.deserialize(predicate)
        filters = self._prepare_predicate(table_alias, predicate)

        # build the query object
        query = sa.select(columns).where(filters)

        if order_by:
            sa_order_by = [
                sa.desc(getattr(table_alias.c, x.key)) if x.desc else sa.asc(
                    getattr(table_alias.c, x.key)) for x in order_by
            ]
            query = query.order_by(*sa_order_by)

        if limit is not None:
            query = query.limit(max(0, limit))
        if offset is not None:
            query = query.offset(max(0, limit))

        console.debug(
            message=(f'SQL: SELECT FROM {self.table}' +
                     (f' OFFSET {offset}' if offset is not None else '') +
                     (f' LIMIT {limit}' if limit else '') +
                     (f' ORDER BY {", ".join(x.to_sql() for x in order_by)}'
                      if order_by else '')),
            data={
                'stack':
                traceback.format_stack(),
                'statement':
                str(query.compile(
                    compile_kwargs={'literal_binds': True})).split('\n')
            } if self.env.SQLALCHEMY_STORE_SHOW_QUERIES else None)
        # execute query, aggregating resulting records
        cursor = self.conn.execute(query)
        records = []

        while True:
            page = [
                self.prepare(dict(row.items()), serialize=False)
                for row in cursor.fetchmany(512)
            ]
            if page:
                records.extend(page)
            else:
                break

        return records

    def _prepare_predicate(self, table, pred, empty=set()):
        if isinstance(pred, ConditionalPredicate):
            if not pred.ignore_field_adapter:
                adapter = self._adapters.get(pred.field.source)
                if adapter and adapter.on_encode:
                    pred.value = adapter.on_encode(pred.value)
            col = getattr(table.c, pred.field.source)
            if pred.op == OP_CODE.EQ:
                return col == pred.value
            elif pred.op == OP_CODE.NEQ:
                return col != pred.value
            if pred.op == OP_CODE.GEQ:
                return col >= pred.value
            elif pred.op == OP_CODE.GT:
                return col > pred.value
            elif pred.op == OP_CODE.LT:
                return col < pred.value
            elif pred.op == OP_CODE.LEQ:
                return col <= pred.value
            elif pred.op == OP_CODE.INCLUDING:
                return col.in_(pred.value)
            elif pred.op == OP_CODE.EXCLUDING:
                return ~col.in_(pred.value)
            elif pred.op == POSTGIS_OP_CODE.CONTAINS:
                if isinstance(pred.value, GeometryObject):
                    EWKT_str = pred.value.to_EWKT_string()
                else:
                    EWKT_str = pred.value
                return sa.func.ST_Contains(
                    col,
                    sa.func.ST_GeomFromEWKT(EWKT_str),
                )
            elif pred.op == POSTGIS_OP_CODE.CONTAINED_BY:
                if isinstance(pred.value, GeometryObject):
                    EWKT_str = pred.value.to_EWKT_string()
                else:
                    EWKT_str = pred.value
                return sa.func.ST_Contains(sa.func.ST_GeomFromEWKT(EWKT_str),
                                           col)
            elif pred.op == POSTGIS_OP_CODE.WITHIN_RADIUS:
                center = pred.value['center']
                radius = pred.value['radius']
                return sa.func.ST_PointInsideCircle(col, center[0], center[1],
                                                    radius)
            else:
                raise Exception('unrecognized conditional predicate')
        elif isinstance(pred, BooleanPredicate):
            if pred.op == OP_CODE.AND:
                lhs_result = self._prepare_predicate(table, pred.lhs)
                rhs_result = self._prepare_predicate(table, pred.rhs)
                return sa.and_(lhs_result, rhs_result)
            elif pred.op == OP_CODE.OR:
                lhs_result = self._prepare_predicate(table, pred.lhs)
                rhs_result = self._prepare_predicate(table, pred.rhs)
                return sa.or_(lhs_result, rhs_result)
            else:
                raise Exception('unrecognized boolean predicate')
        else:
            raise Exception('unrecognized predicate type')

    def exists(self, _id) -> bool:
        columns = [sa.func.count(self._id_column)]
        query = (sa.select(columns).where(
            self._id_column == self.adapt_id(_id)))
        result = self.conn.execute(query)
        return bool(result.scalar())

    def exists_many(self, _ids: Set) -> Dict[object, bool]:
        columns = [self._id_column, sa.func.count(self._id_column)]
        query = (sa.select(columns).where(
            self._id_column.in_([self.adapt_id(_id) for _id in _ids])))
        return {row[0]: row[1] for row in self.conn.execute(query)}

    def count(self) -> int:
        query = sa.select([sa.func.count(self._id_column)])
        result = self.conn.execute(query)
        return result.scalar()

    def fetch(self, _id, fields=None) -> Dict:
        records = self.fetch_many(_ids=[_id], fields=fields)
        return records[_id] if records else None

    def fetch_many(self, _ids: List, fields=None, as_list=False) -> Dict:
        prepared_ids = [self.adapt_id(_id, serialize=True) for _id in _ids]

        if fields:
            if not isinstance(fields, set):
                fields = set(fields)
        else:
            fields = {
                f.source
                for f in self.resource_type.Schema.fields.values()
                if f.name in self._adapters
            }
        fields.update({
            self.id_column_name,
            self.resource_type.Schema.fields[REV].source,
        })

        columns = []
        for k in fields:
            col = getattr(self.table.c, k)
            if isinstance(col.type, GeoalchemyGeometry):
                columns.append(sa.func.ST_AsGeoJSON(col).label(k))
            else:
                columns.append(col)

        select_stmt = sa.select(columns)

        id_col = getattr(self.table.c, self.id_column_name)

        if prepared_ids:
            select_stmt = select_stmt.where(id_col.in_(prepared_ids))
        cursor = self.conn.execute(select_stmt)
        records = {} if not as_list else []

        while True:
            page = cursor.fetchmany(512)
            if page:
                for row in page:
                    raw_record = dict(row.items())
                    record = self.prepare(raw_record, serialize=False)
                    _id = self.adapt_id(row[self.id_column_name],
                                        serialize=False)
                    if as_list:
                        records.append(record)
                    else:
                        records[_id] = record
            else:
                break

        return records

    def fetch_all(self, fields: Set[Text] = None) -> Dict:
        return self.fetch_many([], fields=fields)

    def create(self, record: dict) -> dict:
        record[self.id_column_name] = self.create_id(record)
        prepared_record = self.prepare(record, serialize=True)
        insert_stmt = self.table.insert().values(**prepared_record)
        _id = prepared_record.get('_id', '')
        console.debug(f'SQL: INSERT {str(_id)[:7] + " " if _id else ""}'
                      f'INTO {self.table}')
        try:
            if self.supports_returning:
                insert_stmt = insert_stmt.return_defaults()
                result = self.conn.execute(insert_stmt)
                return dict(record, **(result.returned_defaults or {}))
            else:
                result = self.conn.execute(insert_stmt)
                return self.fetch(_id=record[self.id_column_name])
        except Exception:
            console.error(message=f'failed to insert record',
                          data={
                              'record': record,
                              'resource': get_class_name(self.resource_type),
                          })
            raise

    def create_many(self, records: List[Dict]) -> Dict:
        prepared_records = []
        nullable_fields = self.resource_type.Schema.nullable_fields
        for record in records:
            record[self.id_column_name] = self.create_id(record)
            prepared_record = self.prepare(record, serialize=True)
            prepared_records.append(prepared_record)
            for nullable_field in nullable_fields.values():
                if nullable_field.name not in prepared_record:
                    prepared_record[nullable_field.name] = None

        try:
            self.conn.execute(self.table.insert(), prepared_records)
        except Exception:
            console.error(f'failed to insert records')
            raise

        n = len(prepared_records)
        id_list_str = (', '.join(
            str(x['_id'])[:7] for x in prepared_records if x.get('_id')))
        console.debug(f'SQL: INSERT {id_list_str} INTO {self.table} ' +
                      (f'(count: {n})' if n > 1 else ''))

        if self.supports_returning:
            # TODO: use implicit returning if possible
            pass

        return self.fetch_many((rec[self.id_column_name] for rec in records),
                               as_list=True)

    def update(self, _id, data: Dict) -> Dict:
        prepared_id = self.adapt_id(_id)
        prepared_data = self.prepare(data, serialize=True)
        if prepared_data:
            update_stmt = (self.table.update().values(**prepared_data).where(
                self._id_column == prepared_id))
        else:
            return prepared_data
        if self.supports_returning:
            update_stmt = update_stmt.return_defaults()
            console.debug(f'SQL: UPDATE {self.table}')
            result = self.conn.execute(update_stmt)
            return dict(data, **(result.returned_defaults or {}))
        else:
            self.conn.execute(update_stmt)
            if self._options.get('fetch_on_update', True):
                return self.fetch(_id)
            return data

    def update_many(self, _ids: List, data: List[Dict] = None) -> None:
        assert data

        prepared_ids = []
        prepared_records = []

        for _id, record in zip(_ids, data):
            prepared_id = self.adapt_id(_id)
            prepared_record = self.prepare(record, serialize=True)
            if prepared_record:
                prepared_ids.append(prepared_id)
                prepared_records.append(prepared_record)
                prepared_record[ID] = prepared_id

        if prepared_records:
            n = len(prepared_records)
            console.debug(f'SQL: UPDATE {self.table} ' +
                          (f'({n}x)' if n > 1 else ''))
            values = {k: bindparam(k) for k in prepared_records[0].keys()}
            update_stmt = (self.table.update().where(
                self._id_column == bindparam(self.id_column_name)).values(
                    **values))
            self.conn.execute(update_stmt, prepared_records)

        if self._options.get('fetch_on_update', True):
            if self.supports_returning:
                # TODO: use implicit returning if possible
                return self.fetch_many(_ids)
            else:
                return self.fetch_many(_ids)
        return

    def delete(self, _id) -> None:
        prepared_id = self.adapt_id(_id)
        delete_stmt = self.table.delete().where(self._id_column == prepared_id)
        self.conn.execute(delete_stmt)

    def delete_many(self, _ids: list) -> None:
        prepared_ids = [self.adapt_id(_id) for _id in _ids]
        delete_stmt = self.table.delete().where(
            self._id_column.in_(prepared_ids))
        self.conn.execute(delete_stmt)

    def delete_all(self):
        delete_stmt = self.table.delete()
        self.conn.execute(delete_stmt)

    @property
    def table(self):
        return self._table

    @property
    def conn(self):
        sqla_conn = getattr(self.ravel.local, 'sqla_conn', None)
        if sqla_conn is None:
            # lazily initialize a connection for this thread
            self.connect()
        return self.ravel.local.sqla_conn

    @property
    def supports_returning(self):
        if not self.is_bootstrapped():
            return False
        metadata = self.get_metadata()
        return metadata.bind.dialect.implicit_returning

    @classmethod
    def create_tables(cls, overwrite=False):
        """
        Create all tables for all SqlalchemyStores used in the host app.
        """
        if not cls.is_bootstrapped():
            console.error(f'{get_class_name(cls)} cannot create '
                          f'tables unless bootstrapped')
            return

        meta = cls.get_metadata()
        engine = cls.get_engine()

        if overwrite:
            console.info('dropping Resource SQL tables...')
            meta.drop_all(engine)

        # create all tables
        console.info('creating Resource SQL tables...')
        meta.create_all(engine)

    @classmethod
    def get_active_connection(cls):
        return getattr(cls.ravel.local, 'sqla_conn', None)

    @classmethod
    def connect(cls, refresh=True):
        """
        Create a singleton thread-local SQLAlchemy connection, shared across
        all Resources backed by a SQLAlchemy store. When working with multiple
        threads or processes, make sure to 
        """
        sqla_conn = getattr(cls.ravel.local, 'sqla_conn', None)
        metadata = cls.ravel.local.sqla_metadata
        if sqla_conn is not None:
            console.warning(
                message='sqlalchemy store already has connection', )
            if refresh:
                cls.close()
                cls.ravel.local.sqla_conn = metadata.bind.connect()
        else:
            cls.ravel.local.sqla_conn = metadata.bind.connect()

        return cls.ravel.local.sqla_conn

    @classmethod
    def close(cls):
        """
        Return the thread-local database connection to the sqlalchemy
        connection pool (AKA the "engine").
        """
        sqla_conn = getattr(cls.ravel.local, 'sqla_conn', None)
        if sqla_conn is not None:
            console.debug('closing sqlalchemy connection')
            sqla_conn.close()
            cls.ravel.local.sqla_conn = None
        else:
            console.warning('sqlalchemy has no connection to close')

    @classmethod
    def begin(cls, auto_connect=True, **kwargs):
        """
        Initialize a thread-local transaction. An exception is raised if
        there's already a pending transaction.
        """
        conn = cls.get_active_connection()
        if conn is None:
            if auto_connect:
                conn = cls.connect()
            else:
                raise Exception('no active sqlalchemy connection')

        existing_tx = getattr(cls.ravel.local, 'sqla_tx', None)
        if existing_tx is not None:
            console.debug('there is already an open transaction')
        else:
            new_tx = cls.ravel.local.sqla_conn.begin()
            cls.ravel.local.sqla_tx = new_tx

    @classmethod
    def commit(cls, rollback=True, **kwargs):
        """
        Call commit on the thread-local database transaction. "Begin" must be
        called to start a new transaction at this point, if a new transaction
        is desired.
        """
        def perform_sqlalchemy_commit():
            tx = getattr(cls.ravel.local, 'sqla_tx', None)
            if tx is not None:
                cls.ravel.local.sqla_tx.commit()
                cls.ravel.local.sqla_tx = None

        # try to commit the transaction.
        console.debug(f'committing sqlalchemy transaction')
        try:
            perform_sqlalchemy_commit()
        except Exception:
            if rollback:
                # if the commit fails, rollback the transaction
                console.critical(f'rolling back sqlalchemy transaction')
                cls.rollback()
            else:
                console.exception(f'sqlalchemy transaction failed commit')
        finally:
            # ensure we close the connection either way
            cls.close()

    @classmethod
    def rollback(cls, **kwargs):
        tx = getattr(cls.ravel.local, 'sqla_tx', None)
        if tx is not None:
            cls.ravel.local.sqla_tx = None
            try:
                tx.rollback()
            except:
                console.exception(f'sqlalchemy transaction failed to rollback')

    @classmethod
    def has_transaction(cls) -> bool:
        return cls.ravel.local.sqla_tx is not None

    @classmethod
    def get_metadata(cls):
        return cls.ravel.local.sqla_metadata

    @classmethod
    def get_engine(cls):
        return cls.get_metadata().bind

    @classmethod
    def dispose(cls):
        meta = cls.get_metadata()
        if not meta:
            cls.ravel.local.sqla_metadata = None
            return

        engine = meta.bind
        engine.dispose()