def checkMessageProxyProtocol(self, receivedProxyPayload, source, destination, isTCP, values=[]): proxy = ProxyProtocol() self.assertTrue(proxy.parseHeader(receivedProxyPayload)) self.assertEquals(proxy.version, 0x02) self.assertEquals(proxy.command, 0x01) self.assertEquals(proxy.family, 0x01) if not isTCP: self.assertEquals(proxy.protocol, 0x02) else: self.assertEquals(proxy.protocol, 0x01) self.assertGreater(proxy.contentLen, 0) self.assertTrue(proxy.parseAddressesAndPorts(receivedProxyPayload)) self.assertEquals(proxy.source, source) self.assertEquals(proxy.destination, destination) #self.assertEquals(proxy.sourcePort, sourcePort) self.assertEquals(proxy.destinationPort, self._dnsDistPort) self.assertTrue(proxy.parseAdditionalValues(receivedProxyPayload)) proxy.values.sort() values.sort() self.assertEquals(proxy.values, values)
def testIncomingProxyDest(self): """ Unexpected Proxy Protocol: should be dropped """ name = 'with-proxy-payload.unexpected-protocol-incoming.tests.powerdns.com.' query = dns.message.make_query(name, 'A', 'IN') # Make sure that the proxy payload does NOT turn into a legal qname destAddr = "ff:db8::ffff" destPort = 65535 srcAddr = "ff:db8::ffff" srcPort = 65535 udpPayload = ProxyProtocol.getPayload(False, False, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ]) (_, receivedResponse) = self.sendUDPQuery(udpPayload + query.to_wire(), response=None, useQueue=False, rawQuery=True) self.assertEqual(receivedResponse, None) tcpPayload = ProxyProtocol.getPayload(False, True, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ]) wire = query.to_wire() receivedResponse = None try: conn = self.openTCPConnection(2.0) conn.send(tcpPayload) conn.send(struct.pack("!H", len(wire))) conn.send(wire) receivedResponse = self.recvTCPResponseOverConnection(conn) except socket.timeout: print('timeout') self.assertEqual(receivedResponse, None)
def testProxyUDPWithValueOverride(self): """ Incoming Proxy Protocol: override existing value (UDP) """ name = 'override.proxy-protocol-incoming.tests.powerdns.com.' query = dns.message.make_query(name, 'A', 'IN') response = dns.message.make_response(query) destAddr = "2001:db8::9" destPort = 9999 srcAddr = "2001:db8::8" srcPort = 8888 response = dns.message.make_response(query) udpPayload = ProxyProtocol.getPayload(False, False, True, srcAddr, destAddr, srcPort, destPort, [ [2, b'foo'], [3, b'proxy'], [ 50, b'initial-value']]) toProxyQueue.put(response, True, 2.0) (_, receivedResponse) = self.sendUDPQuery(udpPayload + query.to_wire(), response=None, useQueue=False, rawQuery=True) (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0) self.assertTrue(receivedProxyPayload) self.assertTrue(receivedDNSData) self.assertTrue(receivedResponse) receivedQuery = dns.message.from_wire(receivedDNSData) receivedQuery.id = query.id receivedResponse.id = response.id self.assertEqual(receivedQuery, query) self.assertEqual(receivedResponse, response) self.checkMessageProxyProtocol(receivedProxyPayload, srcAddr, destAddr, False, [ [50, b'overridden'] ], True, srcPort, destPort)
def testNOTIFY(self): """ Check that NOTIFY is properly accepted/rejected based on the PROXY header inner address """ query = dns.message.make_query('example.org', 'SOA') query.set_opcode(dns.opcode.NOTIFY) queryPayload = query.to_wire() for task in ('192.0.2.1', dns.rcode.NOERROR), ('192.0.2.2', dns.rcode.REFUSED): ip, expectedrcode = task ppPayload = ProxyProtocol.getPayload(False, False, False, ip, "10.1.2.3", 12345, 53, []) payload = ppPayload + queryPayload self._sock.settimeout(2.0) try: self._sock.send(payload) data = self._sock.recv(4096) except socket.timeout: data = None finally: self._sock.settimeout(None) res = None if data: res = dns.message.from_wire(data) self.assertRcodeEqual(res, expectedrcode)
def sendUDPQueryWithProxyProtocol(cls, query, v6, source, destination, sourcePort, destinationPort, values=[], timeout=2.0): queryPayload = query.to_wire() ppPayload = ProxyProtocol.getPayload(False, False, v6, source, destination, sourcePort, destinationPort, values) payload = ppPayload + queryPayload if timeout: cls._sock.settimeout(timeout) try: cls._sock.send(payload) data = cls._sock.recv(4096) except socket.timeout: data = None finally: if timeout: cls._sock.settimeout(None) message = None if data: message = dns.message.from_wire(data) return message
def testTooLargeProxyProtocol(self): # the total payload (proxy protocol + DNS) is larger than proxy-protocol-maximum-size # so it should be dropped qname = 'too-large.proxy-protocol.recursor-tests.powerdns.com.' expected = dns.rrset.from_text(qname, 0, dns.rdataclass.IN, 'A', '192.0.2.1') query = dns.message.make_query(qname, 'A', want_dnssec=True) queryPayload = query.to_wire() ppPayload = ProxyProtocol.getPayload( False, True, False, '127.0.0.42', '255.255.255.255', 0, 65535, [[0, b'foo'], [1, b'A' * 512], [255, b'bar']]) payload = ppPayload + queryPayload # UDP self._sock.settimeout(2.0) try: self._sock.send(payload) data = self._sock.recv(4096) except socket.timeout: data = None finally: self._sock.settimeout(None) res = None if data: res = dns.message.from_wire(data) self.assertEqual(res, None) # TCP sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(2.0) sock.connect(("127.0.0.1", self._recursorPort)) try: sock.send(ppPayload) sock.send(struct.pack("!H", len(queryPayload))) sock.send(queryPayload) data = sock.recv(2) if data: (datalen, ) = struct.unpack("!H", data) data = sock.recv(datalen) except socket.timeout as e: print("Timeout: %s" % (str(e))) data = None except socket.error as e: print("Network error: %s" % (str(e))) data = None finally: sock.close() res = None if data: res = dns.message.from_wire(data) self.assertEqual(res, None)
def ProxyProtocolUDPResponder(port, fromQueue, toQueue): sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) try: sock.bind(("127.0.0.1", port)) except socket.error as e: print("Error binding in the Proxy Protocol UDP responder: %s" % str(e)) sys.exit(1) while True: data, addr = sock.recvfrom(4096) proxy = ProxyProtocol() if len(data) < proxy.HEADER_SIZE: continue if not proxy.parseHeader(data): continue if proxy.local: # likely a healthcheck data = data[proxy.HEADER_SIZE:] request = dns.message.from_wire(data) response = dns.message.make_response(request) wire = response.to_wire() sock.settimeout(2.0) sock.sendto(wire, addr) sock.settimeout(None) continue payload = data[:(proxy.HEADER_SIZE + proxy.contentLen)] dnsData = data[(proxy.HEADER_SIZE + proxy.contentLen):] toQueue.put([payload, dnsData], True, 2.0) # computing the correct ID for the response request = dns.message.from_wire(dnsData) response = fromQueue.get(True, 2.0) response.id = request.id sock.settimeout(2.0) sock.sendto(response.to_wire(), addr) sock.settimeout(None) sock.close()
def testIncomingProxyDest(self): """ Incoming Proxy Protocol: values from Lua """ name = 'get-forwarded-dest.proxy-protocol-incoming.tests.powerdns.com.' query = dns.message.make_query(name, 'A', 'IN') # dnsdist set RA = RD for spoofed responses query.flags &= ~dns.flags.RD destAddr = "2001:db8::9" destPort = 9999 srcAddr = "2001:db8::8" srcPort = 8888 response = dns.message.make_response(query) rrset = dns.rrset.from_text( name, 60, dns.rdataclass.IN, dns.rdatatype.CNAME, "address-was-{}-port-was-{}.proxy-protocol-incoming.tests.powerdns.com." .format(destAddr, destPort, self._dnsDistPort)) response.answer.append(rrset) udpPayload = ProxyProtocol.getPayload(False, False, True, srcAddr, destAddr, srcPort, destPort, [[2, b'foo'], [3, b'proxy']]) (_, receivedResponse) = self.sendUDPQuery(udpPayload + query.to_wire(), response=None, useQueue=False, rawQuery=True) self.assertEqual(receivedResponse, response) tcpPayload = ProxyProtocol.getPayload(False, True, True, srcAddr, destAddr, srcPort, destPort, [[2, b'foo'], [3, b'proxy']]) wire = query.to_wire() receivedResponse = None try: conn = self.openTCPConnection(2.0) conn.send(tcpPayload) conn.send(struct.pack("!H", len(wire))) conn.send(wire) receivedResponse = self.recvTCPResponseOverConnection(conn) except socket.timeout: print('timeout') self.assertEqual(receivedResponse, response)
def testLocalProxyProtocol(self): qname = 'local.proxy-protocol.recursor-tests.powerdns.com.' expected = dns.rrset.from_text(qname, 0, dns.rdataclass.IN, 'A', '192.0.2.255') query = dns.message.make_query(qname, 'A', want_dnssec=True) queryPayload = query.to_wire() ppPayload = ProxyProtocol.getPayload(True, False, False, None, None, None, None, []) payload = ppPayload + queryPayload # UDP self._sock.settimeout(2.0) try: self._sock.send(payload) data = self._sock.recv(4096) except socket.timeout: data = None finally: self._sock.settimeout(None) res = None if data: res = dns.message.from_wire(data) self.assertRcodeEqual(res, dns.rcode.NOERROR) self.assertRRsetInAnswer(res, expected) # TCP sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(2.0) sock.connect(("127.0.0.1", self._recursorPort)) try: sock.send(ppPayload) sock.send(struct.pack("!H", len(queryPayload))) sock.send(queryPayload) data = sock.recv(2) if data: (datalen, ) = struct.unpack("!H", data) data = sock.recv(datalen) except socket.timeout as e: print("Timeout: %s" % (str(e))) data = None except socket.error as e: print("Network error: %s" % (str(e))) data = None finally: sock.close() res = None if data: res = dns.message.from_wire(data) self.assertRcodeEqual(res, dns.rcode.NOERROR) self.assertRRsetInAnswer(res, expected)
def testInvalidMagicProxyProtocol(self): qname = 'invalid-magic.proxy-protocol.recursor-tests.powerdns.com.' query = dns.message.make_query(qname, 'A', want_dnssec=True) queryPayload = query.to_wire() ppPayload = ProxyProtocol.getPayload(True, False, False, None, None, None, None, []) ppPayload = b'\x00' + ppPayload[1:] payload = ppPayload + queryPayload # UDP self._sock.settimeout(2.0) try: self._sock.send(payload) data = self._sock.recv(4096) except socket.timeout: data = None finally: self._sock.settimeout(None) res = None if data: res = dns.message.from_wire(data) self.assertEqual(res, None) # TCP sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(2.0) sock.connect(("127.0.0.1", self._recursorPort)) try: sock.send(ppPayload) sock.send(struct.pack("!H", len(queryPayload))) sock.send(queryPayload) data = sock.recv(2) if data: (datalen, ) = struct.unpack("!H", data) data = sock.recv(datalen) except socket.timeout as e: print("Timeout: %s" % (str(e))) data = None except socket.error as e: print("Network error: %s" % (str(e))) data = None finally: sock.close() res = None if data: res = dns.message.from_wire(data) self.assertEqual(res, None)
def testProxyTCPWithValuesFromLua(self): """ Incoming Proxy Protocol: values from Lua (TCP) """ name = 'values-lua.proxy-protocol-incoming.tests.powerdns.com.' query = dns.message.make_query(name, 'A', 'IN') response = dns.message.make_response(query) destAddr = "2001:db8::9" destPort = 9999 srcAddr = "2001:db8::8" srcPort = 8888 response = dns.message.make_response(query) tcpPayload = ProxyProtocol.getPayload(False, True, True, srcAddr, destAddr, srcPort, destPort, [[2, b'foo'], [3, b'proxy']]) toProxyQueue.put(response, True, 2.0) wire = query.to_wire() receivedResponse = None try: conn = self.openTCPConnection(2.0) conn.send(tcpPayload) conn.send(struct.pack("!H", len(wire))) conn.send(wire) receivedResponse = self.recvTCPResponseOverConnection(conn) except socket.timeout: print('timeout') self.assertEqual(receivedResponse, response) (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0) self.assertTrue(receivedProxyPayload) self.assertTrue(receivedDNSData) self.assertTrue(receivedResponse) receivedQuery = dns.message.from_wire(receivedDNSData) receivedQuery.id = query.id receivedResponse.id = response.id self.assertEqual(receivedQuery, query) self.assertEqual(receivedResponse, response) self.checkMessageProxyProtocol( receivedProxyPayload, srcAddr, destAddr, True, [[0, b'foo'], [1, b'dnsdist'], [2, b'foo'], [3, b'proxy'], [42, b'bar'], [255, b'proxy-protocol']], True, srcPort, destPort)
def testIPv6ProxyProtocolSeveralQueriesOverTCP(self): qname = 'several-queries-tcp.proxy-protocol.recursor-tests.powerdns.com.' expected = dns.rrset.from_text(qname, 0, dns.rdataclass.IN, 'A', '192.0.2.1') query = dns.message.make_query(qname, 'A', want_dnssec=True) queryPayload = query.to_wire() ppPayload = ProxyProtocol.getPayload(False, True, True, '::42', '2001:db8::ff', 0, 65535, [[0, b'foo'], [255, b'bar']]) payload = ppPayload + queryPayload # TCP sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(2.0) sock.connect(("127.0.0.1", self._recursorPort)) sock.send(ppPayload) count = 0 for idx in range(5): try: sock.send(struct.pack("!H", len(queryPayload))) sock.send(queryPayload) data = sock.recv(2) if data: (datalen, ) = struct.unpack("!H", data) data = sock.recv(datalen) except socket.timeout as e: print("Timeout: %s" % (str(e))) data = None break except socket.error as e: print("Network error: %s" % (str(e))) data = None break res = None if data: res = dns.message.from_wire(data) self.assertRcodeEqual(res, dns.rcode.NOERROR) self.assertRRsetInAnswer(res, expected) count = count + 1 self.assertEqual(count, 5) sock.close()
def testTCPOneByteAtATimeProxyProtocol(self): qname = 'tcp-one-byte-at-a-time.proxy-protocol.recursor-tests.powerdns.com.' expected = dns.rrset.from_text(qname, 0, dns.rdataclass.IN, 'A', '192.0.2.1') query = dns.message.make_query(qname, 'A', want_dnssec=True) queryPayload = query.to_wire() ppPayload = ProxyProtocol.getPayload(False, True, False, '127.0.0.42', '255.255.255.255', 0, 65535, [[0, b'foo'], [255, b'bar']]) # TCP sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(2.0) sock.connect(("127.0.0.1", self._recursorPort)) try: for i in range(len(ppPayload)): sock.send(ppPayload[i:i + 1]) time.sleep(0.01) value = struct.pack("!H", len(queryPayload)) for i in range(len(value)): sock.send(value[i:i + 1]) time.sleep(0.01) for i in range(len(queryPayload)): sock.send(queryPayload[i:i + 1]) time.sleep(0.01) data = sock.recv(2) if data: (datalen, ) = struct.unpack("!H", data) data = sock.recv(datalen) except socket.timeout as e: print("Timeout: %s" % (str(e))) data = None except socket.error as e: print("Network error: %s" % (str(e))) data = None finally: sock.close() res = None if data: res = dns.message.from_wire(data) self.assertRcodeEqual(res, dns.rcode.NOERROR) self.assertRRsetInAnswer(res, expected)
def testAXFR(self): """ Check that AXFR is properly accepted/rejected based on the PROXY header inner address """ query = dns.message.make_query('example.org', 'AXFR') queryPayload = query.to_wire() for task in ('192.0.2.1', dns.rcode.NOTAUTH), ( '127.0.0.1', dns.rcode.NOTAUTH), ('192.0.2.53', dns.rcode.NOERROR): ip, expectedrcode = task ppPayload = ProxyProtocol.getPayload(False, True, False, ip, "10.1.2.3", 12345, 53, []) # TCP sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(2.0) sock.connect(("127.0.0.1", self._authPort)) try: sock.send(ppPayload) sock.send(struct.pack("!H", len(queryPayload))) sock.send(queryPayload) data = sock.recv(2) if data: (datalen, ) = struct.unpack("!H", data) data = sock.recv(datalen) except socket.timeout as e: print("Timeout: %s" % (str(e))) data = None except socket.error as e: print("Network error: %s" % (str(e))) data = None finally: sock.close() res = None if data: res = dns.message.from_wire(data) self.assertRcodeEqual(res, expectedrcode)
def sendTCPQueryWithProxyProtocol(cls, query, v6, source, destination, sourcePort, destinationPort, values=[], timeout=2.0): queryPayload = query.to_wire() ppPayload = ProxyProtocol.getPayload(False, False, v6, source, destination, sourcePort, destinationPort, values) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) if timeout: sock.settimeout(timeout) sock.connect(("127.0.0.1", cls._recursorPort)) try: sock.send(ppPayload) sock.send(struct.pack("!H", len(queryPayload))) sock.send(queryPayload) data = sock.recv(2) if data: (datalen, ) = struct.unpack("!H", data) data = sock.recv(datalen) except socket.timeout as e: print("Timeout: %s" % (str(e))) data = None except socket.error as e: print("Network error: %s" % (str(e))) data = None finally: sock.close() message = None if data: message = dns.message.from_wire(data) return message
def handleDoHConnection(cls, config, conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, tlsContext, useProxyProtocol): ignoreTrailing = trailingDataResponse is True try: h2conn = h2.connection.H2Connection(config=config) h2conn.initiate_connection() conn.sendall(h2conn.data_to_send()) except ssl.SSLEOFError as e: print("Unexpected EOF: %s" % (e)) return dnsData = {} if useProxyProtocol: # try to read the entire Proxy Protocol header proxy = ProxyProtocol() header = conn.recv(proxy.HEADER_SIZE) if not header: print('unable to get header') conn.close() return if not proxy.parseHeader(header): print('unable to parse header') print(header) conn.close() return proxyContent = conn.recv(proxy.contentLen) if not proxyContent: print('unable to get content') conn.close() return payload = header + proxyContent toQueue.put(payload, True, cls._queueTimeout) # be careful, HTTP/2 headers and data might be in different recv() results requestHeaders = None while True: data = conn.recv(65535) if not data: break events = h2conn.receive_data(data) for event in events: if isinstance(event, h2.events.RequestReceived): requestHeaders = event.headers if isinstance(event, h2.events.DataReceived): h2conn.acknowledge_received_data( event.flow_controlled_length, event.stream_id) if not event.stream_id in dnsData: dnsData[event.stream_id] = b'' dnsData[event.stream_id] = dnsData[event.stream_id] + ( event.data) if event.stream_ended: forceRcode = None status = 200 try: request = dns.message.from_wire( dnsData[event.stream_id], ignore_trailing=ignoreTrailing) except dns.message.TrailingJunk as e: if trailingDataResponse is False or forceRcode is True: raise print( "DOH query with trailing data, synthesizing response" ) request = dns.message.from_wire( dnsData[event.stream_id], ignore_trailing=True) forceRcode = trailingDataResponse if callback: status, wire = callback(request, requestHeaders, fromQueue, toQueue) else: response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode) if response: wire = response.to_wire(max_size=65535) if not wire: conn.close() conn = None break headers = [ (':status', str(status)), ('content-length', str(len(wire))), ('content-type', 'application/dns-message'), ] h2conn.send_headers(stream_id=event.stream_id, headers=headers) h2conn.send_data(stream_id=event.stream_id, data=wire, end_stream=True) data_to_send = h2conn.data_to_send() if data_to_send: conn.sendall(data_to_send) if conn is None: break if conn is not None: conn.close()
def ProxyProtocolTCPResponder(port, fromQueue, toQueue): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) try: sock.bind(("127.0.0.1", port)) except socket.error as e: print("Error binding in the TCP responder: %s" % str(e)) sys.exit(1) sock.listen(100) while True: (conn, _) = sock.accept() conn.settimeout(5.0) # try to read the entire Proxy Protocol header proxy = ProxyProtocol() header = conn.recv(proxy.HEADER_SIZE) if not header: conn.close() continue if not proxy.parseHeader(header): conn.close() continue proxyContent = conn.recv(proxy.contentLen) if not proxyContent: conn.close() continue payload = header + proxyContent while True: try: data = conn.recv(2) except socket.timeout: data = None if not data: conn.close() break (datalen, ) = struct.unpack("!H", data) data = conn.recv(datalen) toQueue.put([payload, data], True, 2.0) response = fromQueue.get(True, 2.0) if not response: conn.close() break # computing the correct ID for the response request = dns.message.from_wire(data) response.id = request.id wire = response.to_wire() conn.send(struct.pack("!H", len(wire))) conn.send(wire) conn.close() sock.close()
def ProxyProtocolTCPResponder(port, fromQueue, toQueue): # be aware that this responder will not accept a new connection # until the last one has been closed. This is done on purpose to # to check for connection reuse, making sure that a lot of connections # are not opened in parallel. sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) try: sock.bind(("127.0.0.1", port)) except socket.error as e: print("Error binding in the TCP responder: %s" % str(e)) sys.exit(1) sock.listen(100) while True: (conn, _) = sock.accept() conn.settimeout(5.0) # try to read the entire Proxy Protocol header proxy = ProxyProtocol() header = conn.recv(proxy.HEADER_SIZE) if not header: conn.close() continue if not proxy.parseHeader(header): conn.close() continue proxyContent = conn.recv(proxy.contentLen) if not proxyContent: conn.close() continue payload = header + proxyContent while True: try: data = conn.recv(2) except socket.timeout: data = None if not data: conn.close() break (datalen,) = struct.unpack("!H", data) data = conn.recv(datalen) toQueue.put([payload, data], True, 2.0) response = copy.deepcopy(fromQueue.get(True, 2.0)) if not response: conn.close() break # computing the correct ID for the response request = dns.message.from_wire(data) response.id = request.id wire = response.to_wire() conn.send(struct.pack("!H", len(wire))) conn.send(wire) conn.close() sock.close()
def testWhoAmI(self): """ See if LUA who picks up the inner address from the PROXY protocol """ for testWithECS in True, False: # first test with an unproxied query - should get ignored options = [] expectedText = '192.0.2.1/192.0.2.1' if testWithECS: ecso = clientsubnetoption.ClientSubnetOption('192.0.2.5', 32) options.append(ecso) expectedText = '192.0.2.1/192.0.2.5' query = dns.message.make_query('myip.example.org', 'TXT', 'IN', use_edns=testWithECS, options=options, payload=512) res = self.sendUDPQuery(query) self.assertEqual(res, None) # query was ignored correctly # now send a proxied query queryPayload = query.to_wire() ppPayload = ProxyProtocol.getPayload(False, False, False, "192.0.2.1", "10.1.2.3", 12345, 53, []) payload = ppPayload + queryPayload # UDP self._sock.settimeout(2.0) try: self._sock.send(payload) data = self._sock.recv(4096) except socket.timeout: data = None finally: self._sock.settimeout(None) res = None if data: res = dns.message.from_wire(data) expected = [ dns.rrset.from_text('myip.example.org.', 0, dns.rdataclass.IN, 'TXT', expectedText) ] self.assertRcodeEqual(res, dns.rcode.NOERROR) self.assertEqual(res.answer, expected) # TCP sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(2.0) sock.connect(("127.0.0.1", self._authPort)) try: sock.send(ppPayload) sock.send(struct.pack("!H", len(queryPayload))) sock.send(queryPayload) data = sock.recv(2) if data: (datalen, ) = struct.unpack("!H", data) data = sock.recv(datalen) except socket.timeout as e: print("Timeout: %s" % (str(e))) data = None except socket.error as e: print("Network error: %s" % (str(e))) data = None finally: sock.close() res = None if data: res = dns.message.from_wire(data) self.assertRcodeEqual(res, dns.rcode.NOERROR) self.assertEqual(res.answer, expected)
def DOHResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None, useProxyProtocol=False): # trailingDataResponse=True means "ignore trailing data". # Other values are either False (meaning "raise an exception") # or are interpreted as a response RCODE for queries with trailing data. # callback is invoked for every -even healthcheck ones- query and should return a raw response ignoreTrailing = trailingDataResponse is True sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) try: sock.bind(("127.0.0.1", port)) except socket.error as e: print("Error binding in the TCP responder: %s" % str(e)) sys.exit(1) sock.listen(100) if tlsContext: sock = tlsContext.wrap_socket(sock, server_side=True) config = h2.config.H2Configuration(client_side=False) while True: try: (conn, _) = sock.accept() except ssl.SSLError: continue except ConnectionResetError: continue conn.settimeout(5.0) h2conn = h2.connection.H2Connection(config=config) h2conn.initiate_connection() conn.sendall(h2conn.data_to_send()) dnsData = {} if useProxyProtocol: # try to read the entire Proxy Protocol header proxy = ProxyProtocol() header = conn.recv(proxy.HEADER_SIZE) if not header: print('unable to get header') conn.close() continue if not proxy.parseHeader(header): print('unable to parse header') print(header) conn.close() continue proxyContent = conn.recv(proxy.contentLen) if not proxyContent: print('unable to get content') conn.close() continue payload = header + proxyContent toQueue.put(payload, True, cls._queueTimeout) while True: data = conn.recv(65535) if not data: break events = h2conn.receive_data(data) for event in events: if isinstance(event, h2.events.DataReceived): h2conn.acknowledge_received_data( event.flow_controlled_length, event.stream_id) if not event.stream_id in dnsData: dnsData[event.stream_id] = b'' dnsData[event.stream_id] = dnsData[event.stream_id] + ( event.data) if event.stream_ended: forceRcode = None status = 200 try: request = dns.message.from_wire( dnsData[event.stream_id], ignore_trailing=ignoreTrailing) except dns.message.TrailingJunk as e: if trailingDataResponse is False or forceRcode is True: raise print( "DOH query with trailing data, synthesizing response" ) request = dns.message.from_wire( dnsData[event.stream_id], ignore_trailing=True) forceRcode = trailingDataResponse if callback: status, wire = callback(request) else: response = cls._getResponse( request, fromQueue, toQueue, synthesize=forceRcode) if response: wire = response.to_wire(max_size=65535) if not wire: conn.close() conn = None break headers = [ (':status', str(status)), ('content-length', str(len(wire))), ('content-type', 'application/dns-message'), ] h2conn.send_headers(stream_id=event.stream_id, headers=headers) h2conn.send_data(stream_id=event.stream_id, data=wire, end_stream=True) data_to_send = h2conn.data_to_send() if data_to_send: conn.sendall(data_to_send) if conn is None: break if conn is not None: conn.close() sock.close()