async def test_get_request_empty_body(self):
     """ Test that when an empty body is provided, we return 400.
     """
     params = utils.build_query_params(self.dnsq.to_wire())
     params[constants.DOH_DNS_PARAM] = ''
     request = await self.client.request(
         self.method,
         self.endpoint,
         params=params)
     self.assertEqual(request.status, 400)
     content = await request.read()
     self.assertEqual(content, b'Missing Body')
Example #2
0
 async def test_get_request_no_content_type(self):
     """ Test that when no ct parameter, we fail with missing content type
     parameter.
     """
     params = utils.build_query_params(self.dnsq.to_wire())
     del params[constants.DOH_CONTENT_TYPE_PARAM]
     request = await self.client.request(self.method,
                                         self.endpoint,
                                         params=params)
     self.assertEqual(request.status, 400)
     content = await request.read()
     self.assertEqual(content, b'Missing Content Type Parameter')
Example #3
0
    async def test_get_valid_request(self, resolve):
        """ Test that when we run a valid GET request, resolve will be called
        and returns some content, here echoes the request.
        """
        resolve.return_value = echo_dns_q(self.dnsq)
        params = utils.build_query_params(self.dnsq.to_wire())
        request = await self.client.request(self.method,
                                            self.endpoint,
                                            params=params)
        self.assertEqual(request.status, 200)
        content = await request.read()

        self.assertEqual(self.dnsq, dns.message.from_wire(content))
Example #4
0
 async def test_get_request_no_content_type(self, resolve):
     """ Test that when no ct parameter, we accept the query.
     content-type is not used in GET request anymore, so it will default to
     'application/dns-message'
     """
     resolve.return_value = echo_dns_q(self.dnsq)
     params = utils.build_query_params(self.dnsq.to_wire())
     request = await self.client.request(self.method,
                                         self.endpoint,
                                         params=params)
     self.assertEqual(request.status, 200)
     content = await request.read()
     self.assertEqual(self.dnsq, dns.message.from_wire(content))
Example #5
0
 async def test_get_request_bad_content_type(self, resolve):
     """ Test that when an invalid content-type is provided, we return 200.
     content-type is not used in GET request anymore, so it will default to
     'application/dns-message'
     """
     resolve.return_value = echo_dns_q(self.dnsq)
     params = utils.build_query_params(self.dnsq.to_wire())
     params['ct'] = 'bad/type'
     request = await self.client.request(self.method,
                                         self.endpoint,
                                         params=params)
     self.assertEqual(request.status, 200)
     content = await request.read()
     self.assertEqual(self.dnsq, dns.message.from_wire(content))
Example #6
0
    async def test_get_valid_request(self, resolve):
        """ Test that when we run a valid GET request, resolve will be called
        and returns some content, here echoes the request.
        """
        resolve.return_value = aiohttp.web.Response(
            status=200,
            body=self.dnsq.to_wire(),
            content_type=constants.DOH_MEDIA_TYPE,
        )
        params = utils.build_query_params(self.dnsq.to_wire())
        request = await self.client.request(self.method,
                                            self.endpoint,
                                            params=params)
        self.assertEqual(request.status, 200)
        content = await request.read()

        self.assertEqual(self.dnsq, dns.message.from_wire(content))
Example #7
0
    async def test_mock_dnsclient_assigned_logger(self, MockedDNSClient,
                                                  Mockedon_answer,
                                                  Mockedquery):
        """ Test that when MockedDNSClient is created with the doh-httpproxy
        logger and DEBUG level
        """
        Mockedquery.return_value = self.dnsq
        Mockedon_answer.return_value = aiohttp.web.Response(status=200,
                                                            body=b'Done')
        params = utils.build_query_params(self.dnsq.to_wire())
        request = await self.client.request('GET',
                                            self.endpoint,
                                            params=params)
        request.remote = "127.0.0.1"
        app = await self.get_application()
        await app.resolve(request, self.dnsq)

        mylogger = utils.configure_logger(name='doh-httpproxy', level='DEBUG')
        MockedDNSClient.assert_called_with(app.upstream_resolver,
                                           app.upstream_port,
                                           logger=mylogger)
Example #8
0
 def test_body_b64encoded(self):
     """ Check that this function is b64 encoding the content of body. """
     q = b''
     params = utils.build_query_params(q)
     self.assertEqual(utils.doh_b64_encode(q),
                      params[constants.DOH_DNS_PARAM])
Example #9
0
 def test_query_accepts_bytes(self):
     """ Check that this function accepts a bytes-object. """
     utils.build_query_params(b'')
Example #10
0
 def test_query_must_be_bytes(self):
     """ Check that this function raises when we pass a string. """
     with self.assertRaises(TypeError):
         utils.build_query_params('')
Example #11
0
 def test_has_right_keys(self):
     """ Check that this function returns body only. """
     keys = {
         constants.DOH_DNS_PARAM,
     }
     self.assertEqual(keys, utils.build_query_params(b'').keys())
Example #12
0
    async def make_request(self, addr, dnsq):

        # FIXME: maybe aioh2 should allow registering to connection_lost event
        # so we can find out when the connection get disconnected.
        with await self._lock:
            if self.client is None or self.client._conn is None:
                await self.setup_client()

            client = self.client

        headers = {'Accept': constants.DOH_MEDIA_TYPE}
        path = self.args.uri
        qid = dnsq.id
        dnsq.id = 0
        body = b''

        headers = [
            (':authority', self.args.domain),
            (':method', self.args.post and 'POST' or 'GET'),
            (':scheme', 'https'),
        ]
        if self.args.post:
            headers.append(('content-type', constants.DOH_MEDIA_TYPE))
            body = dnsq.to_wire()
        else:
            params = utils.build_query_params(dnsq.to_wire())
            self.logger.debug('Query parameters: {}'.format(params))
            params_str = urllib.parse.urlencode(params)
            if self.args.debug:
                url = utils.make_url(self.args.domain, self.args.uri)
                self.logger.debug('Sending {}?{}'.format(url, params_str))
            path = self.args.uri + '?' + params_str

        headers.insert(0, (':path', path))
        headers.extend([
            ('content-length', str(len(body))),
        ])
        # Start request with headers
        # FIXME: Find a better way to close old streams. See GH#11
        try:
            stream_id = await client.start_request(
                headers,
                end_stream=not body)
        except priority.priority.TooManyStreamsError:
            await self.setup_client()
            client = self.client
            stream_id = await client.start_request(
                headers,
                end_stream=not body)
        self.logger.debug(
            'Stream ID: {} / Total streams: {}'.format(
                stream_id, len(client._streams)
            )
        )
        # Send my name "world" as whole request body
        if body:
            await client.send_data(stream_id, body, end_stream=True)

        # Receive response headers
        headers = await client.recv_response(stream_id)
        self.logger.debug('Response headers: {}'.format(headers))

        # Read all response body
        resp = await client.read_stream(stream_id, -1)
        dnsr = dns.message.from_wire(resp)
        dnsr.id = qid
        self.on_answer(addr, dnsr.to_wire())

        # Read response trailers
        trailers = await client.recv_trailers(stream_id)
        self.logger.debug('Response trailers: {}'.format(trailers))
Example #13
0
 def test_has_right_keys(self):
     """ Check that this function returns body and ct parameters only. """
     keys = set(
         [constants.DOH_BODY_PARAM, constants.DOH_CONTENT_TYPE_PARAM]
     )
     self.assertEqual(keys, utils.build_query_params(b'').keys())