/
stupidsessions.py
115 lines (95 loc) · 3.39 KB
/
stupidsessions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from flask import Flask, Session
from html5lib import HTMLParser, serialize
from html5lib.treebuilders.simpletree import Element
from werkzeug import url_quote_plus
_parser = HTMLParser()
class StupidSessionMixin(object):
session_url_key = 'FLASKSESSION'
session_url_rewrite_map = {
'a': ['href'],
'img': ['src'],
'script': ['src'],
'link': ['href']
}
def open_session(self, request):
key = self.secret_key
if key is not None:
value = request.values.get(self.session_url_key, '')
return Session.unserialize(value, key)
def save_session(self, session, response):
# we only support html
if response.mimetype == 'text/html':
response.data = self._inject_session(session, response.data)
# handle redirects
if 'location' in response.headers:
response.headers['Location'] = self._rewrite_session_url(
response.headers['location'], session.serialize())
def _rewrite_session_url(self, url, sess):
return '%s%s%s=%s' % (
url,
'?' in url and '&' or '?',
self.session_url_key,
url_quote_plus(sess)
)
def _inject_session(self, session, html):
serialized = session.serialize()
def _walk(node):
for child in node.childNodes:
_walk(child)
if node.name in self.session_url_rewrite_map:
for attr in self.session_url_rewrite_map[node.name]:
value = node.attributes.get(attr)
if value is None:
continue
new_value = self._rewrite_session_url(value, serialized)
node.attributes[attr] = new_value
elif node.name == 'form':
hidden = Element('input')
hidden.attributes.update(
type='hidden',
name=self.session_url_key,
value=serialized
)
node.childNodes.append(hidden)
tree = _parser.parse(html)
_walk(tree)
return serialize(tree)
class StupidSessionFlask(StupidSessionMixin, Flask):
pass
def testapp():
from flask import request, session, g, escape, redirect, url_for
app = StupidSessionFlask(__name__)
app.secret_key = 'testing'
@app.before_request
def pull_user():
g.user = session.get('username')
@app.route('/')
def index():
if g.user is not None:
return '''
<p>You are logged in as %s.
<p><a href=/logout>Logout</a>
''' % escape(g.user)
return 'You are not logged in. <a href=/login>Login</a>'
@app.route('/login', methods=['GET', 'POST'])
def login():
if request.method == 'POST':
username = request.form['username']
if username:
session['username'] = username
return redirect(url_for('index'))
return '''
<form action="" method=post>
<p>Username:
<input type=text name=username>
<input type=submit value=Login>
</form>
'''
@app.route('/logout')
def logout():
session['username'] = None
return redirect(url_for('index'))
return app
if __name__ == '__main__':
app = testapp()
app.run(debug=True)