예제 #1
0
    def test_size_check_ignores_files(self):
        file_data = (
            b"FAKEPDFBYTESGOHERETHISISREALLYLONGBUTNOTUSEDTOCOMPUTETHESIZEOFTHEREQUEST"
        )
        body = (
            b"--BOUNDARY\r\n" + b'Content-Disposition: form-data; name="' +
            b"Title" + b'"\r\n\r\n' + b"My first book" + b"\r\n" +
            b"--BOUNDARY\r\n" +
            b'Content-Disposition: form-data; name="pdf"; filename="book.pdf"\r\n\r\n'
            + file_data + b"--BOUNDARY--")

        scope = {
            "http_version": "1.1",
            "method": "POST",
            "path": "/test/",
            "headers": {
                "content-type": b"multipart/form-data; boundary=BOUNDARY",
                "content-length": str(len(body)).encode("latin1"),
            },
        }

        with override_settings(DATA_UPLOAD_MAX_MEMORY_SIZE=1):
            with pytest.raises(RequestDataTooBig):
                AsgiRequest(scope, BytesIO(body)).POST

        smaller_than_file_data_size = len(file_data) - 20
        with override_settings(
                DATA_UPLOAD_MAX_MEMORY_SIZE=smaller_than_file_data_size):
            # There is no exception, since the setting should not take into
            # account the size of the file upload data.
            AsgiRequest(scope, BytesIO(body)).POST
def get_scope_user(scope):
    if "_cached_user" not in scope:
        # We need to fake a request so the auth code works
        scope['method'] = "FAKE"
        from channels.http import AsgiRequest
        fake_request = AsgiRequest(scope, b'')
        fake_request.session = scope["session"]

        scope["_cached_user"] = _get_user(fake_request)
    return scope["_cached_user"]
예제 #3
0
    def __call__(self, scope):
        """Call the middleware."""
        def authenticate(request):
            """
            Authenticate the request.

            Returns a tuple containing the user and their access token.
            If it's not valid then None is returned.
            """
            oauthlib_core = get_oauthlib_core()
            valid, r = oauthlib_core.verify_request(request, scopes=[])
            if valid:
                return r.user, r.access_token
            return None, None

        if scope.get('user') and scope['user'] != AnonymousUser:
            # We already have an authenticated user
            return self.inner(scope)

        if "method" not in scope:
            scope['method'] = "FAKE"

        request = AsgiRequest(scope, b"")

        if request.META.get("HTTP_AUTHORIZATION"):
            user, _ = authenticate(request)
            if user:
                scope['user'] = user

        return self.inner(scope)
예제 #4
0
파일: consumer.py 프로젝트: mochi-ai/rested
    def secure(self):
        # use django middleware to get session and authenticate on initial ws connection
        django_middlewhere = [
            'django.contrib.sessions.middleware.SessionMiddleware',
            'django.contrib.auth.middleware.AuthenticationMiddleware',
            # need to check csrf but unable to set that in the http request initiating ws, probably set it in query
        ]

        # build a django request from ws request/scope
        self.scope['method'] = 'WEBSOCKET'
        request = AsgiRequest(self.scope, '')

        # get channel's django middleware
        middleware = [
            import_string(m)(lambda x: None) for m in django_middlewhere
        ]

        # make sure ws request passes channel's django middleware
        for m in middleware:
            if m.process_request(request): raise DenyConnection()

        # set session and user
        if hasattr(request, 'session'): self.session = request.session
        if hasattr(request, 'user'): self.user = request.user

        # deny if we don't have a session and an authenticated user
        if not self.session or not self.user or not self.user.is_authenticated:
            raise DenyConnection()
예제 #5
0
 def test_stream(self):
     """
     Tests the body stream is emulated correctly.
     """
     request = AsgiRequest(
         {
             "http_version": "1.1",
             "method": "PUT",
             "path": "/",
             "headers": {"host": b"example.com", "content-length": b"11"},
         },
         BytesIO(b"onetwothree"),
     )
     self.assertEqual(request.method, "PUT")
     self.assertEqual(request.read(3), b"one")
     self.assertEqual(request.read(), b"twothree")
