def render(self, request): # Get around .parent limitation if self._factory is None: self._makeFactory() # Override handling of invalid methods, returning 400 makes SockJS mad if request.method != 'GET': request.setResponseCode(405) request.defaultContentType = None # SockJS wants this gone request.setHeader('Allow', 'GET') return "" # Override handling of lack of headers, again SockJS requires non-RFC stuff upgrade = request.getHeader("Upgrade") if upgrade is None or "websocket" not in upgrade.lower(): request.setResponseCode(400) return 'Can "Upgrade" only to "WebSocket".' connection = request.getHeader("Connection") if connection is None or "upgrade" not in connection.lower(): request.setResponseCode(400) return '"Connection" must be "Upgrade".' # Defer to inherited methods ret = WebSocketsResource.render( self, request) # For RFC versions of websockets if ret is NOT_DONE_YET: return ret return OldWebSocketsResource.render( self, request) # For non-RFC versions of websockets
def render(self, request): # Get around .parent limitation if self._factory is None: self._makeFactory() # Override handling of invalid methods, returning 400 makes SockJS mad if request.method != 'GET': request.setResponseCode(405) request.defaultContentType = None # SockJS wants this gone request.setHeader('Allow','GET') return "" # Override handling of lack of headers, again SockJS requires non-RFC stuff upgrade = request.getHeader("Upgrade") if upgrade is None or "websocket" not in upgrade.lower(): request.setResponseCode(400) return 'Can "Upgrade" only to "WebSocket".' connection = request.getHeader("Connection") if connection is None or "upgrade" not in connection.lower(): request.setResponseCode(400) return '"Connection" must be "Upgrade".' # Defer to inherited methods ret = WebSocketsResource.render(self, request) # For RFC versions of websockets if ret is NOT_DONE_YET: return ret return OldWebSocketsResource.render(self, request) # For non-RFC versions of websockets
class WebSocketsResourceTest(TestCase): """ Tests for L{WebSocketsResource}. """ def setUp(self): class SavingEchoFactory(Factory): def buildProtocol(oself, addr): return self.echoProtocol factory = SavingEchoFactory() self.echoProtocol = WebSocketsProtocol(SavingEchoReceiver()) self.resource = WebSocketsResource(lookupProtocolForFactory(factory)) def assertRequestFail(self, request): """ Helper method checking that the provided C{request} fails with a I{400} request code, without data or headers. @param request: The request to render. @type request: L{DummyRequest} """ result = self.resource.render(request) self.assertEqual("", result) self.assertEqual({}, request.outgoingHeaders) self.assertEqual([], request.written) self.assertEqual(400, request.responseCode) def test_getChildWithDefault(self): """ L{WebSocketsResource.getChildWithDefault} raises a C{RuntimeError} when called. """ self.assertRaises( RuntimeError, self.resource.getChildWithDefault, "foo", DummyRequest("/")) def test_putChild(self): """ L{WebSocketsResource.putChild} raises C{RuntimeError} when called. """ self.assertRaises( RuntimeError, self.resource.putChild, "foo", Resource()) def test_IResource(self): """ L{WebSocketsResource} implements L{IResource}. """ self.assertTrue(verifyObject(IResource, self.resource)) def test_render(self): """ When rendering a request, L{WebSocketsResource} uses the C{Sec-WebSocket-Key} header to generate a C{Sec-WebSocket-Accept} value. It creates a L{WebSocketsProtocol} instance connected to the protocol provided by the user factory. """ request = DummyRequest("/") request.requestHeaders = Headers() transport = StringTransportWithDisconnection() transport.protocol = Protocol() request.transport = transport request.headers.update({ "upgrade": "Websocket", "connection": "Upgrade", "sec-websocket-key": "secure", "sec-websocket-version": "13"}) result = self.resource.render(request) self.assertEqual(NOT_DONE_YET, result) self.assertEqual( {"connection": "Upgrade", "upgrade": "WebSocket", "sec-websocket-accept": "oYBv54i42V5dw6KnZqOFroecUTc="}, request.outgoingHeaders) self.assertEqual([""], request.written) self.assertEqual(101, request.responseCode) self.assertIdentical(None, request.transport) self.assertIsInstance(transport.protocol._receiver, SavingEchoReceiver) def test_renderProtocol(self): """ If protocols are specified via the C{Sec-WebSocket-Protocol} header, L{WebSocketsResource} passes them to its C{lookupProtocol} argument, which can decide which protocol to return, and which is accepted. """ def lookupProtocol(names, otherRequest): self.assertEqual(["foo", "bar"], names) self.assertIdentical(request, otherRequest) return self.echoProtocol, "bar" self.resource = WebSocketsResource(lookupProtocol) request = DummyRequest("/") request.requestHeaders = Headers( {"sec-websocket-protocol": ["foo", "bar"]}) transport = StringTransportWithDisconnection() transport.protocol = Protocol() request.transport = transport request.headers.update({ "upgrade": "Websocket", "connection": "Upgrade", "sec-websocket-key": "secure", "sec-websocket-version": "13"}) result = self.resource.render(request) self.assertEqual(NOT_DONE_YET, result) self.assertEqual( {"connection": "Upgrade", "upgrade": "WebSocket", "sec-websocket-protocol": "bar", "sec-websocket-accept": "oYBv54i42V5dw6KnZqOFroecUTc="}, request.outgoingHeaders) self.assertEqual([""], request.written) self.assertEqual(101, request.responseCode) def test_renderWrongUpgrade(self): """ If the C{Upgrade} header contains an invalid value, L{WebSocketsResource} returns a failed request. """ request = DummyRequest("/") request.headers.update({ "upgrade": "wrong", "connection": "Upgrade", "sec-websocket-key": "secure", "sec-websocket-version": "13"}) self.assertRequestFail(request) def test_renderNoUpgrade(self): """ If the C{Upgrade} header is not set, L{WebSocketsResource} returns a failed request. """ request = DummyRequest("/") request.headers.update({ "connection": "Upgrade", "sec-websocket-key": "secure", "sec-websocket-version": "13"}) self.assertRequestFail(request) def test_renderPOST(self): """ If the method is not C{GET}, L{WebSocketsResource} returns a failed request. """ request = DummyRequest("/") request.method = "POST" request.headers.update({ "upgrade": "Websocket", "connection": "Upgrade", "sec-websocket-key": "secure", "sec-websocket-version": "13"}) self.assertRequestFail(request) def test_renderWrongConnection(self): """ If the C{Connection} header contains an invalid value, L{WebSocketsResource} returns a failed request. """ request = DummyRequest("/") request.headers.update({ "upgrade": "Websocket", "connection": "Wrong", "sec-websocket-key": "secure", "sec-websocket-version": "13"}) self.assertRequestFail(request) def test_renderNoConnection(self): """ If the C{Connection} header is not set, L{WebSocketsResource} returns a failed request. """ request = DummyRequest("/") request.headers.update({ "upgrade": "Websocket", "sec-websocket-key": "secure", "sec-websocket-version": "13"}) self.assertRequestFail(request) def test_renderNoKey(self): """ If the C{Sec-WebSocket-Key} header is not set, L{WebSocketsResource} returns a failed request. """ request = DummyRequest("/") request.headers.update({ "upgrade": "Websocket", "connection": "Upgrade", "sec-websocket-version": "13"}) self.assertRequestFail(request) def test_renderWrongVersion(self): """ If the value of the C{Sec-WebSocket-Version} is not 13, L{WebSocketsResource} returns a failed request. """ request = DummyRequest("/") request.headers.update({ "upgrade": "Websocket", "connection": "Upgrade", "sec-websocket-key": "secure", "sec-websocket-version": "11"}) result = self.resource.render(request) self.assertEqual("", result) self.assertEqual({"sec-websocket-version": "13"}, request.outgoingHeaders) self.assertEqual([], request.written) self.assertEqual(400, request.responseCode) def test_renderNoProtocol(self): """ If the underlying factory doesn't return any protocol, L{WebSocketsResource} returns a failed request with a C{502} code. """ request = DummyRequest("/") request.requestHeaders = Headers() request.transport = StringTransportWithDisconnection() self.echoProtocol = None request.headers.update({ "upgrade": "Websocket", "connection": "Upgrade", "sec-websocket-key": "secure", "sec-websocket-version": "13"}) result = self.resource.render(request) self.assertEqual("", result) self.assertEqual({}, request.outgoingHeaders) self.assertEqual([], request.written) self.assertEqual(502, request.responseCode) def test_renderSecureRequest(self): """ When the rendered request is over HTTPS, L{WebSocketsResource} wraps the protocol of the C{TLSMemoryBIOProtocol} instance. """ request = DummyRequest("/") request.requestHeaders = Headers() transport = StringTransportWithDisconnection() secureProtocol = TLSMemoryBIOProtocol(Factory(), Protocol()) transport.protocol = secureProtocol request.transport = transport request.headers.update({ "upgrade": "Websocket", "connection": "Upgrade", "sec-websocket-key": "secure", "sec-websocket-version": "13"}) result = self.resource.render(request) self.assertEqual(NOT_DONE_YET, result) self.assertEqual( {"connection": "Upgrade", "upgrade": "WebSocket", "sec-websocket-accept": "oYBv54i42V5dw6KnZqOFroecUTc="}, request.outgoingHeaders) self.assertEqual([""], request.written) self.assertEqual(101, request.responseCode) self.assertIdentical(None, request.transport) self.assertIsInstance( transport.protocol.wrappedProtocol, WebSocketsProtocol) self.assertIsInstance( transport.protocol.wrappedProtocol._receiver, SavingEchoReceiver) def test_renderRealRequest(self): """ The request managed by L{WebSocketsResource.render} doesn't contain unnecessary HTTP headers like I{Content-Type} or I{Transfer-Encoding}. """ channel = DummyChannel() channel.transport = StringTransportWithDisconnection() channel.transport.protocol = channel request = Request(channel, False) headers = { "upgrade": "Websocket", "connection": "Upgrade", "sec-websocket-key": "secure", "sec-websocket-version": "13"} for key, value in headers.items(): request.requestHeaders.setRawHeaders(key, [value]) request.method = "GET" request.clientproto = "HTTP/1.1" result = self.resource.render(request) self.assertEqual(NOT_DONE_YET, result) self.assertEqual( [("Connection", ["Upgrade"]), ("Upgrade", ["WebSocket"]), ("Sec-Websocket-Accept", ["oYBv54i42V5dw6KnZqOFroecUTc="])], list(request.responseHeaders.getAllRawHeaders())) self.assertEqual( "HTTP/1.1 101 Switching Protocols\r\n" "Connection: Upgrade\r\n" "Upgrade: WebSocket\r\n" "Sec-Websocket-Accept: oYBv54i42V5dw6KnZqOFroecUTc=\r\n\r\n", channel.transport.value()) self.assertEqual(101, request.code) self.assertIdentical(None, request.transport) def test_renderIProtocol(self): """ If the protocol returned by C{lookupProtocol} isn't a C{WebSocketsProtocol}, L{WebSocketsResource} wraps it automatically with L{WebSocketsProtocolWrapper}. """ def lookupProtocol(names, otherRequest): return AccumulatingProtocol(), None self.resource = WebSocketsResource(lookupProtocol) request = DummyRequest("/") request.requestHeaders = Headers() transport = StringTransportWithDisconnection() transport.protocol = Protocol() request.transport = transport request.headers.update({ "upgrade": "Websocket", "connection": "Upgrade", "sec-websocket-key": "secure", "sec-websocket-version": "13"}) result = self.resource.render(request) self.assertEqual(NOT_DONE_YET, result) self.assertIsInstance(transport.protocol, WebSocketsProtocolWrapper) self.assertIsInstance(transport.protocol.wrappedProtocol, AccumulatingProtocol)
class WebSocketsResourceTest(TestCase): """ Tests for L{WebSocketsResource}. """ def setUp(self): class SavingEchoFactory(Factory): def buildProtocol(oself, addr): return self.echoProtocol factory = SavingEchoFactory() self.echoProtocol = WebSocketsProtocol(SavingEchoReceiver()) self.resource = WebSocketsResource(lookupProtocolForFactory(factory)) def assertRequestFail(self, request): """ Helper method checking that the provided C{request} fails with a I{400} request code, without data or headers. @param request: The request to render. @type request: L{DummyRequest} """ result = self.resource.render(request) self.assertEqual("", result) self.assertEqual({}, request.outgoingHeaders) self.assertEqual([], request.written) self.assertEqual(400, request.responseCode) def test_getChildWithDefault(self): """ L{WebSocketsResource.getChildWithDefault} raises a C{RuntimeError} when called. """ self.assertRaises(RuntimeError, self.resource.getChildWithDefault, "foo", DummyRequest("/")) def test_putChild(self): """ L{WebSocketsResource.putChild} raises C{RuntimeError} when called. """ self.assertRaises(RuntimeError, self.resource.putChild, "foo", Resource()) def test_IResource(self): """ L{WebSocketsResource} implements L{IResource}. """ self.assertTrue(verifyObject(IResource, self.resource)) def test_render(self): """ When rendering a request, L{WebSocketsResource} uses the C{Sec-WebSocket-Key} header to generate a C{Sec-WebSocket-Accept} value. It creates a L{WebSocketsProtocol} instance connected to the protocol provided by the user factory. """ request = DummyRequest("/") request.requestHeaders = Headers() transport = StringTransportWithDisconnection() transport.protocol = Protocol() request.transport = transport request.headers.update({ "upgrade": "Websocket", "connection": "Upgrade", "sec-websocket-key": "secure", "sec-websocket-version": "13" }) result = self.resource.render(request) self.assertEqual(NOT_DONE_YET, result) self.assertEqual( { "connection": "Upgrade", "upgrade": "WebSocket", "sec-websocket-accept": "oYBv54i42V5dw6KnZqOFroecUTc=" }, request.outgoingHeaders) self.assertEqual([""], request.written) self.assertEqual(101, request.responseCode) self.assertIdentical(None, request.transport) self.assertIsInstance(transport.protocol._receiver, SavingEchoReceiver) def test_renderProtocol(self): """ If protocols are specified via the C{Sec-WebSocket-Protocol} header, L{WebSocketsResource} passes them to its C{lookupProtocol} argument, which can decide which protocol to return, and which is accepted. """ def lookupProtocol(names, otherRequest): self.assertEqual(["foo", "bar"], names) self.assertIdentical(request, otherRequest) return self.echoProtocol, "bar" self.resource = WebSocketsResource(lookupProtocol) request = DummyRequest("/") request.requestHeaders = Headers( {"sec-websocket-protocol": ["foo", "bar"]}) transport = StringTransportWithDisconnection() transport.protocol = Protocol() request.transport = transport request.headers.update({ "upgrade": "Websocket", "connection": "Upgrade", "sec-websocket-key": "secure", "sec-websocket-version": "13" }) result = self.resource.render(request) self.assertEqual(NOT_DONE_YET, result) self.assertEqual( { "connection": "Upgrade", "upgrade": "WebSocket", "sec-websocket-protocol": "bar", "sec-websocket-accept": "oYBv54i42V5dw6KnZqOFroecUTc=" }, request.outgoingHeaders) self.assertEqual([""], request.written) self.assertEqual(101, request.responseCode) def test_renderWrongUpgrade(self): """ If the C{Upgrade} header contains an invalid value, L{WebSocketsResource} returns a failed request. """ request = DummyRequest("/") request.headers.update({ "upgrade": "wrong", "connection": "Upgrade", "sec-websocket-key": "secure", "sec-websocket-version": "13" }) self.assertRequestFail(request) def test_renderNoUpgrade(self): """ If the C{Upgrade} header is not set, L{WebSocketsResource} returns a failed request. """ request = DummyRequest("/") request.headers.update({ "connection": "Upgrade", "sec-websocket-key": "secure", "sec-websocket-version": "13" }) self.assertRequestFail(request) def test_renderPOST(self): """ If the method is not C{GET}, L{WebSocketsResource} returns a failed request. """ request = DummyRequest("/") request.method = "POST" request.headers.update({ "upgrade": "Websocket", "connection": "Upgrade", "sec-websocket-key": "secure", "sec-websocket-version": "13" }) self.assertRequestFail(request) def test_renderWrongConnection(self): """ If the C{Connection} header contains an invalid value, L{WebSocketsResource} returns a failed request. """ request = DummyRequest("/") request.headers.update({ "upgrade": "Websocket", "connection": "Wrong", "sec-websocket-key": "secure", "sec-websocket-version": "13" }) self.assertRequestFail(request) def test_renderNoConnection(self): """ If the C{Connection} header is not set, L{WebSocketsResource} returns a failed request. """ request = DummyRequest("/") request.headers.update({ "upgrade": "Websocket", "sec-websocket-key": "secure", "sec-websocket-version": "13" }) self.assertRequestFail(request) def test_renderNoKey(self): """ If the C{Sec-WebSocket-Key} header is not set, L{WebSocketsResource} returns a failed request. """ request = DummyRequest("/") request.headers.update({ "upgrade": "Websocket", "connection": "Upgrade", "sec-websocket-version": "13" }) self.assertRequestFail(request) def test_renderWrongVersion(self): """ If the value of the C{Sec-WebSocket-Version} is not 13, L{WebSocketsResource} returns a failed request. """ request = DummyRequest("/") request.headers.update({ "upgrade": "Websocket", "connection": "Upgrade", "sec-websocket-key": "secure", "sec-websocket-version": "11" }) result = self.resource.render(request) self.assertEqual("", result) self.assertEqual({"sec-websocket-version": "13"}, request.outgoingHeaders) self.assertEqual([], request.written) self.assertEqual(400, request.responseCode) def test_renderNoProtocol(self): """ If the underlying factory doesn't return any protocol, L{WebSocketsResource} returns a failed request with a C{502} code. """ request = DummyRequest("/") request.requestHeaders = Headers() request.transport = StringTransportWithDisconnection() self.echoProtocol = None request.headers.update({ "upgrade": "Websocket", "connection": "Upgrade", "sec-websocket-key": "secure", "sec-websocket-version": "13" }) result = self.resource.render(request) self.assertEqual("", result) self.assertEqual({}, request.outgoingHeaders) self.assertEqual([], request.written) self.assertEqual(502, request.responseCode) def test_renderSecureRequest(self): """ When the rendered request is over HTTPS, L{WebSocketsResource} wraps the protocol of the C{TLSMemoryBIOProtocol} instance. """ request = DummyRequest("/") request.requestHeaders = Headers() transport = StringTransportWithDisconnection() secureProtocol = TLSMemoryBIOProtocol(Factory(), Protocol()) transport.protocol = secureProtocol request.transport = transport request.headers.update({ "upgrade": "Websocket", "connection": "Upgrade", "sec-websocket-key": "secure", "sec-websocket-version": "13" }) result = self.resource.render(request) self.assertEqual(NOT_DONE_YET, result) self.assertEqual( { "connection": "Upgrade", "upgrade": "WebSocket", "sec-websocket-accept": "oYBv54i42V5dw6KnZqOFroecUTc=" }, request.outgoingHeaders) self.assertEqual([""], request.written) self.assertEqual(101, request.responseCode) self.assertIdentical(None, request.transport) self.assertIsInstance(transport.protocol.wrappedProtocol, WebSocketsProtocol) self.assertIsInstance(transport.protocol.wrappedProtocol._receiver, SavingEchoReceiver) def test_renderRealRequest(self): """ The request managed by L{WebSocketsResource.render} doesn't contain unnecessary HTTP headers like I{Content-Type} or I{Transfer-Encoding}. """ channel = DummyChannel() channel.transport = StringTransportWithDisconnection() channel.transport.protocol = channel request = Request(channel, False) headers = { "upgrade": "Websocket", "connection": "Upgrade", "sec-websocket-key": "secure", "sec-websocket-version": "13" } for key, value in headers.items(): request.requestHeaders.setRawHeaders(key, [value]) request.method = "GET" request.clientproto = "HTTP/1.1" result = self.resource.render(request) self.assertEqual(NOT_DONE_YET, result) self.assertEqual( [("Connection", ["Upgrade"]), ("Upgrade", ["WebSocket"]), ("Sec-Websocket-Accept", ["oYBv54i42V5dw6KnZqOFroecUTc="])], list(request.responseHeaders.getAllRawHeaders())) self.assertEqual( "HTTP/1.1 101 Switching Protocols\r\n" "Connection: Upgrade\r\n" "Upgrade: WebSocket\r\n" "Sec-Websocket-Accept: oYBv54i42V5dw6KnZqOFroecUTc=\r\n\r\n", channel.transport.value()) self.assertEqual(101, request.code) self.assertIdentical(None, request.transport) def test_renderIProtocol(self): """ If the protocol returned by C{lookupProtocol} isn't a C{WebSocketsProtocol}, L{WebSocketsResource} wraps it automatically with L{WebSocketsProtocolWrapper}. """ def lookupProtocol(names, otherRequest): return AccumulatingProtocol(), None self.resource = WebSocketsResource(lookupProtocol) request = DummyRequest("/") request.requestHeaders = Headers() transport = StringTransportWithDisconnection() transport.protocol = Protocol() request.transport = transport request.headers.update({ "upgrade": "Websocket", "connection": "Upgrade", "sec-websocket-key": "secure", "sec-websocket-version": "13" }) result = self.resource.render(request) self.assertEqual(NOT_DONE_YET, result) self.assertIsInstance(transport.protocol, WebSocketsProtocolWrapper) self.assertIsInstance(transport.protocol.wrappedProtocol, AccumulatingProtocol)