-
Notifications
You must be signed in to change notification settings - Fork 0
/
local.py
executable file
·263 lines (217 loc) · 7.8 KB
/
local.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
#!/usr/bin/env python3
import asyncio
import ssl
import logging
import os
import sys
import json
from hashlib import sha512
from common import nyapass_run, ConnectionHandler
from signature import sign_headers, unsign_headers, SignatureError
log = logging.getLogger("nyapass")
class ClientHandler(ConnectionHandler):
def __init__(self, divert_cache, ssl_ctx, manager, *args, **kwargs):
super().__init__(*args, **kwargs)
self._divert_cache = divert_cache
self.buffer_request_body = self.config.divert_banned_requests
self._ssl_ctx = ssl_ctx
self._manager = manager
def reset_request(self):
super().reset_request()
self._local_headers_unsigned = None
self._is_diverted_request = False
self.default_remote = (
self.config.server_host,
self.config.server_port,
)
@property
def should_divert(self):
if not self.config.divert_banned_requests:
return False
if self._is_diverted_request:
return True
if not self._local_headers:
return False
host = self.request_hostname.lower()
if host in self._divert_cache:
return True
if not self._remote_headers:
return False
ret = self.response_code == 410 and \
b"\r\nX-Nyapass-Status: banned\r\n" in self._remote_headers
if ret:
self._divert_cache[host] = True
return ret
@asyncio.coroutine
def divert_request(self):
assert not self._is_diverted_request
self.debug("Diverting request")
self._is_diverted_request = True
if self._local_headers_unsigned:
self._local_headers = self._local_headers_unsigned
yield from self.prepare_standalone_request()
@asyncio.coroutine
def process_request(self):
if self.should_divert:
yield from self.divert_request()
return
self._local_headers_unsigned = self._local_headers
self._local_headers = sign_headers(
self.config,
self._local_headers,
)
@asyncio.coroutine
def process_response(self):
if self._is_diverted_request:
return
try:
self._remote_headers = unsign_headers(
self.config,
self._remote_headers,
)
except SignatureError:
self.critical(
"Failed to verify response signature, server may be "
"improperly configured or we are experiencing MITM attack"
)
self.dump_info()
sys.exit(1)
if self.should_divert:
fut = self._send_request_future
fut.cancel()
yield from asyncio.sleep(0)
assert fut.done()
self.destroy_remote_connection()
self._remote_headers = None
yield from self.divert_request()
if self.request_has_body and self._reader.too_much_data:
yield from self.respond_and_close(
code=503,
status="Service Unavailable",
body="Diverting request, please retry.",
)
assert False # Unreachable
yield from self.ensure_remote_connection()
assert self._send_request_future == fut
self._send_request_future = None
self.begin_send_request()
yield from self.read_remote_headers()
@asyncio.coroutine
def connect_to_remote(self, remote, **kwargs):
if not self._is_diverted_request:
kwargs["ssl"] = self._ssl_ctx
ret = yield from super().connect_to_remote(
remote,
**kwargs
)
if not self._is_diverted_request:
self._manager.validate_remote(remote, ret[1])
return ret
class ClientHandlerManager:
def __init__(self, config, handler_cls=ClientHandler):
self.log = log.getChild(self.__class__.__name__)
self.config = config
self._handler_cls = handler_cls
self._ssl_ctx = create_ssl_ctx(config)
self._divert_cache = {}
self._known_hosts = {}
if self.config.known_hosts_file:
self.config.known_hosts_file = os.path.expanduser(
self.config.known_hosts_file,
)
self._read_known_hosts()
def __call__(self, *args, **kwargs):
return self._handler_cls(
*args,
divert_cache=self._divert_cache,
ssl_ctx=self._ssl_ctx,
manager=self,
**kwargs
)
@property
def __name__(self):
return "%s(%s)" % (
self.__class__.__name__,
getattr(
self._handler_cls,
"__name__",
str(self._handler_cls),
),
)
def _read_known_hosts(self):
if not os.path.isfile(self.config.known_hosts_file):
return
try:
with open(self.config.known_hosts_file, "r") as f:
self._known_hosts.update(json.load(f))
except Exception as e:
self.log.warning("Failed to load saved known hosts: %s", e)
def _write_known_hosts(self):
if not self.config.known_hosts_file:
return
try:
dir = os.path.dirname(self.config.known_hosts_file)
if not os.path.isdir(dir):
os.makedirs(dir)
with open(self.config.known_hosts_file, "w") as f:
json.dump(self._known_hosts, f)
except Exception as e:
self.log.warning("Failed to save known hosts: %s", e)
def get_remote_cert(self, writer):
sslobj = writer.get_extra_info("ssl_object") # Python 3.5.1+
if not sslobj:
sslobj = writer.get_extra_info("socket") # Python 3.4.x?
assert sslobj
if not hasattr(sslobj, "getpeercert"):
# Python 3.5.0, no public way to get this, so we have to...
sslobj = writer.transport._ssl_protocol._sslpipe.ssl_object
return sslobj.getpeercert(True)
def validate_remote(self, remote, remote_writer):
if not self.config.pin_server_cert:
return
self.validate_host_cert(
"%s:%s" % remote,
self.get_remote_cert(remote_writer),
)
def validate_host_cert(self, host, cert):
assert cert
cert_hash = sha512(cert).hexdigest()
if host not in self._known_hosts:
self.log.info(
"Adding %s to known hosts (fingerprint: %s)",
host, cert_hash,
)
self._known_hosts[host] = cert_hash
self._write_known_hosts()
elif self._known_hosts[host] != cert_hash:
self.log.critical(
"Certificate of %s has changed (old = %s, new = %s). "
"If you haven't changed your certificate recently, "
"this probably means that someone is MITMing us. "
"If you believe that %s is safe, "
"delete entry of %s in %s and restart nyapass.",
host, self._known_hosts[host], cert_hash,
host, host, self.config.known_hosts_file,
)
sys.exit(1)
def create_ssl_ctx(config):
ctx = ssl.create_default_context(
purpose=ssl.Purpose.SERVER_AUTH,
)
ctx.options = (
ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 |
ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
)
if not config.ssl_verify:
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
return ctx
def main(config, handler_cls=ClientHandler):
logging.basicConfig(level=config.log_level)
nyapass_run(
handler_factory=ClientHandlerManager(
config,
handler_cls=handler_cls,
),
config=config,
)