예제 #6
0
 def test_stream(self):
     """
     Tests the body stream is emulated correctly.
     """
     request = AsgiRequest(
         {
             "http_version": "1.1",
             "method": "PUT",
             "path": "/",
             "headers": {"host": b"example.com", "content-length": b"11"},
         },
         b"onetwothree",
     )
     self.assertEqual(request.method, "PUT")
     self.assertEqual(request.read(3), b"one")
     self.assertEqual(request.read(), b"twothree")
예제 #7
0
 def test_post(self):
     """
     Tests a POST body.
     """
     request = AsgiRequest(
         {
             "http_version": "1.1",
             "method": "POST",
             "path": "/test2/",
             "query_string": "django=great",
             "headers": {
                 "host": b"example.com",
                 "content-type": b"application/x-www-form-urlencoded",
                 "content-length": b"18",
             },
         },
         b"djangoponies=are+awesome",
     )
     self.assertEqual(request.path, "/test2/")
     self.assertEqual(request.method, "POST")
     self.assertEqual(request.body, b"djangoponies=are+awesome")
     self.assertEqual(request.META["HTTP_HOST"], "example.com")
     self.assertEqual(request.META["CONTENT_TYPE"],
                      "application/x-www-form-urlencoded")
     self.assertEqual(request.GET["django"], "great")
     self.assertEqual(request.POST["djangoponies"], "are awesome")
     with self.assertRaises(KeyError):
         request.POST["django"]
     with self.assertRaises(KeyError):
         request.GET["djangoponies"]
예제 #8
0
 def test_basic(self):
     """
     Tests that the handler can decode the most basic request message,
     with all optional fields omitted.
     """
     request = AsgiRequest(
         {
             "http_version": "1.1",
             "method": "GET",
             "path": "/test/",
         },
         b"",
     )
     self.assertEqual(request.path, "/test/")
     self.assertEqual(request.method, "GET")
     self.assertFalse(request.body)
     self.assertNotIn("HTTP_HOST", request.META)
     self.assertNotIn("REMOTE_ADDR", request.META)
     self.assertNotIn("REMOTE_HOST", request.META)
     self.assertNotIn("REMOTE_PORT", request.META)
     self.assertIn("SERVER_NAME", request.META)
     self.assertIn("SERVER_PORT", request.META)
     self.assertFalse(request.GET)
     self.assertFalse(request.POST)
     self.assertFalse(request.COOKIES)
예제 #9
0
 def test_extended(self):
     """
     Tests a more fully-featured GET request
     """
     request = AsgiRequest(
         {
             "http_version": "1.1",
             "method": "GET",
             "path": "/test2/",
             "query_string": b"x=1&y=%26foo+bar%2Bbaz",
             "headers": {
                 "host": b"example.com",
                 "cookie": b"test-time=1448995585123; test-value=yeah",
             },
             "client": ["10.0.0.1", 1234],
             "server": ["10.0.0.2", 80],
         },
         b"",
     )
     self.assertEqual(request.path, "/test2/")
     self.assertEqual(request.method, "GET")
     self.assertFalse(request.body)
     self.assertEqual(request.META["HTTP_HOST"], "example.com")
     self.assertEqual(request.META["REMOTE_ADDR"], "10.0.0.1")
     self.assertEqual(request.META["REMOTE_HOST"], "10.0.0.1")
     self.assertEqual(request.META["REMOTE_PORT"], 1234)
     self.assertEqual(request.META["SERVER_NAME"], "10.0.0.2")
     self.assertEqual(request.META["SERVER_PORT"], "80")
     self.assertEqual(request.GET["x"], "1")
     self.assertEqual(request.GET["y"], "&foo bar+baz")
     self.assertEqual(request.COOKIES["test-time"], "1448995585123")
     self.assertEqual(request.COOKIES["test-value"], "yeah")
     self.assertFalse(request.POST)
