Ejemplo n.º 1
0
class UnitOfWorkTestCase(Pep8CompliantTestCase):
    def set_up(self):
        self._uow = UnitOfWork()

    def test_basics(self):
        ent = _MyEntity(id=0)
        self._uow.register_new(_MyEntity, ent)
        self.assert_equal(EntityState.get_state(ent).status,
                          ENTITY_STATUS.NEW)
        self.assert_equal([item.entity for item in self._uow.iterator()],
                          [ent])
        self.assert_equal(list(self._uow.get_new(_MyEntity)), [ent])
        self._uow.mark_clean(ent)
        self.assert_equal(list(self._uow.get_clean(_MyEntity)), [ent])
        self._uow.mark_dirty(ent)
        self.assert_equal(list(self._uow.get_dirty(_MyEntity)), [ent])
        self._uow.mark_deleted(ent)
        self.assert_equal(list(self._uow.get_deleted(_MyEntity)), [ent])
        self._uow.unregister(_MyEntity, ent)
        self.assert_equal(list(self._uow.iterator()), [])
        self._uow.reset()

    def test_get_state_unregistered_fails(self):
        ent = _MyEntity()
        with self.assert_raises(ValueError) as cm:
            EntityState.get_state(ent)
        msg = 'Trying to obtain state for un-managed entity'
        self.assert_true(cm.exception.args[0].startswith(msg))

    def test_is_marked_unregistered(self):
        ent = _MyEntity()
        self.assert_false(self._uow.is_marked_persisted(ent))
        self.assert_false(self._uow.is_marked_pending(ent))

    def test_mark_unregistered_fails(self):
        ent = _MyEntity()
        with self.assert_raises(ValueError) as cm:
            self._uow.mark_dirty(ent)
        msg = 'Trying to obtain state for un-managed entity'
        self.assert_true(cm.exception.args[0].startswith(msg))

    def test_release_unregistered_fails(self):
        ent = _MyEntity()
        with self.assert_raises(ValueError) as cm:
            self._uow.unregister(_MyEntity, ent)
        msg = 'Trying to unregister an entity that has not been'
        self.assert_true(str(cm.exception).startswith(msg))

    def test_registered_with_other_uow_fails(self):
        ent = _MyEntity()
        uow = UnitOfWork()
        uow.register_new(_MyEntity, ent)
        with self.assert_raises(ValueError) as cm1:
            self._uow.register_new(_MyEntity, ent)
        msg1 = 'Trying to register an entity that has been'
        self.assert_true(str(cm1.exception).startswith(msg1))
        with self.assert_raises(ValueError) as cm2:
            self._uow.unregister(_MyEntity, ent)
        msg2 = 'Trying to unregister an entity that has been'
        self.assert_true(str(cm2.exception).startswith(msg2))

    def test_mark_deleted_as_dirty(self):
        ent = _MyEntity()
        self._uow.register_new(_MyEntity, ent)
        self._uow.mark_deleted(ent)
        with self.assert_raises(ValueError) as cm:
            self._uow.mark_dirty(ent)
        msg = 'Invalid status transition'
        self.assert_true(str(cm.exception).startswith(msg))

    def test_check_unregistered_is_marked_new(self):
        ent = _MyEntity()
        self.assert_false(self._uow.is_marked_new(ent))

    def test_mark_deleted_as_new(self):
        ent = _MyEntity(id=0)
        self._uow.register_deleted(_MyEntity, ent)
        self._uow.mark_new(ent)
