/
flask_halalchemy.py
207 lines (162 loc) · 7.32 KB
/
flask_halalchemy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
# encoding=utf-8
from flask import request, url_for, current_app, make_response
from flask.helpers import json
from flask.views import MethodView
from werkzeug.utils import cached_property
from dictshield.base import ShieldDocException
from dictshield.document import Document
_error_response_headers = {'Content-Type': 'application/json'}
class FormView(MethodView):
"""
Validates form API requests. Subclass and add the form fields you wish to
validate against. PATCH validates partial updates whereas POST validates
that all required fields are present.
`fields` is a mapping of exposed field names and dictshield The values are
instances of `dictshield.fields.BaseField` to validate against.
"""
fields = {}
validate_on_methods = ['POST', 'PATCH', 'PUT']
def __init__(self, document=None):
if document is None:
cls_name = self.__class__.__name__ + "Document"
self.document = type(cls_name, (Document, ), self.fields)
else:
if not issubclass(document, Document):
raise TypeError("Form documents must be instances of `dictshield.document.Document`")
self.document = document
@cached_property
def data(self):
return request.json or request.form.to_dict()
@cached_property
def clean(self):
return self.document.make_ownersafe(self.document(**self.data).to_python())
def validate(self):
"""
Sets an error attribute with a `field_name`: message dictionary.
Returns `True` if valid and `False` if `errors` is non-empty.
For some fucked up reason dictshield has completely different ways to
validate partial and object integrity updates.
"""
if request.method == "PATCH":
# Allow partial documents when PATCH’ing
validate = self.document.validate_class_partial
self.errors = validate(self.data, validate_all=True)
else:
try:
self.document(**self.data).validate(validate_all=True)
except ShieldDocException, e:
self.errors = e.error_list
else:
self.errors = None
return not bool(self.errors)
def error_response(self):
"""
Return a basic application/json response with status code 422 to inform
the consumer of validation errors in the form request.
"""
errors = dict([(e.field_name, e.reason) for e in self.errors]) # TODO what about multiple errors per field
content = json.dumps(dict(message="Validation error", errors=errors))
return make_response(content, 422, _error_response_headers)
def dispatch_request(self, *args, **kwargs):
if request.method in self.validate_on_methods and not self.validate():
return self.error_response()
return super(FormView, self).dispatch_request(*args, **kwargs)
def schema_response(self):
"""Return a schema+json response for the document. """
return self.document.to_jsonschema(), 200, {
'Content-Type': 'application/schema+json',
'Accept': 'application/json; charset=utf-8'}
class QueryView(MethodView):
"""
Add `url_kwargs` to the view class instance. The HTTP method class methods
do *not* receive the args and kwargs from the Route.
"""
def dispatch_request(self, *args, **kwargs):
self.url_kwargs = kwargs
return super(QueryView, self).dispatch_request()
class ResourceView(QueryView):
content_type = 'application/hal+json'
def get_url(self):
if hasattr(self, "url"):
return self.url
return request.path
def links(self):
links = [{'self': {'href': self.get_url()}}]
if callable(getattr(self.query(), "links", None)):
links += self.query().links()
return links
@property
def json(self):
return dict(_links=self.links(), **self.query().json)
def get(self):
return json.dumps(self.json), 200, {'Content-Type': self.content_type}
@classmethod
def as_resource(cls, endpoint, model_instance=None):
# Instantiate from endpoint and object. Traverse the app url_map and
# find a best match for the subresource URL.
def get_url_kwargs():
for rule in current_app.url_map._rules_by_endpoint[endpoint]:
if 'GET' in rule.methods and rule.arguments:
for arg in rule.arguments:
if hasattr(model_instance, arg):
yield arg, getattr(model_instance, arg)
raise StopIteration()
self = cls()
self.url_kwargs = dict(get_url_kwargs())
self.url = url_for(endpoint, **self.url_kwargs)
if model_instance is not None:
# Avoid n+1 querying by settings `query` to the instance
self.query = lambda: model_instance
return self
class IndexView(QueryView):
"""
Paginated resources. Uses `?page=<int>` URL argument. Route this view like
so:
workout_resource = ResourceView.as_view(Workout, 'workout')
workout_index = IndexView.as_view(Workout, 'workouts', resource=workout_resource)
app.add_url_rule('/workouts/<int:id>', workout_resource, methods=['GET'])
app.add_url_rule('/workouts', workout_index, methods=['GET'])
Notice that a `workout_resource` was created first. This is cleaner since
HAL embeds subresources and we can generate a HAL compliant structure for
this index.
It might be a good idea to order to `query` to get predictable results.
"""
content_type = 'application/hal+json'
per_page = 40
def __init__(self, subresource_endpoint=None):
self.subresource_endpoint = subresource_endpoint
@property
def json(self):
return {'total': self.page.total, 'per_page': self.page.per_page}
def query(self):
raise NotImplementedError()
def links(self):
view_name = request.url_rule.endpoint
_links = {'self': {'href': url_for(view_name)}}
if self.page.pages > 0:
if self.page.page == self.page.pages:
_links['last'] = _links['self']
else:
_links['last'] = {'href': url_for(view_name, page=self.page.pages)}
if self.page.has_next:
_links['next'] = {'href': url_for(view_name, page=self.page.next_num)}
if self.page.has_prev:
_links['previous'] = {'href': url_for(view_name, page=self.page.prev_num)}
return _links
def embedded(self):
endpoint = self.subresource_endpoint
if endpoint is None:
get_json = lambda o: o.json
else:
resource = current_app.view_functions[endpoint].view_class
get_json = lambda o: resource.as_resource(endpoint, o).json
return [get_json(item) for item in self.page.items]
def get(self):
page_num = int(request.args.get('page', 1))
per_page = int(request.args.get('per_page', self.per_page))
per_page = min(per_page, self.per_page) # Upper limit
self.page = self.query().paginate(page_num, per_page=per_page)
content = json.dumps(dict(
_embedded={request.url_rule.endpoint: self.embedded()},
_links=self.links(), **self.json))
return content, 200, {'Content-Type': self.content_type}