예제 #10
0
    def test_websocket_handshake(self):
        """
        Tests handler can decode Websocket connection scope HTTP handshake.

        As per the ASGI specs, a request method is not supplied in the
        websocket connection scope (although the request method must have been
        GET as of RFC6455).
        """

        request = AsgiRequest(
            {
                "type": "websocket",
                "http_version": "1.1",
                "query_string": "django=great",
                "path": "/test/",
                "headers": {
                    "sec-websocket-protocol": b"350",
                },
            },
            BytesIO(b""),
        )
        self.assertEqual(request.path, "/test/")
        self.assertEqual(request.method, "GET")
        self.assertEqual(request.META["REQUEST_METHOD"], "GET")
        self.assertEqual(request.GET["django"], "great")
        self.assertEqual(request.META["HTTP_SEC_WEBSOCKET_PROTOCOL"], "350")
예제 #11
0
 def test_post_files(self):
     """
     Tests POSTing files using multipart form data.
     """
     body = (
         b"--BOUNDARY\r\n" +
         b'Content-Disposition: form-data; name="title"\r\n\r\n' +
         b"My First Book\r\n" + b"--BOUNDARY\r\n" +
         b'Content-Disposition: form-data; name="pdf"; filename="book.pdf"\r\n\r\n'
         + b"FAKEPDFBYTESGOHERE" + b"--BOUNDARY--")
     request = AsgiRequest(
         {
             "http_version": "1.1",
             "method": "POST",
             "path": "/test/",
             "headers": {
                 "content-type": b"multipart/form-data; boundary=BOUNDARY",
                 "content-length": str(len(body)).encode("ascii"),
             },
         },
         body,
     )
     self.assertEqual(request.method, "POST")
     self.assertEqual(len(request.body), len(body))
     self.assertTrue(
         request.META["CONTENT_TYPE"].startswith("multipart/form-data"))
     self.assertFalse(request._post_parse_error)
     self.assertEqual(request.POST["title"], "My First Book")
     self.assertEqual(request.FILES["pdf"].read(), b"FAKEPDFBYTESGOHERE")
예제 #12
0
    def connect(self):
        self.service = self.scope["service"]

        scope = dict(self.scope)
        scope["method"] = "get"
        request = AsgiRequest(scope, b"")
        request._request = request
        request.user = self.scope["user"]
        request.session = self.scope["session"]

        if not self.scope["user"].is_authenticated:
            self.authenticate(request)

        if self.check_permissions(request):
            raise AcceptConnection()
        else:
            raise DenyConnection()
예제 #13
0
 def test_reading_body_after_stream_raises(self):
     request = AsgiRequest(
         {
             "http_version": "1.1",
             "method": "POST",
             "path": "/test2/",
             "query_string": "django=great",
             "headers": {
                 "host": b"example.com",
                 "content-type": b"application/x-www-form-urlencoded",
                 "content-length": b"18",
             },
         },
         BytesIO(b"djangoponies=are+awesome"),
     )
     self.assertEqual(request.read(3), b"dja")
     with pytest.raises(RawPostDataException):
         request.body
예제 #14
0
    def test_script_name(self):
        request = AsgiRequest(
            {
                "http_version": "1.1",
                "method": "GET",
                "path": "/test/",
                "root_path": "/path/to/",
            },
            b"",
        )

        self.assertEqual(request.path, "/path/to/test/")
예제 #15
0
 def test_size_exceeded(self):
     with override_settings(DATA_UPLOAD_MAX_MEMORY_SIZE=1):
         with pytest.raises(RequestDataTooBig):
             AsgiRequest(
                 {
                     "http_version": "1.1",
                     "method": "PUT",
                     "path": "/",
                     "headers": {"host": b"example.com", "content-length": b"1000"},
                 },
                 BytesIO(b""),
             ).body