Ejemplo n.º 2
0
class MemorySession(Session):
    """
    Session object.

    The session
     * Holds a Unit Of Work;
     * Serves as identity and slug map;
     * Performs synchronized commit on repository;
     * Sets up data manager to hook into transaction.
    """
    IS_MANAGING_BACKREFERENCES = True

    def __init__(self, repository, query_class=None, clone_on_load=True):
        self.__repository = repository
        self.__unit_of_work = UnitOfWork()
        self.__cache_map = {}
        if query_class is None:
            query_class = MemoryRepositoryQuery
        self.__query_class = query_class
        self.__clone_on_load = clone_on_load
        self.__needs_flushing = False
        self.__is_flushing = False

    def get_by_id(self, entity_class, entity_id):
        if self.__needs_flushing:
            self.flush()
        cache = self.__get_cache(entity_class)
        return cache.get_by_id(entity_id)

    def get_by_slug(self, entity_class, entity_slug):
        if self.__needs_flushing:
            self.flush()
        cache = self.__get_cache(entity_class)
        ents = cache.get_by_slug(entity_slug)
        #        # When the entity is not found in the cache, it may have been added
        #        # with an undefined slug; we therefore attempt to look it up in the
        #        # list of pending NEW entities.
        #        if ent is None:
        #            for new_ent in self.__unit_of_work.get_new(entity_class):
        #                if new_ent.slug == entity_slug:
        #                    ent = new_ent
        #                    break
        return ents

    def add(self, entity_class, data):
        self.__traverse(entity_class, data, None, RELATION_OPERATIONS.ADD)

    def remove(self, entity_class, data):
        self.__traverse(entity_class, None, data, RELATION_OPERATIONS.REMOVE)

    def update(self, entity_class, data, target=None):
        return self.__traverse(entity_class, data, target,
                               RELATION_OPERATIONS.UPDATE)

    def query(self, entity_class):
        if self.__needs_flushing:
            self.flush()
        return self.__query_class(entity_class, self, self.__repository)

    def flush(self):
        if self.__needs_flushing and not self.__is_flushing:
            self.__is_flushing = True
            with self.__repository.lock:
                self.__repository.flush(self.__unit_of_work)
            self.__is_flushing = False
            for ent_cls in self.__cache_map.keys():
                # The flush may have auto-generated IDs for NEW entities,
                # so we rebuild the cache.
                cache = self.__get_cache(ent_cls)
                cache.rebuild(self.__unit_of_work.get_new(ent_cls))
        self.__needs_flushing = False

    def begin(self):
        self.__unit_of_work.reset()

    def commit(self):
        with self.__repository.lock:
            self.__repository.commit(self.__unit_of_work)
        self.__unit_of_work.reset()
        self.__cache_map.clear()

    def rollback(self):
        with self.__repository.lock:
            self.__repository.rollback(self.__unit_of_work)
        self.__unit_of_work.reset()
        self.__cache_map.clear()

    def reset(self):
        self.rollback()

    def load(self, entity_class, entity):
        """
        Load the given repository entity into the session and return a
        clone. If it was already loaded before, look up the loaded entity
        and return it.

        All entities referenced by the loaded entity will also be loaded
        (and cloned) recursively.

        :raises ValueError: When an attempt is made to load an entity that
          has no ID
        """
        if self.__needs_flushing:
            self.flush()
        if entity.id is None:
            raise ValueError('Can not load entity without an ID.')
        cache = self.__get_cache(entity_class)
        sess_ent = cache.get_by_id(entity.id)
        if sess_ent is None:
            if self.__clone_on_load:
                sess_ent = self.__clone(entity, cache)
            else:  # Only needed by the nosql backend pragma: no cover
                cache.add(entity)
                sess_ent = entity
            self.__unit_of_work.register_clean(entity_class, sess_ent)
        return sess_ent

    @property
    def new(self):
        return self.__unit_of_work.get_new()

    @property
    def deleted(self):
        return self.__unit_of_work.get_deleted()

    def __contains__(self, entity):
        cache = self.__cache_map.get(type(entity))
        if not cache is None:
            found = entity in cache
        else:
            found = False
        return found

    def __traverse(self, entity_class, source_data, target_data, rel_op):
        agg = self.__repository.get_aggregate(entity_class)
        trv = SourceTargetDataTreeTraverser.make_traverser(source_data,
                                                           target_data,
                                                           rel_op,
                                                           accessor=agg)
        vst = AruVisitor(entity_class, self.__add, self.__remove,
                         self.__update)
        trv.run(vst)
        # Indicate that we need to flush the changes.
        self.__needs_flushing = True
        return vst.root

    def __add(self, entity):
        entity_class = type(entity)
        cache = self.__get_cache(entity_class)
        # We allow adding the same entity multiple times.
        if not (not entity.id is None
                and cache.get_by_id(entity.id) is entity):
            if not self.__unit_of_work.is_marked_deleted(entity):
                self.__unit_of_work.register_new(entity_class, entity)
                # FIXME: This is only necessary if the call above re-uses
                #        an existing state, in which case it needs to be
                #        marked as pending explicitly. Consider rewriting
                #        this whole method.
                self.__unit_of_work.mark_pending(entity)
                if not entity.id is None and cache.has_id(entity.id):
                    raise ValueError('Duplicate entity ID "%s".' % entity.id)
            else:
                if self.__unit_of_work.is_marked_pending(entity):
                    # The changes were not flushed yet; just mark as clean.
                    self.__unit_of_work.mark_clean(entity)
                else:
                    self.__unit_of_work.mark_new(entity)
                    self.__unit_of_work.mark_pending(entity)
            cache.add(entity)

    def __remove(self, entity):
        entity_class = type(entity)
        if not self.__unit_of_work.is_registered(entity):
            if entity.id is None:
                raise ValueError('Can not remove un-registered entity '
                                 'without an ID')
            self.__unit_of_work.register_deleted(entity_class, entity)
        elif not self.__unit_of_work.is_marked_new(entity):
            self.__unit_of_work.mark_deleted(entity)
        else:
            if self.__unit_of_work.is_marked_pending(entity):
                # The changes were not flushed yet; just mark as clean.
                self.__unit_of_work.mark_clean(entity)
            else:
                self.__unit_of_work.mark_deleted(entity)
                self.__unit_of_work.mark_pending(entity)
        cache = self.__get_cache(entity_class)
        if entity in cache:
            cache.remove(entity)

    def __update(self, source_data, target_entity):  # pylint: disable=W0613
        EntityState.set_state_data(target_entity, source_data)
        if self.__unit_of_work.is_marked_persisted(target_entity):
            self.__unit_of_work.mark_pending(target_entity)

    def __get_cache(self, entity_class):
        cache = self.__cache_map.get(entity_class)
        if cache is None:
            cache = self.__cache_map[entity_class] = EntityCache()
        return cache

    def __clone(self, entity, cache):
        clone = object.__new__(entity.__class__)
        # We add the clone with its ID set to the cache *before* we load it
        # so that circular references will work.
        clone.id = entity.id
        cache.add(clone)
        state = EntityState.get_state_data(entity)
        id_attr = None
        for attr, value in iteritems_(state):
            if attr.entity_attr == 'id':
                id_attr = attr
                continue
            attr_type = attr.attr_type
            if attr.kind != RESOURCE_ATTRIBUTE_KINDS.TERMINAL \
               and not self.__repository.is_registered_resource(attr_type):
                # Prevent loading of entities from other repositories.
                # FIXME: Doing this here is inconsistent, since e.g. the RDB
                #        session does not perform this kind of check.
                continue
            elif attr.kind == RESOURCE_ATTRIBUTE_KINDS.MEMBER \
               and not value is None:
                ent_cls = get_entity_class(attr_type)
                new_value = self.load(ent_cls, value)
                state[attr] = new_value
            elif attr.kind == RESOURCE_ATTRIBUTE_KINDS.COLLECTION \
                 and len(value) > 0:
                value_type = type(value)
                new_value = value_type.__new__(value_type)
                if issubclass(value_type, MutableSequence):
                    add_op = new_value.append
                elif issubclass(value_type, MutableSet):
                    add_op = new_value.add
                else:
                    raise ValueError('Do not know how to clone value of type '
                                     '%s for resource attribute %s.' %
                                     (type(new_value), attr))
                ent_cls = get_entity_class(attr_type)
                for child in value:
                    child_clone = self.load(ent_cls, child)
                    add_op(child_clone)
                state[attr] = new_value
        # We set the ID already above.
        if not id_attr is None:
            del state[id_attr]
        EntityState.set_state_data(clone, state)
        return clone
