def test_size_check_ignores_files(self): file_data = ( b"FAKEPDFBYTESGOHERETHISISREALLYLONGBUTNOTUSEDTOCOMPUTETHESIZEOFTHEREQUEST" ) body = ( b"--BOUNDARY\r\n" + b'Content-Disposition: form-data; name="' + b"Title" + b'"\r\n\r\n' + b"My first book" + b"\r\n" + b"--BOUNDARY\r\n" + b'Content-Disposition: form-data; name="pdf"; filename="book.pdf"\r\n\r\n' + file_data + b"--BOUNDARY--") scope = { "http_version": "1.1", "method": "POST", "path": "/test/", "headers": { "content-type": b"multipart/form-data; boundary=BOUNDARY", "content-length": str(len(body)).encode("latin1"), }, } with override_settings(DATA_UPLOAD_MAX_MEMORY_SIZE=1): with pytest.raises(RequestDataTooBig): AsgiRequest(scope, BytesIO(body)).POST smaller_than_file_data_size = len(file_data) - 20 with override_settings( DATA_UPLOAD_MAX_MEMORY_SIZE=smaller_than_file_data_size): # There is no exception, since the setting should not take into # account the size of the file upload data. AsgiRequest(scope, BytesIO(body)).POST
def get_scope_user(scope): if "_cached_user" not in scope: # We need to fake a request so the auth code works scope['method'] = "FAKE" from channels.http import AsgiRequest fake_request = AsgiRequest(scope, b'') fake_request.session = scope["session"] scope["_cached_user"] = _get_user(fake_request) return scope["_cached_user"]
def __call__(self, scope): """Call the middleware.""" def authenticate(request): """ Authenticate the request. Returns a tuple containing the user and their access token. If it's not valid then None is returned. """ oauthlib_core = get_oauthlib_core() valid, r = oauthlib_core.verify_request(request, scopes=[]) if valid: return r.user, r.access_token return None, None if scope.get('user') and scope['user'] != AnonymousUser: # We already have an authenticated user return self.inner(scope) if "method" not in scope: scope['method'] = "FAKE" request = AsgiRequest(scope, b"") if request.META.get("HTTP_AUTHORIZATION"): user, _ = authenticate(request) if user: scope['user'] = user return self.inner(scope)
def secure(self): # use django middleware to get session and authenticate on initial ws connection django_middlewhere = [ 'django.contrib.sessions.middleware.SessionMiddleware', 'django.contrib.auth.middleware.AuthenticationMiddleware', # need to check csrf but unable to set that in the http request initiating ws, probably set it in query ] # build a django request from ws request/scope self.scope['method'] = 'WEBSOCKET' request = AsgiRequest(self.scope, '') # get channel's django middleware middleware = [ import_string(m)(lambda x: None) for m in django_middlewhere ] # make sure ws request passes channel's django middleware for m in middleware: if m.process_request(request): raise DenyConnection() # set session and user if hasattr(request, 'session'): self.session = request.session if hasattr(request, 'user'): self.user = request.user # deny if we don't have a session and an authenticated user if not self.session or not self.user or not self.user.is_authenticated: raise DenyConnection()
def test_stream(self): """ Tests the body stream is emulated correctly. """ request = AsgiRequest( { "http_version": "1.1", "method": "PUT", "path": "/", "headers": {"host": b"example.com", "content-length": b"11"}, }, BytesIO(b"onetwothree"), ) self.assertEqual(request.method, "PUT") self.assertEqual(request.read(3), b"one") self.assertEqual(request.read(), b"twothree")
def test_stream(self): """ Tests the body stream is emulated correctly. """ request = AsgiRequest( { "http_version": "1.1", "method": "PUT", "path": "/", "headers": {"host": b"example.com", "content-length": b"11"}, }, b"onetwothree", ) self.assertEqual(request.method, "PUT") self.assertEqual(request.read(3), b"one") self.assertEqual(request.read(), b"twothree")
def test_post(self): """ Tests a POST body. """ request = AsgiRequest( { "http_version": "1.1", "method": "POST", "path": "/test2/", "query_string": "django=great", "headers": { "host": b"example.com", "content-type": b"application/x-www-form-urlencoded", "content-length": b"18", }, }, b"djangoponies=are+awesome", ) self.assertEqual(request.path, "/test2/") self.assertEqual(request.method, "POST") self.assertEqual(request.body, b"djangoponies=are+awesome") self.assertEqual(request.META["HTTP_HOST"], "example.com") self.assertEqual(request.META["CONTENT_TYPE"], "application/x-www-form-urlencoded") self.assertEqual(request.GET["django"], "great") self.assertEqual(request.POST["djangoponies"], "are awesome") with self.assertRaises(KeyError): request.POST["django"] with self.assertRaises(KeyError): request.GET["djangoponies"]
def test_basic(self): """ Tests that the handler can decode the most basic request message, with all optional fields omitted. """ request = AsgiRequest( { "http_version": "1.1", "method": "GET", "path": "/test/", }, b"", ) self.assertEqual(request.path, "/test/") self.assertEqual(request.method, "GET") self.assertFalse(request.body) self.assertNotIn("HTTP_HOST", request.META) self.assertNotIn("REMOTE_ADDR", request.META) self.assertNotIn("REMOTE_HOST", request.META) self.assertNotIn("REMOTE_PORT", request.META) self.assertIn("SERVER_NAME", request.META) self.assertIn("SERVER_PORT", request.META) self.assertFalse(request.GET) self.assertFalse(request.POST) self.assertFalse(request.COOKIES)
def test_extended(self): """ Tests a more fully-featured GET request """ request = AsgiRequest( { "http_version": "1.1", "method": "GET", "path": "/test2/", "query_string": b"x=1&y=%26foo+bar%2Bbaz", "headers": { "host": b"example.com", "cookie": b"test-time=1448995585123; test-value=yeah", }, "client": ["10.0.0.1", 1234], "server": ["10.0.0.2", 80], }, b"", ) self.assertEqual(request.path, "/test2/") self.assertEqual(request.method, "GET") self.assertFalse(request.body) self.assertEqual(request.META["HTTP_HOST"], "example.com") self.assertEqual(request.META["REMOTE_ADDR"], "10.0.0.1") self.assertEqual(request.META["REMOTE_HOST"], "10.0.0.1") self.assertEqual(request.META["REMOTE_PORT"], 1234) self.assertEqual(request.META["SERVER_NAME"], "10.0.0.2") self.assertEqual(request.META["SERVER_PORT"], "80") self.assertEqual(request.GET["x"], "1") self.assertEqual(request.GET["y"], "&foo bar+baz") self.assertEqual(request.COOKIES["test-time"], "1448995585123") self.assertEqual(request.COOKIES["test-value"], "yeah") self.assertFalse(request.POST)
def test_websocket_handshake(self): """ Tests handler can decode Websocket connection scope HTTP handshake. As per the ASGI specs, a request method is not supplied in the websocket connection scope (although the request method must have been GET as of RFC6455). """ request = AsgiRequest( { "type": "websocket", "http_version": "1.1", "query_string": "django=great", "path": "/test/", "headers": { "sec-websocket-protocol": b"350", }, }, BytesIO(b""), ) self.assertEqual(request.path, "/test/") self.assertEqual(request.method, "GET") self.assertEqual(request.META["REQUEST_METHOD"], "GET") self.assertEqual(request.GET["django"], "great") self.assertEqual(request.META["HTTP_SEC_WEBSOCKET_PROTOCOL"], "350")
def test_post_files(self): """ Tests POSTing files using multipart form data. """ body = ( b"--BOUNDARY\r\n" + b'Content-Disposition: form-data; name="title"\r\n\r\n' + b"My First Book\r\n" + b"--BOUNDARY\r\n" + b'Content-Disposition: form-data; name="pdf"; filename="book.pdf"\r\n\r\n' + b"FAKEPDFBYTESGOHERE" + b"--BOUNDARY--") request = AsgiRequest( { "http_version": "1.1", "method": "POST", "path": "/test/", "headers": { "content-type": b"multipart/form-data; boundary=BOUNDARY", "content-length": str(len(body)).encode("ascii"), }, }, body, ) self.assertEqual(request.method, "POST") self.assertEqual(len(request.body), len(body)) self.assertTrue( request.META["CONTENT_TYPE"].startswith("multipart/form-data")) self.assertFalse(request._post_parse_error) self.assertEqual(request.POST["title"], "My First Book") self.assertEqual(request.FILES["pdf"].read(), b"FAKEPDFBYTESGOHERE")
def connect(self): self.service = self.scope["service"] scope = dict(self.scope) scope["method"] = "get" request = AsgiRequest(scope, b"") request._request = request request.user = self.scope["user"] request.session = self.scope["session"] if not self.scope["user"].is_authenticated: self.authenticate(request) if self.check_permissions(request): raise AcceptConnection() else: raise DenyConnection()
def test_reading_body_after_stream_raises(self): request = AsgiRequest( { "http_version": "1.1", "method": "POST", "path": "/test2/", "query_string": "django=great", "headers": { "host": b"example.com", "content-type": b"application/x-www-form-urlencoded", "content-length": b"18", }, }, BytesIO(b"djangoponies=are+awesome"), ) self.assertEqual(request.read(3), b"dja") with pytest.raises(RawPostDataException): request.body
def test_script_name(self): request = AsgiRequest( { "http_version": "1.1", "method": "GET", "path": "/test/", "root_path": "/path/to/", }, b"", ) self.assertEqual(request.path, "/path/to/test/")
def test_size_exceeded(self): with override_settings(DATA_UPLOAD_MAX_MEMORY_SIZE=1): with pytest.raises(RequestDataTooBig): AsgiRequest( { "http_version": "1.1", "method": "PUT", "path": "/", "headers": {"host": b"example.com", "content-length": b"1000"}, }, BytesIO(b""), ).body
def from_scope(cls, viewset_action, scope, view_kwargs, query_params): """ "This is the magic." (reference: https://github.com/encode/django-rest-framework/blob/1e383f/rest_framework/viewsets.py#L47) This method initializes a view properly so that calls to methods like get_queryset() and get_serializer_class(), and permission checks have all the properties set, like self.kwargs and self.request, that they would expect. The production of a Django HttpRequest object from a base websocket asgi scope, rather than an actual HTTP request, is probably the largest "hack" in this project. By inspection of the ASGI spec, however, the only difference between websocket and HTTP scopes is the existence of an HTTP method (https://asgi.readthedocs.io/en/latest/specs/www.html). This is because websocket connections are established over an HTTP connection, and so headers and everything else are set just as they would be in a normal HTTP request. Therefore, the base of the request object for every broadcast is the initial HTTP request. Subscriptions are retrieval operations, so the method is hard-coded as GET. """ self = cls() self.format_kwarg = None self.action_map = dict() self.args = [] self.kwargs = view_kwargs base_request = AsgiRequest( { **scope, "method": "GET", "query_string": urlencode(query_params) }, BytesIO(), ) # TODO: Run other middleware? base_request.user = scope.get("user", None) base_request.session = scope.get("session", None) self.request = self.initialize_request(base_request) self.action = viewset_action # TODO: custom subscription actions? return self
def test_latin1_headers(self): request = AsgiRequest( { "http_version": "1.1", "method": "GET", "path": "/test2/", "headers": { "host": b"example.com", "foo": bytes("äbcd", encoding="latin1"), }, }, BytesIO(b""), ) self.assertEqual(request.headers["foo"], "äbcd")
def receive(self, text_data=None, bytes_data=None): gql = text_data user_agent = '' for tup in self.scope['headers']: if tup[0].decode('utf-8') == 'user-agent': user_agent = tup[1] asgi_scope = { 'client': self.scope['client'], 'path': '/graphql/', 'type': 'http', 'headers': [(b'origin', b'null'), (b'connection', b'keep-alive'), (b'accept-encoding', b'gzip, deflate, br'), (b'user-agent', user_agent), (b'content-type', b'application/json'), (b'host', b'localhost:8000'), (b'content-length', b'51'), (b'accept-language', b'zh-CN,zh;q=0.9'), (b'accept', b'application/json')], 'scheme': 'http', 'method': 'POST', 'query_string': b'', 'root_path': '', 'http_version': '1.1', 'server': ['127.0.0.1', 8000] } asgi_body = json.dumps({"query": gql}).encode('utf-8') asgi_request = AsgiRequest(asgi_scope, asgi_body) resp = BetterGraphQLView.as_view(schema=schema)(asgi_request) async_to_sync(self.channel_layer.group_send)( self.group_name, { 'type': 'query_result', 'data': json.loads(resp.content.decode('utf-8')) })
async def handle(self, body): from django_grip import GripMiddleware from .eventrequest import EventRequest from .eventstream import EventPermissionError from .utils import sse_error_response self.listener = None request = AsgiRequest(self.scope, body) gm = GripMiddleware() gm.process_request(request) if 'user' in self.scope: request.user = self.scope['user'] try: event_request = await self.parse_request(request) response = None except EventRequest.ResumeNotAllowedError as e: response = HttpResponseBadRequest( 'Invalid request: %s.\n' % str(e)) except EventRequest.GripError as e: if request.grip.proxied: response = sse_error_response( 'internal-error', 'Invalid internal request.') else: response = sse_error_response( 'bad-request', 'Invalid request: %s.' % str(e)) except EventRequest.Error as e: response = sse_error_response( 'bad-request', 'Invalid request: %s.' % str(e)) # for grip requests, prepare immediate response if not response and request.grip.proxied: try: event_response = await self.get_events(event_request) response = event_response.to_http_response(request) except EventPermissionError as e: response = sse_error_response( 'forbidden', str(e), {'channels': e.channels}) extra_headers = {} extra_headers['Cache-Control'] = 'no-cache' extra_headers['X-Accel-Buffering'] = 'no' augment_cors_headers(extra_headers) # if this was a grip request or we encountered an error, respond now if response: response = gm.process_response(request, response) headers = [] for name, value in response.items(): if isinstance(name, six.text_type): name = name.encode('utf-8') if isinstance(value, six.text_type): value = value.encode('utf-8') headers.append((name, value)) for name, value in extra_headers.items(): if isinstance(name, six.text_type): name = name.encode('utf-8') if isinstance(value, six.text_type): value = value.encode('utf-8') headers.append((name, value)) await self.send_response( response.status_code, response.content, headers=headers ) return # if we got here then the request was not a grip request, and there # were no errors, so we can begin a local stream response headers = [(six.b('Content-Type'), six.b('text/event-stream'))] for name, value in extra_headers.items(): if isinstance(name, six.text_type): name = name.encode('utf-8') if isinstance(value, six.text_type): value = value.encode('utf-8') headers.append((name, value)) await self.send_headers(headers=headers) self.listener = Listener() self.is_streaming = True asyncio.get_event_loop().create_task(self.stream(event_request))
async def handle(self, body): from .eventrequest import EventRequest from .eventstream import EventPermissionError from .utils import sse_encode_event, sse_error_response, make_id self.listener = None request = AsgiRequest(self.scope, body) # TODO use GripMiddleware request.grip_proxied = False for name, value in self.scope['headers']: if name == b'grip-sig': request.grip_proxied = True break if 'user' in self.scope: request.user = self.scope['user'] try: event_request = await self.parse_request(request) response = None except EventRequest.ResumeNotAllowedError as e: response = HttpResponseBadRequest( 'Invalid request: %s.\n' % str(e)) except EventRequest.GripError as e: if request.grip_proxied: response = sse_error_response( 'internal-error', 'Invalid internal request.') else: response = sse_error_response( 'bad-request', 'Invalid request: %s.' % str(e)) except EventRequest.Error as e: response = sse_error_response( 'bad-request', 'Invalid request: %s.' % str(e)) user = None if hasattr(request, 'user') and request.user.is_authenticated: user = request.user # for grip requests, prepare immediate response if not response and request.grip_proxied: try: event_response = await self.get_events(event_request, user) response = event_response.to_http_response(request) except EventPermissionError as e: response = sse_error_response( 'forbidden', str(e), {'channels': e.channels}) extra_headers = {} extra_headers['Cache-Control'] = 'no-cache' cors_origin = '' if hasattr(settings, 'EVENTSTREAM_ALLOW_ORIGIN'): cors_origin = settings.EVENTSTREAM_ALLOW_ORIGIN if cors_origin: extra_headers['Access-Control-Allow-Origin'] = cors_origin # if this was a grip request or we encountered an error, respond now if response: headers = [] for name, value in response.items(): headers.append((name, value)) for name, value in extra_headers.items(): headers.append((name, value)) await self.send_response( response.status_code, response.content, headers=headers ) return # if we got here then the request was not a grip request, and there # were no errors, so we can begin a local stream response headers = [('Content-Type', 'text/event-stream')] for name, value in extra_headers.items(): headers.append((name, value)) await self.send_headers(headers=headers) body = b':' + (b' ' * 2048) + b'\n\n' body += b'event: stream-open\ndata:\n\n' await self.send_body(body, more_body=True) self.listener = Listener() self.is_streaming = True asyncio.get_event_loop().create_task(self.stream(event_request, user))