예제 #16
0
    def from_scope(cls, viewset_action, scope, view_kwargs, query_params):
        """
        "This is the magic."
        (reference: https://github.com/encode/django-rest-framework/blob/1e383f/rest_framework/viewsets.py#L47)

        This method initializes a view properly so that calls to methods like get_queryset() and get_serializer_class(),
        and permission checks have all the properties set, like self.kwargs and self.request, that they would expect.

        The production of a Django HttpRequest object from a base websocket asgi scope, rather than an actual HTTP
        request, is probably the largest "hack" in this project. By inspection of the ASGI spec, however,
        the only difference between websocket and HTTP scopes is the existence of an HTTP method
        (https://asgi.readthedocs.io/en/latest/specs/www.html).

        This is because websocket connections are established over an HTTP connection, and so headers and everything
        else are set just as they would be in a normal HTTP request. Therefore, the base of the request object for
        every broadcast is the initial HTTP request. Subscriptions are retrieval operations, so the method is hard-coded
        as GET.
        """
        self = cls()

        self.format_kwarg = None
        self.action_map = dict()
        self.args = []
        self.kwargs = view_kwargs

        base_request = AsgiRequest(
            {
                **scope, "method": "GET",
                "query_string": urlencode(query_params)
            },
            BytesIO(),
        )
        # TODO: Run other middleware?
        base_request.user = scope.get("user", None)
        base_request.session = scope.get("session", None)

        self.request = self.initialize_request(base_request)
        self.action = viewset_action  # TODO: custom subscription actions?
        return self
예제 #17
0
    def test_latin1_headers(self):
        request = AsgiRequest(
            {
                "http_version": "1.1",
                "method": "GET",
                "path": "/test2/",
                "headers": {
                    "host": b"example.com",
                    "foo": bytes("äbcd", encoding="latin1"),
                },
            },
            BytesIO(b""),
        )

        self.assertEqual(request.headers["foo"], "äbcd")
예제 #18
0
 def receive(self, text_data=None, bytes_data=None):
     gql = text_data
     user_agent = ''
     for tup in self.scope['headers']:
         if tup[0].decode('utf-8') == 'user-agent':
             user_agent = tup[1]
     asgi_scope = {
         'client':
         self.scope['client'],
         'path':
         '/graphql/',
         'type':
         'http',
         'headers': [(b'origin', b'null'), (b'connection', b'keep-alive'),
                     (b'accept-encoding', b'gzip, deflate, br'),
                     (b'user-agent', user_agent),
                     (b'content-type', b'application/json'),
                     (b'host', b'localhost:8000'),
                     (b'content-length', b'51'),
                     (b'accept-language', b'zh-CN,zh;q=0.9'),
                     (b'accept', b'application/json')],
         'scheme':
         'http',
         'method':
         'POST',
         'query_string':
         b'',
         'root_path':
         '',
         'http_version':
         '1.1',
         'server': ['127.0.0.1', 8000]
     }
     asgi_body = json.dumps({"query": gql}).encode('utf-8')
     asgi_request = AsgiRequest(asgi_scope, asgi_body)
     resp = BetterGraphQLView.as_view(schema=schema)(asgi_request)
     async_to_sync(self.channel_layer.group_send)(
         self.group_name, {
             'type': 'query_result',
             'data': json.loads(resp.content.decode('utf-8'))
         })