Ejemplo n.º 3
0
class UnitOfWorkTestCase(Pep8CompliantTestCase):
    def set_up(self):
        self._uow = UnitOfWork()

    def test_basics(self):
        ent = _MyEntity(id=0)
        self._uow.register_new(_MyEntity, ent)
        self.assert_equal(EntityState.get_state(ent).status, ENTITY_STATUS.NEW)
        self.assert_equal([item.entity for item in self._uow.iterator()],
                          [ent])
        self.assert_equal(list(self._uow.get_new(_MyEntity)), [ent])
        self._uow.mark_clean(ent)
        self.assert_equal(list(self._uow.get_clean(_MyEntity)), [ent])
        self._uow.mark_dirty(ent)
        self.assert_equal(list(self._uow.get_dirty(_MyEntity)), [ent])
        self._uow.mark_deleted(ent)
        self.assert_equal(list(self._uow.get_deleted(_MyEntity)), [ent])
        self._uow.unregister(_MyEntity, ent)
        self.assert_equal(list(self._uow.iterator()), [])
        self._uow.reset()

    def test_get_state_unregistered_fails(self):
        ent = _MyEntity()
        with self.assert_raises(ValueError) as cm:
            EntityState.get_state(ent)
        msg = 'Trying to obtain state for un-managed entity'
        self.assert_true(cm.exception.args[0].startswith(msg))

    def test_is_marked_unregistered(self):
        ent = _MyEntity()
        self.assert_raises(ValueError, self._uow.is_marked_persisted, ent)
        self.assert_raises(ValueError, self._uow.is_marked_pending, ent)

    def test_mark_unregistered_fails(self):
        ent = _MyEntity()
        with self.assert_raises(ValueError) as cm:
            self._uow.mark_dirty(ent)
        msg = 'Trying to obtain state for un-managed entity'
        self.assert_true(cm.exception.args[0].startswith(msg))

    def test_release_unregistered_fails(self):
        ent = _MyEntity()
        with self.assert_raises(ValueError) as cm:
            self._uow.unregister(_MyEntity, ent)
        msg = 'Trying to unregister an entity that has not been'
        self.assert_true(str(cm.exception).startswith(msg))

    def test_registered_with_other_uow_fails(self):
        ent = _MyEntity()
        uow = UnitOfWork()
        uow.register_new(_MyEntity, ent)
        with self.assert_raises(ValueError) as cm1:
            self._uow.register_new(_MyEntity, ent)
        msg1 = 'Trying to register an entity that has been'
        self.assert_true(str(cm1.exception).startswith(msg1))
        with self.assert_raises(ValueError) as cm2:
            self._uow.unregister(_MyEntity, ent)
        msg2 = 'Trying to unregister an entity that has been'
        self.assert_true(str(cm2.exception).startswith(msg2))

    def test_mark_deleted_as_dirty(self):
        ent = _MyEntity()
        self._uow.register_new(_MyEntity, ent)
        self._uow.mark_deleted(ent)
        with self.assert_raises(ValueError) as cm:
            self._uow.mark_dirty(ent)
        msg = 'Invalid status transition'
        self.assert_true(str(cm.exception).startswith(msg))

    def test_check_unregistered_is_marked_new(self):
        ent = _MyEntity()
        self.assert_false(self._uow.is_marked_new(ent))

    def test_mark_deleted_as_new(self):
        ent = _MyEntity(id=0)
        self._uow.register_deleted(_MyEntity, ent)
        self._uow.mark_new(ent)
