forked from riffm/testalchemy
-
Notifications
You must be signed in to change notification settings - Fork 1
/
testalchemy.py
209 lines (171 loc) · 7.25 KB
/
testalchemy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
# -*- coding: utf-8 -*-
import types
from sqlalchemy import event
from sqlalchemy.orm import util, Session, ScopedSession
__all__ = ['Sample', 'Restorable', 'DBHistory']
class sample_property(object):
def __init__(self, method, name=None):
self.method = method
self.__doc__ = method.__doc__
self.name = name or method.__name__
def __get__(self, inst, cls):
if inst is None:
return self
result = self.method(inst)
if isinstance(result, (list, tuple)):
inst.db.add_all(result)
else:
inst.db.add(result)
inst.used_properties.add(self.name)
setattr(inst, self.name, result)
return result
def __call__(self, obj):
return self.method(obj)
class Sample(object):
class __metaclass__(type):
def __new__(cls, cls_name, bases, attributes):
self = type.__new__(cls, cls_name, bases, attributes)
for name in dir(self):
if name.startswith('_') or name == 'create_all':
continue
value = getattr(self, name)
if isinstance(value, types.MethodType):
new_value = value.im_func
# already decorated attribute, assigned from another class
elif isinstance(value, sample_property) and name!= value.name:
new_value = value.method
# classmethod, staticmethod and etc
else:
continue
setattr(self, name, sample_property(new_value, name=name))
return self
def __init__(self, db, **kwargs):
if isinstance(db, ScopedSession):
db = db.registry()
self.db = db
self.used_properties = set()
self.__dict__.update(kwargs)
def create_all(self):
if self.db.autocommit:
self.db.begin()
map(lambda name: getattr(self, name), dir(self))
self.db.commit()
class Restorable(object):
def __init__(self, db, watch=None):
if isinstance(db, ScopedSession):
db = db.registry()
self.db = db
self.watch = watch or db
self.history = {}
def __enter__(self):
event.listen(self.watch, 'after_flush', self.after_flush)
def __exit__(self, type, value, traceback):
db = self.db
db.rollback()
db.expunge_all()
old_autoflush = db.autoflush
db.autoflush = False
if db.autocommit:
db.begin()
for cls, ident_set in self.history.items():
for ident in ident_set:
instance = db.query(cls).get(ident)
if instance is not None:
db.delete(instance)
db.commit()
db.close()
db.autoflush = old_autoflush
event.Events._remove(self.watch, 'after_flush',
self.after_flush)
def after_flush(self, db, flush_context, instances=None):
for instance in db.new:
cls, ident = util.identity_key(instance=instance)
self.history.setdefault(cls, set()).add(ident)
class DBHistory(object):
def __init__(self, session):
assert isinstance(session, (Session, ScopedSession))
self.session = session
#XXX: It is not clear do we need events on class or object
self._target = session
if isinstance(session, ScopedSession):
self._target = session.registry()
self.created = set()
self.deleted = set()
self.updated = set()
self.created_idents = {}
self.updated_idents = {}
self.deleted_idents = {}
def last(self, model_cls, mode):
assert mode in ('created', 'updated', 'deleted')
if mode == 'deleted':
# Because there is not data in DB we return detached object set.
return set([inst for inst in self.deleted \
if isinstance(inst, model_cls)])
idents = getattr(self, '%s_idents' % mode).get(model_cls, set())
return set([self.session.query(model_cls).get(ident) \
for ident in idents])
def last_created(self, model_cls):
return self.last(model_cls, 'created')
def last_updated(self, model_cls):
return self.last(model_cls, 'updated')
def last_deleted(self, model_cls):
return self.last(model_cls, 'deleted')
def assert_(self, model_cls, ident=None, mode='created'):
dataset = self.last(model_cls, mode)
error_msg = 'No instances of %s were %s' % (model_cls, mode)
assert dataset, error_msg
if ident is not None:
ident = ident if isinstance(ident, (tuple, list)) else (ident,)
item = [i for i in dataset \
if util.identity_key(instance=i)[1] == ident]
assert item,'No insatances of %s with identity %r were %s' % \
(model_cls, ident, mode)
return item[0]
return dataset
def assert_created(self, model_cls, ident=None):
return self.assert_(model_cls, ident, 'created')
def assert_updated(self, model_cls, ident=None):
return self.assert_(model_cls, ident, 'updated')
def assert_deleted(self, model_cls, ident=None):
return self.assert_(model_cls, ident, 'deleted')
def assert_one(self, dataset, model_cls, mode):
if len(dataset) != 1:
raise AssertionError('%d instance(s) of %s %s, '
'need only one' % (len(dataset),
model_cls,
mode))
return dataset.pop()
def assert_created_one(self, model_cls):
result = self.assert_created(model_cls)
return self.assert_one(result, model_cls, 'created')
def assert_deleted_one(self, model_cls):
result = self.assert_deleted(model_cls)
return self.assert_one(result, model_cls, 'deleted')
def assert_updated_one(self, model_cls):
result = self.assert_updated(model_cls)
return self.assert_one(result, model_cls, 'updated')
def clear(self):
self.created = set()
self.deleted = set()
self.updated = set()
self.created_idents = {}
self.updated_idents = {}
self.deleted_idents = {}
def __enter__(self):
event.listen(self._target, 'after_flush', self._after_flush)
return self
def __exit__(self, type, value, traceback):
event.Events._remove(self._target, 'after_flush', self._after_flush)
def _populate_idents_dict(self, idents, objects):
for obj in objects:
ident = util.identity_key(instance=obj)
idents.setdefault(ident[0], set()).add(ident[1])
def _after_flush(self, db, flush_context, instances=None):
def identityset_to_set(obj):
return set(obj._members.values())
self.created = self.created.union(identityset_to_set(db.new))
self.updated = self.updated.union(identityset_to_set(db.dirty))
self.deleted = self.deleted.union(identityset_to_set(db.deleted))
self._populate_idents_dict(self.created_idents, self.created)
self._populate_idents_dict(self.updated_idents, self.updated)
self._populate_idents_dict(self.deleted_idents, self.deleted)