예제 #19
0
	async def handle(self, body):
		from django_grip import GripMiddleware
		from .eventrequest import EventRequest
		from .eventstream import EventPermissionError
		from .utils import sse_error_response

		self.listener = None

		request = AsgiRequest(self.scope, body)

		gm = GripMiddleware()
		gm.process_request(request)

		if 'user' in self.scope:
			request.user = self.scope['user']

		try:
			event_request = await self.parse_request(request)
			response = None
		except EventRequest.ResumeNotAllowedError as e:
			response = HttpResponseBadRequest(
				'Invalid request: %s.\n' % str(e))
		except EventRequest.GripError as e:
			if request.grip.proxied:
				response = sse_error_response(
					'internal-error',
					'Invalid internal request.')
			else:
				response = sse_error_response(
					'bad-request',
					'Invalid request: %s.' % str(e))
		except EventRequest.Error as e:
			response = sse_error_response(
				'bad-request',
				'Invalid request: %s.' % str(e))

		# for grip requests, prepare immediate response
		if not response and request.grip.proxied:
			try:
				event_response = await self.get_events(event_request)
				response = event_response.to_http_response(request)
			except EventPermissionError as e:
				response = sse_error_response(
					'forbidden',
					str(e),
					{'channels': e.channels})

		extra_headers = {}
		extra_headers['Cache-Control'] = 'no-cache'
		extra_headers['X-Accel-Buffering'] = 'no'
		augment_cors_headers(extra_headers)

		# if this was a grip request or we encountered an error, respond now
		if response:
			response = gm.process_response(request, response)

			headers = []
			for name, value in response.items():
				if isinstance(name, six.text_type):
					name = name.encode('utf-8')
				if isinstance(value, six.text_type):
					value = value.encode('utf-8')
				headers.append((name, value))

			for name, value in extra_headers.items():
				if isinstance(name, six.text_type):
					name = name.encode('utf-8')
				if isinstance(value, six.text_type):
					value = value.encode('utf-8')
				headers.append((name, value))

			await self.send_response(
				response.status_code,
				response.content,
				headers=headers
			)
			return

		# if we got here then the request was not a grip request, and there
		#   were no errors, so we can begin a local stream response

		headers = [(six.b('Content-Type'), six.b('text/event-stream'))]
		for name, value in extra_headers.items():
			if isinstance(name, six.text_type):
				name = name.encode('utf-8')
			if isinstance(value, six.text_type):
				value = value.encode('utf-8')
			headers.append((name, value))

		await self.send_headers(headers=headers)

		self.listener = Listener()
		self.is_streaming = True

		asyncio.get_event_loop().create_task(self.stream(event_request))
예제 #20
0
	async def handle(self, body):
		from .eventrequest import EventRequest
		from .eventstream import EventPermissionError
		from .utils import sse_encode_event, sse_error_response, make_id

		self.listener = None

		request = AsgiRequest(self.scope, body)

		# TODO use GripMiddleware
		request.grip_proxied = False
		for name, value in self.scope['headers']:
			if name == b'grip-sig':
				request.grip_proxied = True
				break

		if 'user' in self.scope:
			request.user = self.scope['user']

		try:
			event_request = await self.parse_request(request)
			response = None
		except EventRequest.ResumeNotAllowedError as e:
			response = HttpResponseBadRequest(
				'Invalid request: %s.\n' % str(e))
		except EventRequest.GripError as e:
			if request.grip_proxied:
				response = sse_error_response(
					'internal-error',
					'Invalid internal request.')
			else:
				response = sse_error_response(
					'bad-request',
					'Invalid request: %s.' % str(e))
		except EventRequest.Error as e:
			response = sse_error_response(
				'bad-request',
				'Invalid request: %s.' % str(e))

		user = None
		if hasattr(request, 'user') and request.user.is_authenticated:
			user = request.user

		# for grip requests, prepare immediate response
		if not response and request.grip_proxied:
			try:
				event_response = await self.get_events(event_request, user)
				response = event_response.to_http_response(request)
			except EventPermissionError as e:
				response = sse_error_response(
					'forbidden',
					str(e),
					{'channels': e.channels})

		extra_headers = {}
		extra_headers['Cache-Control'] = 'no-cache'

		cors_origin = ''
		if hasattr(settings, 'EVENTSTREAM_ALLOW_ORIGIN'):
			cors_origin = settings.EVENTSTREAM_ALLOW_ORIGIN

		if cors_origin:
			extra_headers['Access-Control-Allow-Origin'] = cors_origin

		# if this was a grip request or we encountered an error, respond now
		if response:
			headers = []
			for name, value in response.items():
				headers.append((name, value))

			for name, value in extra_headers.items():
				headers.append((name, value))

			await self.send_response(
				response.status_code,
				response.content,
				headers=headers
			)
			return

		# if we got here then the request was not a grip request, and there
		#   were no errors, so we can begin a local stream response

		headers = [('Content-Type', 'text/event-stream')]
		for name, value in extra_headers.items():
			headers.append((name, value))

		await self.send_headers(headers=headers)

		body = b':' + (b' ' * 2048) + b'\n\n'
		body += b'event: stream-open\ndata:\n\n'
		await self.send_body(body, more_body=True)

		self.listener = Listener()
		self.is_streaming = True

		asyncio.get_event_loop().create_task(self.stream(event_request, user))