コード例 #1
0
ファイル: storing.py プロジェクト: b8va/everest
 def __collect(self, resource):
     ent_cls = get_entity_class(resource)
     coll_cls = get_collection_class(resource)
     cache = EntityCacheMap()
     agg = StagingAggregate(ent_cls, cache)
     coll = coll_cls.create_from_aggregate(agg)
     coll.add(resource)
     return dict([(get_member_class(ent_cls),
                   coll.get_root_collection(ent_cls))
                  for ent_cls in cache.keys()])
コード例 #2
0
 def __collect(self, resource):
     ent_cls = get_entity_class(resource)
     coll_cls = get_collection_class(resource)
     cache = EntityCacheMap()
     agg = StagingAggregate(ent_cls, cache)
     coll = coll_cls.create_from_aggregate(agg)
     coll.add(resource)
     return dict([(get_member_class(ent_cls),
                   coll.get_root_collection(ent_cls))
                  for ent_cls in cache.keys()])
コード例 #3
0
ファイル: staging.py プロジェクト: papagr/everest
 def __init__(self, entity_class, cache=None):
     Aggregate.__init__(self)
     self.entity_class = entity_class
     if cache is None:
         cache = EntityCacheMap()
     self.__cache_map = cache
     self.__visitor = AruVisitor(entity_class, self.__add,
                                 self.__remove, self.__update)
コード例 #4
0
 def test_basics(self):
     ecm = EntityCacheMap()
     ent = MyEntity(id=0)
     ecm.add(MyEntity, ent)
     assert ecm.has_key(MyEntity)
     assert ecm[MyEntity].get_by_id(ent.id) == ent
     assert ent in ecm
     assert list(ecm.keys()) == [MyEntity]
     ecm.remove(MyEntity, ent)
     assert not ent in ecm
コード例 #5
0
ファイル: repository.py プロジェクト: b8va/everest
 def __init__(self, name, aggregate_class=None,
              join_transaction=False, autocommit=False):
     if aggregate_class is None:
         aggregate_class = MemoryAggregate
     Repository.__init__(self, name, aggregate_class,
                         join_transaction=join_transaction,
                         autocommit=autocommit)
     self.__cache_map = EntityCacheMap()
     # By default, we do not use a cache loader.
     self.configure(cache_loader=None)
コード例 #6
0
 def test_traverse_with_remove_sequence(self):
     ent0 = create_entity(entity_id=0)
     ent1 = create_entity(entity_id=None)
     cache = EntityCacheMap()
     agg = StagingAggregate(MyEntity, cache=cache)
     agg.add(ent0)
     agg.add(ent1)
     trv = SourceTargetDataTreeTraverser.make_traverser(
         None, [ent0, ent1], RELATION_OPERATIONS.REMOVE)
     vst = AruVisitor(MyEntity, remove_callback=agg.remove)
     trv.run(vst)
     self.assert_equal(len(list(iter(agg))), 0)
コード例 #7
0
ファイル: test_memory_cache.py プロジェクト: b8va/everest
 def test_basics(self):
     ecm = EntityCacheMap()
     ent = MyEntity(id=0)
     ecm.add(MyEntity, ent)
     self.assert_equal(ecm[MyEntity].get_by_id(0), ent)
     self.assert_true(ent in ecm)
     self.assert_equal(list(ecm.keys()), [MyEntity])
     ecm.remove(MyEntity, ent)
     self.assert_false(ent in ecm)
コード例 #8
0
ファイル: repository.py プロジェクト: b8va/everest
class MemoryRepository(Repository):
    """
    A repository that caches entities in memory.
    """
    _configurables = Repository._configurables \
                     + ['cache_loader']

    lock = RLock()

    def __init__(self, name, aggregate_class=None,
                 join_transaction=False, autocommit=False):
        if aggregate_class is None:
            aggregate_class = MemoryAggregate
        Repository.__init__(self, name, aggregate_class,
                            join_transaction=join_transaction,
                            autocommit=autocommit)
        self.__cache_map = EntityCacheMap()
        # By default, we do not use a cache loader.
        self.configure(cache_loader=None)

    def retrieve(self, entity_class, filter_expression=None,
                 order_expression=None, slice_key=None):
        cache = self.__get_cache(entity_class)
        return cache.retrieve(filter_expression=filter_expression,
                              order_expression=order_expression,
                              slice_key=slice_key)

    def flush(self, unit_of_work):
        for state in unit_of_work.iterator():
            if state.is_persisted:
                continue
            else:
                self.__persist(state)
                unit_of_work.mark_persisted(state.entity)

    def commit(self, unit_of_work):
        self.flush(unit_of_work)

    def rollback(self, unit_of_work):
        for state in unit_of_work.iterator():
            if state.is_persisted:
                self.__rollback(state)

    def __persist(self, state):
        source_entity = state.entity
        cache = self.__get_cache(type(source_entity))
        status = state.status
        if status == ENTITY_STATUS.NEW:
            # Autogenerate new ID.
            if source_entity.id is None:
                source_entity.id = new_entity_id()
            cache.add(source_entity)
        else:
            target_entity = cache.get_by_id(source_entity.id)
            if target_entity is None:
                raise ValueError('Could not persist data - target entity not '
                                 'found (ID used for lookup: %s).'
                                 % source_entity.id)
            if status == ENTITY_STATUS.DELETED:
                cache.remove(target_entity)
            elif status == ENTITY_STATUS.DIRTY:
                cache.update(state.data, target_entity)

    def __rollback(self, state):
        source_entity = state.entity
        cache = self.__get_cache(type(source_entity))
        if state.status == ENTITY_STATUS.DELETED:
            cache.add(source_entity)
        else:
            if state.status == ENTITY_STATUS.NEW:
                cache.remove(source_entity)
            elif state.status == ENTITY_STATUS.DIRTY:
                target_entity = cache.get_by_id(source_entity.id)
                cache.update(state.clean_data, target_entity)

    def _initialize(self):
        pass

    def _make_session_factory(self):
        return MemorySessionFactory(self)

    def __get_cache(self, entity_class):
        run_loader = not entity_class in self.__cache_map
        if run_loader:
            is_top_level = len(self.__cache_map.keys()) == 0
            self.__load_entities(entity_class, is_top_level)
        return self.__cache_map[entity_class]

    def __load_entities(self, entity_class, is_top_level):
        # Check if we have an entity loader configured.
        loader = self.configuration['cache_loader']
        if not loader is None:
            cache = self.__cache_map[entity_class]
            for ent in loader(entity_class):
                if ent.id is None:
                    ent.id = new_entity_id()
                cache.add(ent)
            # To fully initialize the cache, we also need to load collections
            # that are not linked to from any of the entities just loaded.
            if is_top_level:
                for reg_rc in self.registered_resources:
                    reg_ent_cls = get_entity_class(reg_rc)
                    if not reg_ent_cls in self.__cache_map:
                        self.__load_entities(reg_ent_cls, False)