Ejemplo n.º 4
0
class MemorySession(Session):
    """
    Session object.

    The session
     * Holds a Unit Of Work;
     * Serves as identity and slug map;
     * Performs synchronized commit on repository;
     * Sets up data manager to hook into transaction.
    """
    IS_MANAGING_BACKREFERENCES = True

    def __init__(self, repository, query_class=None, clone_on_load=True):
        self.__repository = repository
        self.__unit_of_work = UnitOfWork()
        self.__cache_map = {}
        if query_class is None:
            query_class = MemoryRepositoryQuery
        self.__query_class = query_class
        self.__clone_on_load = clone_on_load
        self.__needs_flushing = False
        self.__is_flushing = False

    def get_by_id(self, entity_class, entity_id):
        if self.__needs_flushing:
            self.flush()
        cache = self.__get_cache(entity_class)
        return cache.get_by_id(entity_id)

    def get_by_slug(self, entity_class, entity_slug):
        if self.__needs_flushing:
            self.flush()
        cache = self.__get_cache(entity_class)
        ents = cache.get_by_slug(entity_slug)
#        # When the entity is not found in the cache, it may have been added
#        # with an undefined slug; we therefore attempt to look it up in the
#        # list of pending NEW entities.
#        if ent is None:
#            for new_ent in self.__unit_of_work.get_new(entity_class):
#                if new_ent.slug == entity_slug:
#                    ent = new_ent
#                    break
        return ents

    def add(self, entity_class, data):
        self.__traverse(entity_class, data, None, RELATION_OPERATIONS.ADD)

    def remove(self, entity_class, data):
        self.__traverse(entity_class, None, data, RELATION_OPERATIONS.REMOVE)

    def update(self, entity_class, data, target=None):
        return self.__traverse(entity_class, data, target,
                               RELATION_OPERATIONS.UPDATE)

    def query(self, entity_class):
        if self.__needs_flushing:
            self.flush()
        return self.__query_class(entity_class, self, self.__repository)

    def flush(self):
        if self.__needs_flushing and not self.__is_flushing:
            self.__is_flushing = True
            with self.__repository.lock:
                self.__repository.flush(self.__unit_of_work)
            self.__is_flushing = False
            for ent_cls in self.__cache_map.keys():
                # The flush may have auto-generated IDs for NEW entities,
                # so we rebuild the cache.
                cache = self.__get_cache(ent_cls)
                cache.rebuild(self.__unit_of_work.get_new(ent_cls))
        self.__needs_flushing = False

    def begin(self):
        self.__unit_of_work.reset()

    def commit(self):
        with self.__repository.lock:
            self.__repository.commit(self.__unit_of_work)
        self.__unit_of_work.reset()
        self.__cache_map.clear()

    def rollback(self):
        with self.__repository.lock:
            self.__repository.rollback(self.__unit_of_work)
        self.__unit_of_work.reset()
        self.__cache_map.clear()

    def reset(self):
        self.rollback()

    def load(self, entity_class, entity):
        """
        Load the given repository entity into the session and return a
        clone. If it was already loaded before, look up the loaded entity
        and return it.

        All entities referenced by the loaded entity will also be loaded
        (and cloned) recursively.

        :raises ValueError: When an attempt is made to load an entity that
          has no ID
        """
        if self.__needs_flushing:
            self.flush()
        if entity.id is None:
            raise ValueError('Can not load entity without an ID.')
        cache = self.__get_cache(entity_class)
        sess_ent = cache.get_by_id(entity.id)
        if sess_ent is None:
            if self.__clone_on_load:
                sess_ent = self.__clone(entity, cache)
            else: # Only needed by the nosql backend pragma: no cover
                cache.add(entity)
                sess_ent = entity
            self.__unit_of_work.register_clean(entity_class, sess_ent)
        return sess_ent

    @property
    def new(self):
        return self.__unit_of_work.get_new()

    @property
    def deleted(self):
        return self.__unit_of_work.get_deleted()

    def __contains__(self, entity):
        cache = self.__cache_map.get(type(entity))
        if not cache is None:
            found = entity in cache
        else:
            found = False
        return found

    def __traverse(self, entity_class, source_data, target_data, rel_op):
        agg = self.__repository.get_aggregate(entity_class)
        trv = SourceTargetDataTreeTraverser.make_traverser(source_data,
                                                           target_data,
                                                           rel_op,
                                                           accessor=agg)
        vst = AruVisitor(entity_class,
                         self.__add, self.__remove, self.__update)
        trv.run(vst)
        # Indicate that we need to flush the changes.
        self.__needs_flushing = True
        return vst.root

    def __add(self, entity):
        entity_class = type(entity)
        cache = self.__get_cache(entity_class)
        # We allow adding the same entity multiple times.
        if not (not entity.id is None
                and cache.get_by_id(entity.id) is entity):
            if not self.__unit_of_work.is_marked_deleted(entity):
                self.__unit_of_work.register_new(entity_class, entity)
                # FIXME: This is only necessary if the call above re-uses
                #        an existing state, in which case it needs to be
                #        marked as pending explicitly. Consider rewriting
                #        this whole method.
                self.__unit_of_work.mark_pending(entity)
                if not entity.id is None and cache.has_id(entity.id):
                    raise ValueError('Duplicate entity ID "%s".' % entity.id)
            else:
                if self.__unit_of_work.is_marked_pending(entity):
                    # The changes were not flushed yet; just mark as clean.
                    self.__unit_of_work.mark_clean(entity)
                else:
                    self.__unit_of_work.mark_new(entity)
                    self.__unit_of_work.mark_pending(entity)
            cache.add(entity)

    def __remove(self, entity):
        entity_class = type(entity)
        if not self.__unit_of_work.is_registered(entity):
            if entity.id is None:
                raise ValueError('Can not remove un-registered entity '
                                 'without an ID')
            self.__unit_of_work.register_deleted(entity_class, entity)
        elif not self.__unit_of_work.is_marked_new(entity):
            self.__unit_of_work.mark_deleted(entity)
        else:
            if self.__unit_of_work.is_marked_pending(entity):
                # The changes were not flushed yet; just mark as clean.
                self.__unit_of_work.mark_clean(entity)
            else:
                self.__unit_of_work.mark_deleted(entity)
                self.__unit_of_work.mark_pending(entity)
        cache = self.__get_cache(entity_class)
        if entity in cache:
            cache.remove(entity)

    def __update(self, source_data, target_entity): # pylint: disable=W0613
        EntityState.set_state_data(target_entity, source_data)
        if self.__unit_of_work.is_marked_persisted(target_entity):
            self.__unit_of_work.mark_pending(target_entity)

    def __get_cache(self, entity_class):
        cache = self.__cache_map.get(entity_class)
        if cache is None:
            cache = self.__cache_map[entity_class] = EntityCache()
        return cache

    def __clone(self, entity, cache):
        clone = object.__new__(entity.__class__)
        # We add the clone with its ID set to the cache *before* we load it
        # so that circular references will work.
        clone.id = entity.id
        cache.add(clone)
        state = EntityState.get_state_data(entity)
        id_attr = None
        for attr, value in iteritems_(state):
            if attr.entity_attr == 'id':
                id_attr = attr
                continue
            attr_type = attr.attr_type
            if attr.kind != RESOURCE_ATTRIBUTE_KINDS.TERMINAL \
               and not self.__repository.is_registered_resource(attr_type):
                # Prevent loading of entities from other repositories.
                # FIXME: Doing this here is inconsistent, since e.g. the RDB
                #        session does not perform this kind of check.
                continue
            elif attr.kind == RESOURCE_ATTRIBUTE_KINDS.MEMBER \
               and not value is None:
                ent_cls = get_entity_class(attr_type)
                new_value = self.load(ent_cls, value)
                state[attr] = new_value
            elif attr.kind == RESOURCE_ATTRIBUTE_KINDS.COLLECTION \
                 and len(value) > 0:
                value_type = type(value)
                new_value = value_type.__new__(value_type)
                if issubclass(value_type, MutableSequence):
                    add_op = new_value.append
                elif issubclass(value_type, MutableSet):
                    add_op = new_value.add
                else:
                    raise ValueError('Do not know how to clone value of type '
                                     '%s for resource attribute %s.'
                                     % (type(new_value), attr))
                ent_cls = get_entity_class(attr_type)
                for child in value:
                    child_clone = self.load(ent_cls, child)
                    add_op(child_clone)
                state[attr] = new_value
        # We set the ID already above.
        if not id_attr is None:
            del state[id_attr]
        EntityState.set_state_data(clone, state)
        return clone