예제 #1
0
 def test_pool_manager_no_url_absolute_form(self) -> None:
     """Valides we won't send a request with absolute form without a proxy"""
     p = PoolManager()
     assert p._proxy_requires_url_absolute_form(
         Url("http://example.com")) is False
     assert p._proxy_requires_url_absolute_form(
         Url("https://example.com")) is False
예제 #2
0
파일: test_util.py 프로젝트: tyzbit/urllib3
    def test_parse_url_unicode_python_2(self):
        url = parse_url(u"https://www.google.com/")
        assert url == Url(u"https", host=u"www.google.com", path=u"/")

        assert isinstance(url.scheme, six.text_type)
        assert isinstance(url.host, six.text_type)
        assert isinstance(url.path, six.text_type)
예제 #3
0
def url_path_query(path: str, cntry: str=None, pgid: str=None,\
                   lang=None, extra=None, month=None, day=None,\
                   api_key=api_key)->str:
    """Builds and returns a string url to query `path` for `cntry`
    with optional `pgid`,`extra`,`month` and `day`."""
    query = ''
    if cntry is not None:
        query = query + 'ROG3=' + str(cntry)
    if pgid is not None:
        query = query + '&PeopleID3=' + str(pgid)
    if lang is not None:
        query = query + '&ROL3=' + str(lang)
    if extra is not None:
        query = query + '&' + str(extra)
    if 'upgotd' in path:
        today = date.today()
        if isinstance(day, int):
            day = str(day)
        if isinstance(month, int):
            month = str(month)
        if day is None:
            day = today.strftime("%-d")
        if month is None:
            month = today.strftime("%-m")
        query = query + '&LRofTheDayMonth=' + month
        query = query + '&LRofTheDayDay=' + day
    query = query + '&api_key=' + api_key
    return Url(scheme='https', host='joshuaproject.net', \
               path=path, query=query).url
예제 #4
0
 def get_info(img_id: int) -> Info:
     resp = requests.request(
         'GET',
         Url(scheme='https',
             host='picsum.photos',
             path=PicsumPhotos._get_path(img_id=img_id, info=True)))
     return Info(resp.json())
예제 #5
0
파일: test_util.py 프로젝트: tyzbit/urllib3
    def test_parse_url_bytes_to_str_python_2(self):
        url = parse_url(b"https://www.google.com/")
        assert url == Url("https", host="www.google.com", path="/")

        assert isinstance(url.scheme, str)
        assert isinstance(url.host, str)
        assert isinstance(url.path, str)
예제 #6
0
    def __init__(self, json_data: dict):
        super().__init__(Url('https', host='www.wuxiaworld.com'))
        self.id = int(json_data['id'])
        self.name = json_data['name']
        self.slug = json_data['slug']
        self.cover_url = json_data['coverUrl']
        self.abbreviation = json_data['abbreviation']
        self.synopsis = BeautifulSoup(json_data['synopsis'],
                                      features="html5lib")
        self.language = json_data['language']
        self.time_created = datetime.utcfromtimestamp(
            float(json_data['timeCreated']))
        self.sneakPeek = bool(json_data['sneakPeek'])
        self.status = Status.from_int(json_data['status'])
        self.chapter_count = int(json_data['chapterCount'])
        self.tags = list(
            map(lambda tag: NovelTag.from_str(tag), json_data['tags']))
        self.genres = list(
            map(lambda genre: Genre.from_str(genre), json_data['genres']))

        self.title = self.name
        if self.sneakPeek:
            self._url = self.alter_url(f"preview/{self.slug}")
        else:
            self._url = self.alter_url(f"novel/{self.slug}")
예제 #7
0
def service_url(function_scoped_container_getter):
    """The url of the service from a running container."""
    service_name = "pokespear"
    scheme = "http"
    container = function_scoped_container_getter.get(service_name)
    network_info = container.network_info[0]
    host = network_info.hostname
    port = network_info.host_port
    return Url(scheme=scheme, host=host, port=port).url
예제 #8
0
 def __init__(self):
     super().__init__()
     # Validate and split URL using urllib3's Url class util
     url = parse_url(self.url)
     path = url.path
     # Add a / to the end of 'path' if it doesn't already have it
     if not path.endswith("/"):
         path += "/"
     # Set API URL to this version's root path
     self.url = Url(url.scheme, url.auth, url.host, url.port, path)
예제 #9
0
 def get_list(page: int = 1, limit: int = 30) -> List[Info]:
     resp = requests.request(
         'GET',
         Url(scheme='https',
             host='picsum.photos',
             path=PicsumPhotos._get_path(list=True),
             query=PicsumPhotos._get_query(page=page, limit=limit)))
     infos = []
     for entry in resp.json():
         infos.append(Info(entry))
     return infos
예제 #10
0
 def get_image(width: int,
               height: int,
               img_id: int = -1,
               grayscale: bool = False,
               blur: int = 0) -> Tuple[int, Image]:
     resp = requests.request(
         'GET',
         Url(scheme='https',
             host='picsum.photos',
             path=PicsumPhotos._get_path(width, height, img_id),
             query=PicsumPhotos._get_query(grayscale, blur)))
     url = parse_url(resp.url)
     img_id = int(url.path.split('/')[2])
     return img_id, PILImg.open(BytesIO(resp.content))
예제 #11
0
    def test_control_characters_are_percent_encoded(self, char):
        percent_char = "%" + (hex(ord(char))[2:].zfill(2).upper())
        url = parse_url(
            "http://user{0}@example.com/path{0}?query{0}#fragment{0}".format(
                char))

        assert url == Url(
            "http",
            auth="user" + percent_char,
            host="example.com",
            path="/path" + percent_char,
            query="query" + percent_char,
            fragment="fragment" + percent_char,
        )
예제 #12
0
 def __extract_chapters(self,
                        book_html: Tag) -> List[WuxiaWorldComChapterEntry]:
     chapters = []
     chapter_index = 0
     for chapter_html in book_html.select('div div li a'):
         chapter_index += 1
         chapter = WuxiaWorldComChapterEntry(
             Url('https',
                 host='www.wuxiaworld.com',
                 path=chapter_html.get('href')),
             title=chapter_html.text.strip())
         chapter.index = chapter_index
         chapters.append(chapter)
     self.log.debug(f"Chapters found: {len(chapters)}")
     return chapters
예제 #13
0
    def _normalize_url(url):
        url = parse_url(url)

        path = url.path or ''
        if not path.endswith('/'):
            path += '/'

        url = Url(scheme=url.scheme or 'http',
                  auth=url.auth,
                  host=url.host,
                  port=url.port,
                  path=path,
                  fragment=url.fragment)

        return url.url
예제 #14
0
 def send(self):
     response = requests.request(self.method,
                                 Url(scheme="https",
                                     host=self.host,
                                     path=self.path).url,
                                 data=self.params,
                                 headers={"User-Agent": "XiaomiPCSuite"},
                                 cookies=self.auth.cookies)
     logging.debug(response)
     logging.debug(response.headers)
     response.raise_for_status()
     logging.debug(response.text)
     data = self._decrypt(response.text)
     logging.debug("query returned %s", data.decode("utf-8"))
     return data
예제 #15
0
def _setup_application():
    """Setup things that can be taken care of before io loop is started"""
    global io_loop, tornado_app, public_url, thrift_context, easy_client
    global server, client_ssl, request_scheduler, anonymous_principal

    # Tweak some config options
    config.web.url_prefix = normalize_url_prefix(config.web.url_prefix)
    if not config.auth.token.secret:
        config.auth.token.secret = os.urandom(20)
        if config.auth.enabled:
            logger.warning(
                "Brew-view was started with authentication enabled and no "
                "Secret. Generated tokens will not be valid across Brew-view "
                "restarts. To prevent this set the auth.token.secret config."
            )

    public_url = Url(scheme='https' if config.web.ssl.enabled else 'http',
                     host=config.web.public_fqdn,
                     port=config.web.port,
                     path=config.web.url_prefix).url

    # This is not super clean as we're pulling the config from different
    # 'sections,' but the scheduler is the only thing that uses this
    easy_client = EasyClient(
        host=config.web.public_fqdn,
        port=config.web.port,
        url_prefix=config.web.url_prefix,
        ssl_enabled=config.web.ssl.enabled,
        ca_cert=config.web.ssl.ca_cert,
        username=config.scheduler.auth.username,
        password=config.scheduler.auth.password,
    )

    thrift_context = _setup_thrift_context()
    tornado_app = _setup_tornado_app()
    server_ssl, client_ssl = _setup_ssl_context()
    anonymous_principal = load_anonymous()
    request_scheduler = _setup_scheduler()

    server = HTTPServer(tornado_app, ssl_options=server_ssl)
    io_loop = IOLoop.current()
예제 #16
0
    def configure(self, *args, **kwargs):
        vars = filter(lambda x: x[0].startswith('OS_'), os.environ.iteritems())
        conf_keys = self.conf.keys()
        for k, v in vars:
            # Try the full var first
            n = k.lower()
            cands = (n, n[3:])
            for var in cands:
                if var in conf_keys:
                    self.conf.set_default(name=var, default=v)
                    break

        self.conf(args[0])

        # bail using keystoneauth1 if not available.
        # FIXME: this is hacky...
        if self.conf.use_keystoneauth1 and not HAS_KEYSTONEAUTH1:
            raise Exception('Requested module keystoneauth1 is not available.')
        # adjust the logging
        if self.conf.debug:
            ch = logging.StreamHandler(stream=sys.stderr)
            ch.setLevel(logging.DEBUG)
            self.logger.addHandler(ch)
            # This is questionable...
            self._logging_handlers['debug'] = ch
            self.logger.removeHandler(self._logging_handlers['info'])
            self.logger.setLevel(logging.DEBUG)

        self.os_service_endpoint = self.conf.os_service_endpoint
        if self.os_service_endpoint is None:
            base = {'path': None}
            url = parse_url(self.conf.auth_url)
            l = list(url)[:4] + [None] * (len(url._fields) - 4)
            self.os_service_endpoint = Url(*l).url
            self.conf.set_default('os_service_endpoint',
                                  default=self.os_service_endpoint)
예제 #17
0
 def _absolute_url(self, path):
     return Url(scheme=self.scheme,
                host=self.host,
                port=self.port,
                path=path).url
예제 #18
0
 def __init__(self, michost, gpshost, sahost):
     self.MIC_HOST = Url(scheme='http', host=michost, port=common.micport)
     self.GPS_HOST = Url(scheme='http', host=gpshost, port=common.gpsport)
     self.SA_HOST = Url(scheme='http', host=sahost, port=common.saport)
예제 #19
0
class TestUtil:

    url_host_map = [
        # Hosts
        ("http://google.com/mail", ("http", "google.com", None)),
        ("http://google.com/mail/", ("http", "google.com", None)),
        ("google.com/mail", ("http", "google.com", None)),
        ("http://google.com/", ("http", "google.com", None)),
        ("http://google.com", ("http", "google.com", None)),
        ("http://www.google.com", ("http", "www.google.com", None)),
        ("http://mail.google.com", ("http", "mail.google.com", None)),
        ("http://google.com:8000/mail/", ("http", "google.com", 8000)),
        ("http://google.com:8000", ("http", "google.com", 8000)),
        ("https://google.com", ("https", "google.com", None)),
        ("https://google.com:8000", ("https", "google.com", 8000)),
        ("http://*****:*****@127.0.0.1:1234", ("http", "127.0.0.1", 1234)),
        ("http://google.com/foo=http://bar:42/baz", ("http", "google.com",
                                                     None)),
        ("http://google.com?foo=http://bar:42/baz", ("http", "google.com",
                                                     None)),
        ("http://google.com#foo=http://bar:42/baz", ("http", "google.com",
                                                     None)),
        # IPv4
        ("173.194.35.7", ("http", "173.194.35.7", None)),
        ("http://173.194.35.7", ("http", "173.194.35.7", None)),
        ("http://173.194.35.7/test", ("http", "173.194.35.7", None)),
        ("http://173.194.35.7:80", ("http", "173.194.35.7", 80)),
        ("http://173.194.35.7:80/test", ("http", "173.194.35.7", 80)),
        # IPv6
        ("[2a00:1450:4001:c01::67]", ("http", "[2a00:1450:4001:c01::67]", None)
         ),
        ("http://[2a00:1450:4001:c01::67]",
         ("http", "[2a00:1450:4001:c01::67]", None)),
        (
            "http://[2a00:1450:4001:c01::67]/test",
            ("http", "[2a00:1450:4001:c01::67]", None),
        ),
        (
            "http://[2a00:1450:4001:c01::67]:80",
            ("http", "[2a00:1450:4001:c01::67]", 80),
        ),
        (
            "http://[2a00:1450:4001:c01::67]:80/test",
            ("http", "[2a00:1450:4001:c01::67]", 80),
        ),
        # More IPv6 from http://www.ietf.org/rfc/rfc2732.txt
        (
            "http://[fedc:ba98:7654:3210:fedc:ba98:7654:3210]:8000/index.html",
            ("http", "[fedc:ba98:7654:3210:fedc:ba98:7654:3210]", 8000),
        ),
        (
            "http://[1080:0:0:0:8:800:200c:417a]/index.html",
            ("http", "[1080:0:0:0:8:800:200c:417a]", None),
        ),
        ("http://[3ffe:2a00:100:7031::1]", ("http", "[3ffe:2a00:100:7031::1]",
                                            None)),
        (
            "http://[1080::8:800:200c:417a]/foo",
            ("http", "[1080::8:800:200c:417a]", None),
        ),
        ("http://[::192.9.5.5]/ipng", ("http", "[::192.9.5.5]", None)),
        (
            "http://[::ffff:129.144.52.38]:42/index.html",
            ("http", "[::ffff:129.144.52.38]", 42),
        ),
        (
            "http://[2010:836b:4179::836b:4179]",
            ("http", "[2010:836b:4179::836b:4179]", None),
        ),
        # Scoped IPv6 (with ZoneID), both RFC 6874 compliant and not.
        ("http://[a::b%25zone]", ("http", "[a::b%zone]", None)),
        ("http://[a::b%zone]", ("http", "[a::b%zone]", None)),
        # Hosts
        ("HTTP://GOOGLE.COM/mail/", ("http", "google.com", None)),
        ("GOogle.COM/mail", ("http", "google.com", None)),
        ("HTTP://GoOgLe.CoM:8000/mail/", ("http", "google.com", 8000)),
        ("HTTP://*****:*****@EXAMPLE.COM:1234", ("http", "example.com",
                                                   1234)),
        ("173.194.35.7", ("http", "173.194.35.7", None)),
        ("HTTP://173.194.35.7", ("http", "173.194.35.7", None)),
        (
            "HTTP://[2a00:1450:4001:c01::67]:80/test",
            ("http", "[2a00:1450:4001:c01::67]", 80),
        ),
        (
            "HTTP://[FEDC:BA98:7654:3210:FEDC:BA98:7654:3210]:8000/index.html",
            ("http", "[fedc:ba98:7654:3210:fedc:ba98:7654:3210]", 8000),
        ),
        (
            "HTTPS://[1080:0:0:0:8:800:200c:417A]/index.html",
            ("https", "[1080:0:0:0:8:800:200c:417a]", None),
        ),
        ("abOut://eXamPlE.com?info=1", ("about", "eXamPlE.com", None)),
        (
            "http+UNIX://%2fvar%2frun%2fSOCKET/path",
            ("http+unix", "%2fvar%2frun%2fSOCKET", None),
        ),
    ]

    @pytest.mark.parametrize(["url", "scheme_host_port"], url_host_map)
    def test_scheme_host_port(
            self, url: str, scheme_host_port: Tuple[str, str,
                                                    Optional[int]]) -> None:
        parsed_url = parse_url(url)
        scheme, host, port = scheme_host_port

        assert (parsed_url.scheme or "http") == scheme
        assert parsed_url.hostname == parsed_url.host == host
        assert parsed_url.port == port

    def test_encode_invalid_chars_none(self) -> None:
        assert _encode_invalid_chars(None, set()) is None

    @pytest.mark.parametrize(
        "url",
        [
            "http://google.com:foo",
            "http://::1/",
            "http://::1:80/",
            "http://google.com:-80",
            "http://google.com:65536",
            "http://google.com:\xb2\xb2",  # \xb2 = ^2
            # Invalid IDNA labels
            "http://\uD7FF.com",
            "http://❤️",
            # Unicode surrogates
            "http://\uD800.com",
            "http://\uDC00.com",
        ],
    )
    def test_invalid_url(self, url: str) -> None:
        with pytest.raises(LocationParseError):
            parse_url(url)

    @pytest.mark.parametrize(
        "url, expected_normalized_url",
        [
            ("HTTP://GOOGLE.COM/MAIL/", "http://google.com/MAIL/"),
            (
                "http://[email protected]:[email protected]/~tilde@?@",
                "http://user%40domain.com:[email protected]/~tilde@?@",
            ),
            (
                "HTTP://*****:*****@Example.com:8080/",
                "http://*****:*****@example.com:8080/",
            ),
            ("HTTPS://Example.Com/?Key=Value",
             "https://example.com/?Key=Value"),
            ("Https://Example.Com/#Fragment", "https://example.com/#Fragment"),
            # IPv6 addresses with zone IDs. Both RFC 6874 (%25) as well as
            # non-standard (unquoted %) variants.
            ("[::1%zone]", "[::1%zone]"),
            ("[::1%25zone]", "[::1%zone]"),
            ("[::1%25]", "[::1%25]"),
            ("[::Ff%etH0%Ff]/%ab%Af", "[::ff%etH0%FF]/%AB%AF"),
            (
                "http://*****:*****@[AaAa::Ff%25etH0%Ff]/%ab%Af",
                "http://*****:*****@[aaaa::ff%etH0%FF]/%AB%AF",
            ),
            # Invalid characters for the query/fragment getting encoded
            (
                'http://google.com/p[]?parameter[]="hello"#fragment#',
                "http://google.com/p%5B%5D?parameter%5B%5D=%22hello%22#fragment%23",
            ),
            # Percent encoding isn't applied twice despite '%' being invalid
            # but the percent encoding is still normalized.
            (
                "http://google.com/p%5B%5d?parameter%5b%5D=%22hello%22#fragment%23",
                "http://google.com/p%5B%5D?parameter%5B%5D=%22hello%22#fragment%23",
            ),
        ],
    )
    def test_parse_url_normalization(self, url: str,
                                     expected_normalized_url: str) -> None:
        """Assert parse_url normalizes the scheme/host, and only the scheme/host"""
        actual_normalized_url = parse_url(url).url
        assert actual_normalized_url == expected_normalized_url

    @pytest.mark.parametrize("char",
                             [chr(i) for i in range(0x00, 0x21)] + ["\x7F"])
    def test_control_characters_are_percent_encoded(self, char: str) -> None:
        percent_char = "%" + (hex(ord(char))[2:].zfill(2).upper())
        url = parse_url(
            f"http://user{char}@example.com/path{char}?query{char}#fragment{char}"
        )

        assert url == Url(
            "http",
            auth="user" + percent_char,
            host="example.com",
            path="/path" + percent_char,
            query="query" + percent_char,
            fragment="fragment" + percent_char,
        )

    parse_url_host_map = [
        ("http://google.com/mail", Url("http", host="google.com",
                                       path="/mail")),
        ("http://google.com/mail/",
         Url("http", host="google.com", path="/mail/")),
        ("http://google.com/mail", Url("http", host="google.com",
                                       path="mail")),
        ("google.com/mail", Url(host="google.com", path="/mail")),
        ("http://google.com/", Url("http", host="google.com", path="/")),
        ("http://google.com", Url("http", host="google.com")),
        ("http://google.com?foo",
         Url("http", host="google.com", path="", query="foo")),
        # Path/query/fragment
        ("", Url()),
        ("/", Url(path="/")),
        ("#?/!google.com/?foo", Url(path="", fragment="?/!google.com/?foo")),
        ("/foo", Url(path="/foo")),
        ("/foo?bar=baz", Url(path="/foo", query="bar=baz")),
        (
            "/foo?bar=baz#banana?apple/orange",
            Url(path="/foo", query="bar=baz", fragment="banana?apple/orange"),
        ),
        (
            "/redirect?target=http://localhost:61020/",
            Url(path="redirect", query="target=http://localhost:61020/"),
        ),
        # Port
        ("http://google.com/", Url("http", host="google.com", path="/")),
        ("http://google.com:80/",
         Url("http", host="google.com", port=80, path="/")),
        ("http://google.com:80", Url("http", host="google.com", port=80)),
        # Auth
        (
            "http://*****:*****@localhost/",
            Url("http", auth="foo:bar", host="localhost", path="/"),
        ),
        ("http://foo@localhost/",
         Url("http", auth="foo", host="localhost", path="/")),
        (
            "http://*****:*****@localhost/",
            Url("http", auth="foo:bar", host="localhost", path="/"),
        ),
    ]

    non_round_tripping_parse_url_host_map = [
        # Path/query/fragment
        ("?", Url(path="", query="")),
        ("#", Url(path="", fragment="")),
        # Path normalization
        ("/abc/../def", Url(path="/def")),
        # Empty Port
        ("http://google.com:", Url("http", host="google.com")),
        ("http://google.com:/", Url("http", host="google.com", path="/")),
        # Uppercase IRI
        (
            "http://Königsgäßchen.de/straße",
            Url("http", host="xn--knigsgchen-b4a3dun.de", path="/stra%C3%9Fe"),
        ),
        # Percent-encode in userinfo
        (
            "http://[email protected]:[email protected]/",
            Url("http",
                auth="user%40email.com:password",
                host="example.com",
                path="/"),
        ),
        (
            'http://user":[email protected]/',
            Url("http", auth="user%22:quoted", host="example.com", path="/"),
        ),
        # Unicode Surrogates
        ("http://google.com/\uD800",
         Url("http", host="google.com", path="%ED%A0%80")),
        (
            "http://google.com?q=\uDC00",
            Url("http", host="google.com", path="", query="q=%ED%B0%80"),
        ),
        (
            "http://google.com#\uDC00",
            Url("http", host="google.com", path="", fragment="%ED%B0%80"),
        ),
    ]

    @pytest.mark.parametrize(
        "url, expected_url",
        chain(parse_url_host_map, non_round_tripping_parse_url_host_map),
    )
    def test_parse_url(self, url: str, expected_url: Url) -> None:
        returned_url = parse_url(url)
        assert returned_url == expected_url
        assert returned_url.hostname == returned_url.host == expected_url.host

    @pytest.mark.parametrize("url, expected_url", parse_url_host_map)
    def test_unparse_url(self, url: str, expected_url: Url) -> None:
        assert url == expected_url.url

    @pytest.mark.parametrize(
        ["url", "expected_url"],
        [
            # RFC 3986 5.2.4
            ("/abc/../def", Url(path="/def")),
            ("/..", Url(path="/")),
            ("/./abc/./def/", Url(path="/abc/def/")),
            ("/.", Url(path="/")),
            ("/./", Url(path="/")),
            ("/abc/./.././d/././e/.././f/./../../ghi", Url(path="/ghi")),
        ],
    )
    def test_parse_and_normalize_url_paths(self, url: str,
                                           expected_url: Url) -> None:
        actual_url = parse_url(url)
        assert actual_url == expected_url
        assert actual_url.url == expected_url.url

    def test_parse_url_invalid_IPv6(self) -> None:
        with pytest.raises(LocationParseError):
            parse_url("[::1")

    def test_parse_url_negative_port(self) -> None:
        with pytest.raises(LocationParseError):
            parse_url("https://www.google.com:-80/")

    def test_Url_str(self) -> None:
        U = Url("http", host="google.com")
        assert str(U) == U.url

    request_uri_map = [
        ("http://google.com/mail", "/mail"),
        ("http://google.com/mail/", "/mail/"),
        ("http://google.com/", "/"),
        ("http://google.com", "/"),
        ("", "/"),
        ("/", "/"),
        ("?", "/?"),
        ("#", "/"),
        ("/foo?bar=baz", "/foo?bar=baz"),
    ]

    @pytest.mark.parametrize("url, expected_request_uri", request_uri_map)
    def test_request_uri(self, url: str, expected_request_uri: str) -> None:
        returned_url = parse_url(url)
        assert returned_url.request_uri == expected_request_uri

    url_authority_map: List[Tuple[str, Optional[str]]] = [
        ("http://*****:*****@google.com/mail", "user:[email protected]"),
        ("http://*****:*****@google.com:80/mail", "user:[email protected]:80"),
        ("http://[email protected]:80/mail", "[email protected]:80"),
        ("http://*****:*****@192.168.1.1/path", "user:[email protected]"),
        ("http://*****:*****@192.168.1.1:80/path", "user:[email protected]:80"),
        ("http://[email protected]:80/path", "[email protected]:80"),
        ("http://*****:*****@[::1]/path", "user:pass@[::1]"),
        ("http://*****:*****@[::1]:80/path", "user:pass@[::1]:80"),
        ("http://user@[::1]:80/path", "user@[::1]:80"),
        ("http://*****:*****@localhost/path", "user:pass@localhost"),
        ("http://*****:*****@localhost:80/path", "user:pass@localhost:80"),
        ("http://user@localhost:80/path", "user@localhost:80"),
    ]

    url_netloc_map = [
        ("http://google.com/mail", "google.com"),
        ("http://google.com:80/mail", "google.com:80"),
        ("http://192.168.0.1/path", "192.168.0.1"),
        ("http://192.168.0.1:80/path", "192.168.0.1:80"),
        ("http://[::1]/path", "[::1]"),
        ("http://[::1]:80/path", "[::1]:80"),
        ("http://localhost", "localhost"),
        ("http://*****:*****@pytest.mark.parametrize("url, expected_authority",
                             combined_netloc_authority_map)
    def test_authority(self, url: str,
                       expected_authority: Optional[str]) -> None:
        assert parse_url(url).authority == expected_authority

    @pytest.mark.parametrize("url, expected_authority",
                             url_authority_with_schemes_map)
    def test_authority_matches_urllib_netloc(
            self, url: str, expected_authority: Optional[str]) -> None:
        """Validate this matches the behavior of urlparse().netloc"""
        assert urlparse(url).netloc == expected_authority

    @pytest.mark.parametrize("url, expected_netloc", url_netloc_map)
    def test_netloc(self, url: str, expected_netloc: Optional[str]) -> None:
        assert parse_url(url).netloc == expected_netloc

    url_vulnerabilities = [
        # urlparse doesn't follow RFC 3986 Section 3.2
        (
            "http://google.com#@evil.com/",
            Url("http", host="google.com", path="", fragment="@evil.com/"),
        ),
        # CVE-2016-5699
        (
            "http://127.0.0.1%0d%0aConnection%3a%20keep-alive",
            Url("http", host="127.0.0.1%0d%0aconnection%3a%20keep-alive"),
        ),
        # NodeJS unicode -> double dot
        (
            "http://google.com/\uff2e\uff2e/abc",
            Url("http", host="google.com", path="/%EF%BC%AE%EF%BC%AE/abc"),
        ),
        # Scheme without ://
        (
            "javascript:a='@google.com:12345/';alert(0)",
            Url(scheme="javascript", path="a='@google.com:12345/';alert(0)"),
        ),
        ("//google.com/a/b/c", Url(host="google.com", path="/a/b/c")),
        # International URLs
        (
            "http://ヒ:キ@ヒ.abc.ニ/ヒ?キ#ワ",
            Url(
                "http",
                host="xn--pdk.abc.xn--idk",
                auth="%E3%83%92:%E3%82%AD",
                path="/%E3%83%92",
                query="%E3%82%AD",
                fragment="%E3%83%AF",
            ),
        ),
        # Injected headers (CVE-2016-5699, CVE-2019-9740, CVE-2019-9947)
        (
            "10.251.0.83:7777?a=1 HTTP/1.1\r\nX-injected: header",
            Url(
                host="10.251.0.83",
                port=7777,
                path="",
                query="a=1%20HTTP/1.1%0D%0AX-injected:%20header",
            ),
        ),
        (
            "http://127.0.0.1:6379?\r\nSET test failure12\r\n:8080/test/?test=a",
            Url(
                scheme="http",
                host="127.0.0.1",
                port=6379,
                path="",
                query="%0D%0ASET%20test%20failure12%0D%0A:8080/test/?test=a",
            ),
        ),
        # See https://bugs.xdavidhu.me/google/2020/03/08/the-unexpected-google-wide-domain-check-bypass/
        (
            "https://*****:*****@xdavidhu.me\\test.corp.google.com:8080/path/to/something?param=value#hash",
            Url(
                scheme="https",
                auth="user:pass",
                host="xdavidhu.me",
                path="/%5Ctest.corp.google.com:8080/path/to/something",
                query="param=value",
                fragment="hash",
            ),
        ),
        # Tons of '@' causing backtracking
        ("https://" + ("@" * 10000) + "[", False),
        (
            "https://*****:*****@" * 10000) + "example.com",
            Url(
                scheme="https",
                auth="user:"******"%40" * 9999),
                host="example.com",
            ),
        ),
    ]

    @pytest.mark.parametrize("url, expected_url", url_vulnerabilities)
    def test_url_vulnerabilities(
            self, url: str, expected_url: Union["Literal[False]",
                                                Url]) -> None:
        if expected_url is False:
            with pytest.raises(LocationParseError):
                parse_url(url)
        else:
            assert parse_url(url) == expected_url

    def test_parse_url_bytes_type_error(self) -> None:
        with pytest.raises(TypeError):
            parse_url(b"https://www.google.com/")  # type: ignore[arg-type]

    @pytest.mark.parametrize(
        "kwargs, expected",
        [
            pytest.param(
                {"accept_encoding": True},
                {"accept-encoding": "gzip,deflate,br,zstd"},
                marks=[onlyBrotli(), onlyZstd()],  # type: ignore[list-item]
            ),
            pytest.param(
                {"accept_encoding": True},
                {"accept-encoding": "gzip,deflate,br"},
                marks=[onlyBrotli(), notZstd()],  # type: ignore[list-item]
            ),
            pytest.param(
                {"accept_encoding": True},
                {"accept-encoding": "gzip,deflate,zstd"},
                marks=[notBrotli(), onlyZstd()],  # type: ignore[list-item]
            ),
            pytest.param(
                {"accept_encoding": True},
                {"accept-encoding": "gzip,deflate"},
                marks=[notBrotli(), notZstd()],  # type: ignore[list-item]
            ),
            ({
                "accept_encoding": "foo,bar"
            }, {
                "accept-encoding": "foo,bar"
            }),
            ({
                "accept_encoding": ["foo", "bar"]
            }, {
                "accept-encoding": "foo,bar"
            }),
            pytest.param(
                {
                    "accept_encoding": True,
                    "user_agent": "banana"
                },
                {
                    "accept-encoding": "gzip,deflate,br,zstd",
                    "user-agent": "banana"
                },
                marks=[onlyBrotli(), onlyZstd()],  # type: ignore[list-item]
            ),
            pytest.param(
                {
                    "accept_encoding": True,
                    "user_agent": "banana"
                },
                {
                    "accept-encoding": "gzip,deflate,br",
                    "user-agent": "banana"
                },
                marks=[onlyBrotli(), notZstd()],  # type: ignore[list-item]
            ),
            pytest.param(
                {
                    "accept_encoding": True,
                    "user_agent": "banana"
                },
                {
                    "accept-encoding": "gzip,deflate,zstd",
                    "user-agent": "banana"
                },
                marks=[notBrotli(), onlyZstd()],  # type: ignore[list-item]
            ),
            pytest.param(
                {
                    "accept_encoding": True,
                    "user_agent": "banana"
                },
                {
                    "accept-encoding": "gzip,deflate",
                    "user-agent": "banana"
                },
                marks=[notBrotli(), notZstd()],  # type: ignore[list-item]
            ),
            ({
                "user_agent": "banana"
            }, {
                "user-agent": "banana"
            }),
            ({
                "keep_alive": True
            }, {
                "connection": "keep-alive"
            }),
            ({
                "basic_auth": "foo:bar"
            }, {
                "authorization": "Basic Zm9vOmJhcg=="
            }),
            (
                {
                    "proxy_basic_auth": "foo:bar"
                },
                {
                    "proxy-authorization": "Basic Zm9vOmJhcg=="
                },
            ),
            ({
                "disable_cache": True
            }, {
                "cache-control": "no-cache"
            }),
        ],
    )
    def test_make_headers(self, kwargs: Dict[str, Union[bool, str]],
                          expected: Dict[str, str]) -> None:
        assert make_headers(**kwargs) == expected  # type: ignore[arg-type]

    def test_rewind_body(self) -> None:
        body = io.BytesIO(b"test data")
        assert body.read() == b"test data"

        # Assert the file object has been consumed
        assert body.read() == b""

        # Rewind it back to just be b'data'
        rewind_body(body, 5)
        assert body.read() == b"data"

    def test_rewind_body_failed_tell(self) -> None:
        body = io.BytesIO(b"test data")
        body.read()  # Consume body

        # Simulate failed tell()
        body_pos = _FAILEDTELL
        with pytest.raises(UnrewindableBodyError):
            rewind_body(body, body_pos)

    def test_rewind_body_bad_position(self) -> None:
        body = io.BytesIO(b"test data")
        body.read()  # Consume body

        # Pass non-integer position
        with pytest.raises(ValueError):
            rewind_body(body, body_pos=None)  # type: ignore[arg-type]
        with pytest.raises(ValueError):
            rewind_body(body, body_pos=object())  # type: ignore[arg-type]

    def test_rewind_body_failed_seek(self) -> None:
        class BadSeek(io.StringIO):
            def seek(self, offset: int, whence: int = 0) -> NoReturn:
                raise OSError

        with pytest.raises(UnrewindableBodyError):
            rewind_body(BadSeek(), body_pos=2)

    def test_add_stderr_logger(self) -> None:
        handler = add_stderr_logger(
            level=logging.INFO)  # Don't actually print debug
        logger = logging.getLogger("urllib3")
        assert handler in logger.handlers

        logger.debug("Testing add_stderr_logger")
        logger.removeHandler(handler)

    def test_disable_warnings(self) -> None:
        with warnings.catch_warnings(record=True) as w:
            clear_warnings()
            warnings.warn("This is a test.", InsecureRequestWarning)
            assert len(w) == 1
            disable_warnings()
            warnings.warn("This is a test.", InsecureRequestWarning)
            assert len(w) == 1

    def _make_time_pass(self, seconds: int, timeout: Timeout,
                        time_mock: Mock) -> Timeout:
        """Make some time pass for the timeout object"""
        time_mock.return_value = TIMEOUT_EPOCH
        timeout.start_connect()
        time_mock.return_value = TIMEOUT_EPOCH + seconds
        return timeout

    @pytest.mark.parametrize(
        "kwargs, message",
        [
            ({
                "total": -1
            }, "less than"),
            ({
                "connect": 2,
                "total": -1
            }, "less than"),
            ({
                "read": -1
            }, "less than"),
            ({
                "connect": False
            }, "cannot be a boolean"),
            ({
                "read": True
            }, "cannot be a boolean"),
            ({
                "connect": 0
            }, "less than or equal"),
            ({
                "read": "foo"
            }, "int, float or None"),
            ({
                "read": "1.0"
            }, "int, float or None"),
        ],
    )
    def test_invalid_timeouts(self, kwargs: Dict[str, Union[int, bool]],
                              message: str) -> None:
        with pytest.raises(ValueError, match=message):
            Timeout(**kwargs)

    @patch("time.monotonic")
    def test_timeout(self, time_monotonic: MagicMock) -> None:
        timeout = Timeout(total=3)

        # make 'no time' elapse
        timeout = self._make_time_pass(seconds=0,
                                       timeout=timeout,
                                       time_mock=time_monotonic)
        assert timeout.read_timeout == 3
        assert timeout.connect_timeout == 3

        timeout = Timeout(total=3, connect=2)
        assert timeout.connect_timeout == 2

        timeout = Timeout()
        assert timeout.connect_timeout == _DEFAULT_TIMEOUT

        # Connect takes 5 seconds, leaving 5 seconds for read
        timeout = Timeout(total=10, read=7)
        timeout = self._make_time_pass(seconds=5,
                                       timeout=timeout,
                                       time_mock=time_monotonic)
        assert timeout.read_timeout == 5

        # Connect takes 2 seconds, read timeout still 7 seconds
        timeout = Timeout(total=10, read=7)
        timeout = self._make_time_pass(seconds=2,
                                       timeout=timeout,
                                       time_mock=time_monotonic)
        assert timeout.read_timeout == 7

        timeout = Timeout(total=10, read=7)
        assert timeout.read_timeout == 7

        timeout = Timeout(total=None, read=None, connect=None)
        assert timeout.connect_timeout is None
        assert timeout.read_timeout is None
        assert timeout.total is None

        timeout = Timeout(5)
        assert timeout.total == 5

    def test_timeout_default_resolve(self) -> None:
        """The timeout default is resolved when read_timeout is accessed."""
        timeout = Timeout()
        with patch("urllib3.util.timeout.getdefaulttimeout", return_value=2):
            assert timeout.read_timeout == 2

        with patch("urllib3.util.timeout.getdefaulttimeout", return_value=3):
            assert timeout.read_timeout == 3

    def test_timeout_str(self) -> None:
        timeout = Timeout(connect=1, read=2, total=3)
        assert str(timeout) == "Timeout(connect=1, read=2, total=3)"
        timeout = Timeout(connect=1, read=None, total=3)
        assert str(timeout) == "Timeout(connect=1, read=None, total=3)"

    @patch("time.monotonic")
    def test_timeout_elapsed(self, time_monotonic: MagicMock) -> None:
        time_monotonic.return_value = TIMEOUT_EPOCH
        timeout = Timeout(total=3)
        with pytest.raises(TimeoutStateError):
            timeout.get_connect_duration()

        timeout.start_connect()
        with pytest.raises(TimeoutStateError):
            timeout.start_connect()

        time_monotonic.return_value = TIMEOUT_EPOCH + 2
        assert timeout.get_connect_duration() == 2
        time_monotonic.return_value = TIMEOUT_EPOCH + 37
        assert timeout.get_connect_duration() == 37

    def test_is_fp_closed_object_supports_closed(self) -> None:
        class ClosedFile:
            @property
            def closed(self) -> "Literal[True]":
                return True

        assert is_fp_closed(ClosedFile())

    def test_is_fp_closed_object_has_none_fp(self) -> None:
        class NoneFpFile:
            @property
            def fp(self) -> None:
                return None

        assert is_fp_closed(NoneFpFile())

    def test_is_fp_closed_object_has_fp(self) -> None:
        class FpFile:
            @property
            def fp(self) -> "Literal[True]":
                return True

        assert not is_fp_closed(FpFile())

    def test_is_fp_closed_object_has_neither_fp_nor_closed(self) -> None:
        class NotReallyAFile:
            pass

        with pytest.raises(ValueError):
            is_fp_closed(NotReallyAFile())

    def test_has_ipv6_disabled_on_compile(self) -> None:
        with patch("socket.has_ipv6", False):
            assert not _has_ipv6("::1")

    def test_has_ipv6_enabled_but_fails(self) -> None:
        with patch("socket.has_ipv6", True):
            with patch("socket.socket") as mock:
                instance = mock.return_value
                instance.bind = Mock(side_effect=Exception("No IPv6 here!"))
                assert not _has_ipv6("::1")

    def test_has_ipv6_enabled_and_working(self) -> None:
        with patch("socket.has_ipv6", True):
            with patch("socket.socket") as mock:
                instance = mock.return_value
                instance.bind.return_value = True
                assert _has_ipv6("::1")

    def test_ip_family_ipv6_enabled(self) -> None:
        with patch("urllib3.util.connection.HAS_IPV6", True):
            assert allowed_gai_family() == socket.AF_UNSPEC

    def test_ip_family_ipv6_disabled(self) -> None:
        with patch("urllib3.util.connection.HAS_IPV6", False):
            assert allowed_gai_family() == socket.AF_INET

    @pytest.mark.parametrize("headers", [b"foo", None, object])
    def test_assert_header_parsing_throws_typeerror_with_non_headers(
            self, headers: Optional[Union[bytes, object]]) -> None:
        with pytest.raises(TypeError):
            assert_header_parsing(headers)  # type: ignore[arg-type]

    def test_connection_requires_http_tunnel_no_proxy(self) -> None:
        assert not connection_requires_http_tunnel(
            proxy_url=None, proxy_config=None, destination_scheme=None)

    def test_connection_requires_http_tunnel_http_proxy(self) -> None:
        proxy = parse_url("http://*****:*****@pytest.mark.parametrize("host", [".localhost", "...", "t" * 64])
    def test_create_connection_with_invalid_idna_labels(self,
                                                        host: str) -> None:
        with pytest.raises(
                LocationParseError,
                match=f"Failed to parse: '{host}', label empty or too long",
        ):
            create_connection((host, 80))

    @pytest.mark.parametrize(
        "host",
        [
            "a.example.com",
            "localhost.",
            "[dead::beef]",
            "[dead::beef%en5]",
            "[dead::beef%en5.]",
        ],
    )
    @patch("socket.getaddrinfo")
    @patch("socket.socket")
    def test_create_connection_with_valid_idna_labels(self, socket: MagicMock,
                                                      getaddrinfo: MagicMock,
                                                      host: str) -> None:
        getaddrinfo.return_value = [(None, None, None, None, None)]
        socket.return_value = Mock()
        create_connection((host, 80))

    @patch("socket.getaddrinfo")
    def test_create_connection_error(self, getaddrinfo: MagicMock) -> None:
        getaddrinfo.return_value = []
        with pytest.raises(OSError, match="getaddrinfo returns an empty list"):
            create_connection(("example.com", 80))

    @patch("socket.getaddrinfo")
    def test_dnsresolver_forced_error(self, getaddrinfo: MagicMock) -> None:
        getaddrinfo.side_effect = socket.gaierror()
        with pytest.raises(socket.gaierror):
            # dns is valid but we force the error just for the sake of the test
            create_connection(("example.com", 80))

    def test_dnsresolver_expected_error(self) -> None:
        with pytest.raises(socket.gaierror):
            # windows: [Errno 11001] getaddrinfo failed in windows
            # linux: [Errno -2] Name or service not known
            # macos: [Errno 8] nodename nor servname provided, or not known
            create_connection(("badhost.invalid", 80))

    @patch("socket.getaddrinfo")
    @patch("socket.socket")
    def test_create_connection_with_scoped_ipv6(
            self, socket: MagicMock, getaddrinfo: MagicMock) -> None:
        # Check that providing create_connection with a scoped IPv6 address
        # properly propagates the scope to getaddrinfo, and that the returned
        # scoped ID makes it to the socket creation call.
        fake_scoped_sa6 = ("a::b", 80, 0, 42)
        getaddrinfo.return_value = [(
            socket.AF_INET6,
            socket.SOCK_STREAM,
            socket.IPPROTO_TCP,
            "",
            fake_scoped_sa6,
        )]
        socket.return_value = fake_sock = MagicMock()

        create_connection(("a::b%iface", 80))
        assert getaddrinfo.call_args[0][0] == "a::b%iface"
        fake_sock.connect.assert_called_once_with(fake_scoped_sa6)

    @pytest.mark.parametrize(
        "input,params,expected",
        (
            ("test", {}, "test"),  # str input
            (b"test", {}, "test"),  # bytes input
            (b"test", {
                "encoding": "utf-8"
            }, "test"),  # bytes input with utf-8
            (b"test", {
                "encoding": "ascii"
            }, "test"),  # bytes input with ascii
        ),
    )
    def test_to_str(self, input: Union[bytes, str], params: Dict[str, str],
                    expected: str) -> None:
        assert to_str(input, **params) == expected

    def test_to_str_error(self) -> None:
        with pytest.raises(TypeError, match="not expecting type int"):
            to_str(1)  # type: ignore[arg-type]

    @pytest.mark.parametrize(
        "input,params,expected",
        (
            (b"test", {}, b"test"),  # str input
            ("test", {}, b"test"),  # bytes input
            ("é", {}, b"\xc3\xa9"),  # bytes input
            ("test", {
                "encoding": "utf-8"
            }, b"test"),  # bytes input with utf-8
            ("test", {
                "encoding": "ascii"
            }, b"test"),  # bytes input with ascii
        ),
    )
    def test_to_bytes(self, input: Union[bytes, str], params: Dict[str, str],
                      expected: bytes) -> None:
        assert to_bytes(input, **params) == expected

    def test_to_bytes_error(self) -> None:
        with pytest.raises(TypeError, match="not expecting type int"):
            to_bytes(1)  # type: ignore[arg-type]
예제 #20
0
 def test_Url_str(self):
     U = Url("http", host="google.com")
     assert str(U) == U.url
예제 #21
0
class TestUtil(object):

    url_host_map = [
        # Hosts
        ("http://google.com/mail", ("http", "google.com", None)),
        ("http://google.com/mail/", ("http", "google.com", None)),
        ("google.com/mail", ("http", "google.com", None)),
        ("http://google.com/", ("http", "google.com", None)),
        ("http://google.com", ("http", "google.com", None)),
        ("http://www.google.com", ("http", "www.google.com", None)),
        ("http://mail.google.com", ("http", "mail.google.com", None)),
        ("http://google.com:8000/mail/", ("http", "google.com", 8000)),
        ("http://google.com:8000", ("http", "google.com", 8000)),
        ("https://google.com", ("https", "google.com", None)),
        ("https://google.com:8000", ("https", "google.com", 8000)),
        ("http://*****:*****@127.0.0.1:1234", ("http", "127.0.0.1", 1234)),
        ("http://google.com/foo=http://bar:42/baz", ("http", "google.com", None)),
        ("http://google.com?foo=http://bar:42/baz", ("http", "google.com", None)),
        ("http://google.com#foo=http://bar:42/baz", ("http", "google.com", None)),
        # IPv4
        ("173.194.35.7", ("http", "173.194.35.7", None)),
        ("http://173.194.35.7", ("http", "173.194.35.7", None)),
        ("http://173.194.35.7/test", ("http", "173.194.35.7", None)),
        ("http://173.194.35.7:80", ("http", "173.194.35.7", 80)),
        ("http://173.194.35.7:80/test", ("http", "173.194.35.7", 80)),
        # IPv6
        ("[2a00:1450:4001:c01::67]", ("http", "[2a00:1450:4001:c01::67]", None)),
        ("http://[2a00:1450:4001:c01::67]", ("http", "[2a00:1450:4001:c01::67]", None)),
        (
            "http://[2a00:1450:4001:c01::67]/test",
            ("http", "[2a00:1450:4001:c01::67]", None),
        ),
        (
            "http://[2a00:1450:4001:c01::67]:80",
            ("http", "[2a00:1450:4001:c01::67]", 80),
        ),
        (
            "http://[2a00:1450:4001:c01::67]:80/test",
            ("http", "[2a00:1450:4001:c01::67]", 80),
        ),
        # More IPv6 from http://www.ietf.org/rfc/rfc2732.txt
        (
            "http://[fedc:ba98:7654:3210:fedc:ba98:7654:3210]:8000/index.html",
            ("http", "[fedc:ba98:7654:3210:fedc:ba98:7654:3210]", 8000),
        ),
        (
            "http://[1080:0:0:0:8:800:200c:417a]/index.html",
            ("http", "[1080:0:0:0:8:800:200c:417a]", None),
        ),
        ("http://[3ffe:2a00:100:7031::1]", ("http", "[3ffe:2a00:100:7031::1]", None)),
        (
            "http://[1080::8:800:200c:417a]/foo",
            ("http", "[1080::8:800:200c:417a]", None),
        ),
        ("http://[::192.9.5.5]/ipng", ("http", "[::192.9.5.5]", None)),
        (
            "http://[::ffff:129.144.52.38]:42/index.html",
            ("http", "[::ffff:129.144.52.38]", 42),
        ),
        (
            "http://[2010:836b:4179::836b:4179]",
            ("http", "[2010:836b:4179::836b:4179]", None),
        ),
        # Hosts
        ("HTTP://GOOGLE.COM/mail/", ("http", "google.com", None)),
        ("GOogle.COM/mail", ("http", "google.com", None)),
        ("HTTP://GoOgLe.CoM:8000/mail/", ("http", "google.com", 8000)),
        ("HTTP://*****:*****@EXAMPLE.COM:1234", ("http", "example.com", 1234)),
        ("173.194.35.7", ("http", "173.194.35.7", None)),
        ("HTTP://173.194.35.7", ("http", "173.194.35.7", None)),
        (
            "HTTP://[2a00:1450:4001:c01::67]:80/test",
            ("http", "[2a00:1450:4001:c01::67]", 80),
        ),
        (
            "HTTP://[FEDC:BA98:7654:3210:FEDC:BA98:7654:3210]:8000/index.html",
            ("http", "[fedc:ba98:7654:3210:fedc:ba98:7654:3210]", 8000),
        ),
        (
            "HTTPS://[1080:0:0:0:8:800:200c:417A]/index.html",
            ("https", "[1080:0:0:0:8:800:200c:417a]", None),
        ),
        ("abOut://eXamPlE.com?info=1", ("about", "eXamPlE.com", None)),
        (
            "http+UNIX://%2fvar%2frun%2fSOCKET/path",
            ("http+unix", "%2fvar%2frun%2fSOCKET", None),
        ),
    ]

    @pytest.mark.parametrize("url, expected_host", url_host_map)
    def test_get_host(self, url, expected_host):
        returned_host = get_host(url)
        assert returned_host == expected_host

    # TODO: Add more tests
    @pytest.mark.parametrize(
        "location",
        [
            "http://google.com:foo",
            "http://::1/",
            "http://::1:80/",
            "http://google.com:-80",
            six.u("http://google.com:\xb2\xb2"),  # \xb2 = ^2
        ],
    )
    def test_invalid_host(self, location):
        with pytest.raises(LocationParseError):
            get_host(location)

    @pytest.mark.parametrize(
        "url",
        [
            # Invalid IDNA labels
            u"http://\uD7FF.com",
            u"http://❤️",
            # Unicode surrogates
            u"http://\uD800.com",
            u"http://\uDC00.com",
        ],
    )
    def test_invalid_url(self, url):
        with pytest.raises(LocationParseError):
            parse_url(url)

    @pytest.mark.parametrize(
        "url, expected_normalized_url",
        [
            ("HTTP://GOOGLE.COM/MAIL/", "http://google.com/MAIL/"),
            (
                "http://[email protected]:[email protected]/~tilde@?@",
                "http://user%40domain.com:[email protected]/~tilde@?@",
            ),
            (
                "HTTP://*****:*****@Example.com:8080/",
                "http://*****:*****@example.com:8080/",
            ),
            ("HTTPS://Example.Com/?Key=Value", "https://example.com/?Key=Value"),
            ("Https://Example.Com/#Fragment", "https://example.com/#Fragment"),
            ("[::1%25]", "[::1%25]"),
            ("[::Ff%etH0%Ff]/%ab%Af", "[::ff%etH0%FF]/%AB%AF"),
            (
                "http://*****:*****@[AaAa::Ff%25etH0%Ff]/%ab%Af",
                "http://*****:*****@[aaaa::ff%etH0%FF]/%AB%AF",
            ),
            # Invalid characters for the query/fragment getting encoded
            (
                'http://google.com/p[]?parameter[]="hello"#fragment#',
                "http://google.com/p%5B%5D?parameter%5B%5D=%22hello%22#fragment%23",
            ),
            # Percent encoding isn't applied twice despite '%' being invalid
            # but the percent encoding is still normalized.
            (
                "http://google.com/p%5B%5d?parameter%5b%5D=%22hello%22#fragment%23",
                "http://google.com/p%5B%5D?parameter%5B%5D=%22hello%22#fragment%23",
            ),
        ],
    )
    def test_parse_url_normalization(self, url, expected_normalized_url):
        """Assert parse_url normalizes the scheme/host, and only the scheme/host"""
        actual_normalized_url = parse_url(url).url
        assert actual_normalized_url == expected_normalized_url

    @pytest.mark.parametrize("char", [chr(i) for i in range(0x00, 0x21)] + ["\x7F"])
    def test_control_characters_are_percent_encoded(self, char):
        percent_char = "%" + (hex(ord(char))[2:].zfill(2).upper())
        url = parse_url(
            "http://user{0}@example.com/path{0}?query{0}#fragment{0}".format(char)
        )

        assert url == Url(
            "http",
            auth="user" + percent_char,
            host="example.com",
            path="/path" + percent_char,
            query="query" + percent_char,
            fragment="fragment" + percent_char,
        )

    parse_url_host_map = [
        ("http://google.com/mail", Url("http", host="google.com", path="/mail")),
        ("http://google.com/mail/", Url("http", host="google.com", path="/mail/")),
        ("http://google.com/mail", Url("http", host="google.com", path="mail")),
        ("google.com/mail", Url(host="google.com", path="/mail")),
        ("http://google.com/", Url("http", host="google.com", path="/")),
        ("http://google.com", Url("http", host="google.com")),
        ("http://google.com?foo", Url("http", host="google.com", path="", query="foo")),
        # Path/query/fragment
        ("", Url()),
        ("/", Url(path="/")),
        ("#?/!google.com/?foo", Url(path="", fragment="?/!google.com/?foo")),
        ("/foo", Url(path="/foo")),
        ("/foo?bar=baz", Url(path="/foo", query="bar=baz")),
        (
            "/foo?bar=baz#banana?apple/orange",
            Url(path="/foo", query="bar=baz", fragment="banana?apple/orange"),
        ),
        (
            "/redirect?target=http://localhost:61020/",
            Url(path="redirect", query="target=http://localhost:61020/"),
        ),
        # Port
        ("http://google.com/", Url("http", host="google.com", path="/")),
        ("http://google.com:80/", Url("http", host="google.com", port=80, path="/")),
        ("http://google.com:80", Url("http", host="google.com", port=80)),
        # Auth
        (
            "http://*****:*****@localhost/",
            Url("http", auth="foo:bar", host="localhost", path="/"),
        ),
        ("http://foo@localhost/", Url("http", auth="foo", host="localhost", path="/")),
        (
            "http://*****:*****@localhost/",
            Url("http", auth="foo:bar", host="localhost", path="/"),
        ),
        # Unicode type (Python 2.x)
        (
            u"http://*****:*****@localhost/",
            Url(u"http", auth=u"foo:bar", host=u"localhost", path=u"/"),
        ),
        (
            "http://*****:*****@localhost/",
            Url("http", auth="foo:bar", host="localhost", path="/"),
        ),
    ]

    non_round_tripping_parse_url_host_map = [
        # Path/query/fragment
        ("?", Url(path="", query="")),
        ("#", Url(path="", fragment="")),
        # Path normalization
        ("/abc/../def", Url(path="/def")),
        # Empty Port
        ("http://google.com:", Url("http", host="google.com")),
        ("http://google.com:/", Url("http", host="google.com", path="/")),
        # Uppercase IRI
        (
            u"http://Königsgäßchen.de/straße",
            Url("http", host="xn--knigsgchen-b4a3dun.de", path="/stra%C3%9Fe"),
        ),
        # Percent-encode in userinfo
        (
            u"http://[email protected]:[email protected]/",
            Url("http", auth="user%40email.com:password", host="example.com", path="/"),
        ),
        (
            u'http://user":[email protected]/',
            Url("http", auth="user%22:quoted", host="example.com", path="/"),
        ),
        # Unicode Surrogates
        (u"http://google.com/\uD800", Url("http", host="google.com", path="%ED%A0%80")),
        (
            u"http://google.com?q=\uDC00",
            Url("http", host="google.com", path="", query="q=%ED%B0%80"),
        ),
        (
            u"http://google.com#\uDC00",
            Url("http", host="google.com", path="", fragment="%ED%B0%80"),
        ),
    ]

    @pytest.mark.parametrize(
        "url, expected_url",
        chain(parse_url_host_map, non_round_tripping_parse_url_host_map),
    )
    def test_parse_url(self, url, expected_url):
        returned_url = parse_url(url)
        assert returned_url == expected_url

    @pytest.mark.parametrize("url, expected_url", parse_url_host_map)
    def test_unparse_url(self, url, expected_url):
        assert url == expected_url.url

    @pytest.mark.parametrize(
        ["url", "expected_url"],
        [
            # RFC 3986 5.2.4
            ("/abc/../def", Url(path="/def")),
            ("/..", Url(path="/")),
            ("/./abc/./def/", Url(path="/abc/def/")),
            ("/.", Url(path="/")),
            ("/./", Url(path="/")),
            ("/abc/./.././d/././e/.././f/./../../ghi", Url(path="/ghi")),
        ],
    )
    def test_parse_and_normalize_url_paths(self, url, expected_url):
        actual_url = parse_url(url)
        assert actual_url == expected_url
        assert actual_url.url == expected_url.url

    def test_parse_url_invalid_IPv6(self):
        with pytest.raises(LocationParseError):
            parse_url("[::1")

    def test_parse_url_negative_port(self):
        with pytest.raises(LocationParseError):
            parse_url("https://www.google.com:-80/")

    def test_Url_str(self):
        U = Url("http", host="google.com")
        assert str(U) == U.url

    request_uri_map = [
        ("http://google.com/mail", "/mail"),
        ("http://google.com/mail/", "/mail/"),
        ("http://google.com/", "/"),
        ("http://google.com", "/"),
        ("", "/"),
        ("/", "/"),
        ("?", "/?"),
        ("#", "/"),
        ("/foo?bar=baz", "/foo?bar=baz"),
    ]

    @pytest.mark.parametrize("url, expected_request_uri", request_uri_map)
    def test_request_uri(self, url, expected_request_uri):
        returned_url = parse_url(url)
        assert returned_url.request_uri == expected_request_uri

    url_netloc_map = [
        ("http://google.com/mail", "google.com"),
        ("http://google.com:80/mail", "google.com:80"),
        ("google.com/foobar", "google.com"),
        ("google.com:12345", "google.com:12345"),
    ]

    @pytest.mark.parametrize("url, expected_netloc", url_netloc_map)
    def test_netloc(self, url, expected_netloc):
        assert parse_url(url).netloc == expected_netloc

    url_vulnerabilities = [
        # urlparse doesn't follow RFC 3986 Section 3.2
        (
            "http://google.com#@evil.com/",
            Url("http", host="google.com", path="", fragment="@evil.com/"),
        ),
        # CVE-2016-5699
        (
            "http://127.0.0.1%0d%0aConnection%3a%20keep-alive",
            Url("http", host="127.0.0.1%0d%0aconnection%3a%20keep-alive"),
        ),
        # NodeJS unicode -> double dot
        (
            u"http://google.com/\uff2e\uff2e/abc",
            Url("http", host="google.com", path="/%EF%BC%AE%EF%BC%AE/abc"),
        ),
        # Scheme without ://
        (
            "javascript:a='@google.com:12345/';alert(0)",
            Url(scheme="javascript", path="a='@google.com:12345/';alert(0)"),
        ),
        ("//google.com/a/b/c", Url(host="google.com", path="/a/b/c")),
        # International URLs
        (
            u"http://ヒ:キ@ヒ.abc.ニ/ヒ?キ#ワ",
            Url(
                u"http",
                host=u"xn--pdk.abc.xn--idk",
                auth=u"%E3%83%92:%E3%82%AD",
                path=u"/%E3%83%92",
                query=u"%E3%82%AD",
                fragment=u"%E3%83%AF",
            ),
        ),
        # Injected headers (CVE-2016-5699, CVE-2019-9740, CVE-2019-9947)
        (
            "10.251.0.83:7777?a=1 HTTP/1.1\r\nX-injected: header",
            Url(
                host="10.251.0.83",
                port=7777,
                path="",
                query="a=1%20HTTP/1.1%0D%0AX-injected:%20header",
            ),
        ),
        (
            "http://127.0.0.1:6379?\r\nSET test failure12\r\n:8080/test/?test=a",
            Url(
                scheme="http",
                host="127.0.0.1",
                port=6379,
                path="",
                query="%0D%0ASET%20test%20failure12%0D%0A:8080/test/?test=a",
            ),
        ),
        # See https://bugs.xdavidhu.me/google/2020/03/08/the-unexpected-google-wide-domain-check-bypass/
        (
            "https://*****:*****@xdavidhu.me\\test.corp.google.com:8080/path/to/something?param=value#hash",
            Url(
                scheme="https",
                auth="user:pass",
                host="xdavidhu.me",
                path="/%5Ctest.corp.google.com:8080/path/to/something",
                query="param=value",
                fragment="hash",
            ),
        ),
    ]

    @pytest.mark.parametrize("url, expected_url", url_vulnerabilities)
    def test_url_vulnerabilities(self, url, expected_url):
        if expected_url is False:
            with pytest.raises(LocationParseError):
                parse_url(url)
        else:
            assert parse_url(url) == expected_url

    @onlyPy2
    def test_parse_url_bytes_to_str_python_2(self):
        url = parse_url(b"https://www.google.com/")
        assert url == Url("https", host="www.google.com", path="/")

        assert isinstance(url.scheme, str)
        assert isinstance(url.host, str)
        assert isinstance(url.path, str)

    @onlyPy2
    def test_parse_url_unicode_python_2(self):
        url = parse_url(u"https://www.google.com/")
        assert url == Url(u"https", host=u"www.google.com", path=u"/")

        assert isinstance(url.scheme, six.text_type)
        assert isinstance(url.host, six.text_type)
        assert isinstance(url.path, six.text_type)

    @onlyPy3
    def test_parse_url_bytes_type_error_python_3(self):
        with pytest.raises(TypeError):
            parse_url(b"https://www.google.com/")

    @pytest.mark.parametrize(
        "kwargs, expected",
        [
            pytest.param(
                {"accept_encoding": True},
                {"accept-encoding": "gzip,deflate,br"},
                marks=onlyBrotlipy(),
            ),
            pytest.param(
                {"accept_encoding": True},
                {"accept-encoding": "gzip,deflate"},
                marks=notBrotlipy(),
            ),
            ({"accept_encoding": "foo,bar"}, {"accept-encoding": "foo,bar"}),
            ({"accept_encoding": ["foo", "bar"]}, {"accept-encoding": "foo,bar"}),
            pytest.param(
                {"accept_encoding": True, "user_agent": "banana"},
                {"accept-encoding": "gzip,deflate,br", "user-agent": "banana"},
                marks=onlyBrotlipy(),
            ),
            pytest.param(
                {"accept_encoding": True, "user_agent": "banana"},
                {"accept-encoding": "gzip,deflate", "user-agent": "banana"},
                marks=notBrotlipy(),
            ),
            ({"user_agent": "banana"}, {"user-agent": "banana"}),
            ({"keep_alive": True}, {"connection": "keep-alive"}),
            ({"basic_auth": "foo:bar"}, {"authorization": "Basic Zm9vOmJhcg=="}),
            (
                {"proxy_basic_auth": "foo:bar"},
                {"proxy-authorization": "Basic Zm9vOmJhcg=="},
            ),
            ({"disable_cache": True}, {"cache-control": "no-cache"}),
        ],
    )
    def test_make_headers(self, kwargs, expected):
        assert make_headers(**kwargs) == expected

    def test_rewind_body(self):
        body = io.BytesIO(b"test data")
        assert body.read() == b"test data"

        # Assert the file object has been consumed
        assert body.read() == b""

        # Rewind it back to just be b'data'
        rewind_body(body, 5)
        assert body.read() == b"data"

    def test_rewind_body_failed_tell(self):
        body = io.BytesIO(b"test data")
        body.read()  # Consume body

        # Simulate failed tell()
        body_pos = _FAILEDTELL
        with pytest.raises(UnrewindableBodyError):
            rewind_body(body, body_pos)

    def test_rewind_body_bad_position(self):
        body = io.BytesIO(b"test data")
        body.read()  # Consume body

        # Pass non-integer position
        with pytest.raises(ValueError):
            rewind_body(body, body_pos=None)
        with pytest.raises(ValueError):
            rewind_body(body, body_pos=object())

    def test_rewind_body_failed_seek(self):
        class BadSeek:
            def seek(self, pos, offset=0):
                raise IOError

        with pytest.raises(UnrewindableBodyError):
            rewind_body(BadSeek(), body_pos=2)

    @pytest.mark.parametrize(
        "input, expected",
        [
            (("abcd", "b"), ("a", "cd", "b")),
            (("abcd", "cb"), ("a", "cd", "b")),
            (("abcd", ""), ("abcd", "", None)),
            (("abcd", "a"), ("", "bcd", "a")),
            (("abcd", "ab"), ("", "bcd", "a")),
            (("abcd", "eb"), ("a", "cd", "b")),
        ],
    )
    def test_split_first(self, input, expected):
        output = split_first(*input)
        assert output == expected

    def test_add_stderr_logger(self):
        handler = add_stderr_logger(level=logging.INFO)  # Don't actually print debug
        logger = logging.getLogger("urllib3")
        assert handler in logger.handlers

        logger.debug("Testing add_stderr_logger")
        logger.removeHandler(handler)

    def test_disable_warnings(self):
        with warnings.catch_warnings(record=True) as w:
            clear_warnings()
            warnings.warn("This is a test.", InsecureRequestWarning)
            assert len(w) == 1
            disable_warnings()
            warnings.warn("This is a test.", InsecureRequestWarning)
            assert len(w) == 1

    def _make_time_pass(self, seconds, timeout, time_mock):
        """ Make some time pass for the timeout object """
        time_mock.return_value = TIMEOUT_EPOCH
        timeout.start_connect()
        time_mock.return_value = TIMEOUT_EPOCH + seconds
        return timeout

    @pytest.mark.parametrize(
        "kwargs, message",
        [
            ({"total": -1}, "less than"),
            ({"connect": 2, "total": -1}, "less than"),
            ({"read": -1}, "less than"),
            ({"connect": False}, "cannot be a boolean"),
            ({"read": True}, "cannot be a boolean"),
            ({"connect": 0}, "less than or equal"),
            ({"read": "foo"}, "int, float or None"),
        ],
    )
    def test_invalid_timeouts(self, kwargs, message):
        with pytest.raises(ValueError) as e:
            Timeout(**kwargs)
        assert message in str(e.value)

    @patch("urllib3.util.timeout.current_time")
    def test_timeout(self, current_time):
        timeout = Timeout(total=3)

        # make 'no time' elapse
        timeout = self._make_time_pass(
            seconds=0, timeout=timeout, time_mock=current_time
        )
        assert timeout.read_timeout == 3
        assert timeout.connect_timeout == 3

        timeout = Timeout(total=3, connect=2)
        assert timeout.connect_timeout == 2

        timeout = Timeout()
        assert timeout.connect_timeout == Timeout.DEFAULT_TIMEOUT

        # Connect takes 5 seconds, leaving 5 seconds for read
        timeout = Timeout(total=10, read=7)
        timeout = self._make_time_pass(
            seconds=5, timeout=timeout, time_mock=current_time
        )
        assert timeout.read_timeout == 5

        # Connect takes 2 seconds, read timeout still 7 seconds
        timeout = Timeout(total=10, read=7)
        timeout = self._make_time_pass(
            seconds=2, timeout=timeout, time_mock=current_time
        )
        assert timeout.read_timeout == 7

        timeout = Timeout(total=10, read=7)
        assert timeout.read_timeout == 7

        timeout = Timeout(total=None, read=None, connect=None)
        assert timeout.connect_timeout is None
        assert timeout.read_timeout is None
        assert timeout.total is None

        timeout = Timeout(5)
        assert timeout.total == 5

    def test_timeout_str(self):
        timeout = Timeout(connect=1, read=2, total=3)
        assert str(timeout) == "Timeout(connect=1, read=2, total=3)"
        timeout = Timeout(connect=1, read=None, total=3)
        assert str(timeout) == "Timeout(connect=1, read=None, total=3)"

    @patch("urllib3.util.timeout.current_time")
    def test_timeout_elapsed(self, current_time):
        current_time.return_value = TIMEOUT_EPOCH
        timeout = Timeout(total=3)
        with pytest.raises(TimeoutStateError):
            timeout.get_connect_duration()

        timeout.start_connect()
        with pytest.raises(TimeoutStateError):
            timeout.start_connect()

        current_time.return_value = TIMEOUT_EPOCH + 2
        assert timeout.get_connect_duration() == 2
        current_time.return_value = TIMEOUT_EPOCH + 37
        assert timeout.get_connect_duration() == 37

    @pytest.mark.parametrize(
        "candidate, requirements",
        [
            (None, ssl.CERT_REQUIRED),
            (ssl.CERT_NONE, ssl.CERT_NONE),
            (ssl.CERT_REQUIRED, ssl.CERT_REQUIRED),
            ("REQUIRED", ssl.CERT_REQUIRED),
            ("CERT_REQUIRED", ssl.CERT_REQUIRED),
        ],
    )
    def test_resolve_cert_reqs(self, candidate, requirements):
        assert resolve_cert_reqs(candidate) == requirements

    @pytest.mark.parametrize(
        "candidate, version",
        [
            (ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1),
            ("PROTOCOL_TLSv1", ssl.PROTOCOL_TLSv1),
            ("TLSv1", ssl.PROTOCOL_TLSv1),
            (ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23),
        ],
    )
    def test_resolve_ssl_version(self, candidate, version):
        assert resolve_ssl_version(candidate) == version

    def test_is_fp_closed_object_supports_closed(self):
        class ClosedFile(object):
            @property
            def closed(self):
                return True

        assert is_fp_closed(ClosedFile())

    def test_is_fp_closed_object_has_none_fp(self):
        class NoneFpFile(object):
            @property
            def fp(self):
                return None

        assert is_fp_closed(NoneFpFile())

    def test_is_fp_closed_object_has_fp(self):
        class FpFile(object):
            @property
            def fp(self):
                return True

        assert not is_fp_closed(FpFile())

    def test_is_fp_closed_object_has_neither_fp_nor_closed(self):
        class NotReallyAFile(object):
            pass

        with pytest.raises(ValueError):
            is_fp_closed(NotReallyAFile())

    def test_ssl_wrap_socket_loads_the_cert_chain(self):
        socket = object()
        mock_context = Mock()
        ssl_wrap_socket(
            ssl_context=mock_context, sock=socket, certfile="/path/to/certfile"
        )

        mock_context.load_cert_chain.assert_called_once_with("/path/to/certfile", None)

    @patch("urllib3.util.ssl_.create_urllib3_context")
    def test_ssl_wrap_socket_creates_new_context(self, create_urllib3_context):
        socket = object()
        ssl_wrap_socket(sock=socket, cert_reqs="CERT_REQUIRED")

        create_urllib3_context.assert_called_once_with(
            None, "CERT_REQUIRED", ciphers=None
        )

    def test_ssl_wrap_socket_loads_verify_locations(self):
        socket = object()
        mock_context = Mock()
        ssl_wrap_socket(ssl_context=mock_context, ca_certs="/path/to/pem", sock=socket)
        mock_context.load_verify_locations.assert_called_once_with(
            "/path/to/pem", None, None
        )

    def test_ssl_wrap_socket_loads_certificate_directories(self):
        socket = object()
        mock_context = Mock()
        ssl_wrap_socket(
            ssl_context=mock_context, ca_cert_dir="/path/to/pems", sock=socket
        )
        mock_context.load_verify_locations.assert_called_once_with(
            None, "/path/to/pems", None
        )

    def test_ssl_wrap_socket_loads_certificate_data(self):
        socket = object()
        mock_context = Mock()
        ssl_wrap_socket(
            ssl_context=mock_context, ca_cert_data="TOTALLY PEM DATA", sock=socket
        )
        mock_context.load_verify_locations.assert_called_once_with(
            None, None, "TOTALLY PEM DATA"
        )

    def test_ssl_wrap_socket_with_no_sni_warns(self):
        socket = object()
        mock_context = Mock()
        # Ugly preservation of original value
        HAS_SNI = ssl_.HAS_SNI
        ssl_.HAS_SNI = False
        try:
            with patch("warnings.warn") as warn:
                ssl_wrap_socket(
                    ssl_context=mock_context,
                    sock=socket,
                    server_hostname="www.google.com",
                )
            mock_context.wrap_socket.assert_called_once_with(socket)
            assert warn.call_count >= 1
            warnings = [call[0][1] for call in warn.call_args_list]
            assert SNIMissingWarning in warnings
        finally:
            ssl_.HAS_SNI = HAS_SNI

    def test_const_compare_digest_fallback(self):
        target = hashlib.sha256(b"abcdef").digest()
        assert _const_compare_digest_backport(target, target)

        prefix = target[:-1]
        assert not _const_compare_digest_backport(target, prefix)

        suffix = target + b"0"
        assert not _const_compare_digest_backport(target, suffix)

        incorrect = hashlib.sha256(b"xyz").digest()
        assert not _const_compare_digest_backport(target, incorrect)

    def test_has_ipv6_disabled_on_compile(self):
        with patch("socket.has_ipv6", False):
            assert not _has_ipv6("::1")

    def test_has_ipv6_enabled_but_fails(self):
        with patch("socket.has_ipv6", True):
            with patch("socket.socket") as mock:
                instance = mock.return_value
                instance.bind = Mock(side_effect=Exception("No IPv6 here!"))
                assert not _has_ipv6("::1")

    def test_has_ipv6_enabled_and_working(self):
        with patch("socket.has_ipv6", True):
            with patch("socket.socket") as mock:
                instance = mock.return_value
                instance.bind.return_value = True
                assert _has_ipv6("::1")

    def test_has_ipv6_disabled_on_appengine(self):
        gae_patch = patch(
            "urllib3.contrib._appengine_environ.is_appengine_sandbox", return_value=True
        )
        with gae_patch:
            assert not _has_ipv6("::1")

    def test_ip_family_ipv6_enabled(self):
        with patch("urllib3.util.connection.HAS_IPV6", True):
            assert allowed_gai_family() == socket.AF_UNSPEC

    def test_ip_family_ipv6_disabled(self):
        with patch("urllib3.util.connection.HAS_IPV6", False):
            assert allowed_gai_family() == socket.AF_INET

    @pytest.mark.parametrize("headers", [b"foo", None, object])
    def test_assert_header_parsing_throws_typeerror_with_non_headers(self, headers):
        with pytest.raises(TypeError):
            assert_header_parsing(headers)
예제 #22
0
파일: test_util.py 프로젝트: toywei/urllib3
class TestUtil(object):

    url_host_map = [
        # Hosts
        ('http://google.com/mail', ('http', 'google.com', None)),
        ('http://google.com/mail/', ('http', 'google.com', None)),
        ('google.com/mail', ('http', 'google.com', None)),
        ('http://google.com/', ('http', 'google.com', None)),
        ('http://google.com', ('http', 'google.com', None)),
        ('http://www.google.com', ('http', 'www.google.com', None)),
        ('http://mail.google.com', ('http', 'mail.google.com', None)),
        ('http://google.com:8000/mail/', ('http', 'google.com', 8000)),
        ('http://google.com:8000', ('http', 'google.com', 8000)),
        ('https://google.com', ('https', 'google.com', None)),
        ('https://google.com:8000', ('https', 'google.com', 8000)),
        ('http://*****:*****@127.0.0.1:1234', ('http', '127.0.0.1', 1234)),
        ('http://google.com/foo=http://bar:42/baz', ('http', 'google.com', None)),
        ('http://google.com?foo=http://bar:42/baz', ('http', 'google.com', None)),
        ('http://google.com#foo=http://bar:42/baz', ('http', 'google.com', None)),

        # IPv4
        ('173.194.35.7', ('http', '173.194.35.7', None)),
        ('http://173.194.35.7', ('http', '173.194.35.7', None)),
        ('http://173.194.35.7/test', ('http', '173.194.35.7', None)),
        ('http://173.194.35.7:80', ('http', '173.194.35.7', 80)),
        ('http://173.194.35.7:80/test', ('http', '173.194.35.7', 80)),

        # IPv6
        ('[2a00:1450:4001:c01::67]', ('http', '[2a00:1450:4001:c01::67]', None)),
        ('http://[2a00:1450:4001:c01::67]', ('http', '[2a00:1450:4001:c01::67]', None)),
        ('http://[2a00:1450:4001:c01::67]/test', ('http', '[2a00:1450:4001:c01::67]', None)),
        ('http://[2a00:1450:4001:c01::67]:80', ('http', '[2a00:1450:4001:c01::67]', 80)),
        ('http://[2a00:1450:4001:c01::67]:80/test', ('http', '[2a00:1450:4001:c01::67]', 80)),

        # More IPv6 from http://www.ietf.org/rfc/rfc2732.txt
        ('http://[fedc:ba98:7654:3210:fedc:ba98:7654:3210]:8000/index.html', (
            'http', '[fedc:ba98:7654:3210:fedc:ba98:7654:3210]', 8000)),
        ('http://[1080:0:0:0:8:800:200c:417a]/index.html', (
            'http', '[1080:0:0:0:8:800:200c:417a]', None)),
        ('http://[3ffe:2a00:100:7031::1]', ('http', '[3ffe:2a00:100:7031::1]', None)),
        ('http://[1080::8:800:200c:417a]/foo', ('http', '[1080::8:800:200c:417a]', None)),
        ('http://[::192.9.5.5]/ipng', ('http', '[::192.9.5.5]', None)),
        ('http://[::ffff:129.144.52.38]:42/index.html', ('http', '[::ffff:129.144.52.38]', 42)),
        ('http://[2010:836b:4179::836b:4179]', ('http', '[2010:836b:4179::836b:4179]', None)),

        # Hosts
        ('HTTP://GOOGLE.COM/mail/', ('http', 'google.com', None)),
        ('GOogle.COM/mail', ('http', 'google.com', None)),
        ('HTTP://GoOgLe.CoM:8000/mail/', ('http', 'google.com', 8000)),
        ('HTTP://*****:*****@EXAMPLE.COM:1234', ('http', 'example.com', 1234)),
        ('173.194.35.7', ('http', '173.194.35.7', None)),
        ('HTTP://173.194.35.7', ('http', '173.194.35.7', None)),
        ('HTTP://[2a00:1450:4001:c01::67]:80/test', ('http', '[2a00:1450:4001:c01::67]', 80)),
        ('HTTP://[FEDC:BA98:7654:3210:FEDC:BA98:7654:3210]:8000/index.html', (
            'http', '[fedc:ba98:7654:3210:fedc:ba98:7654:3210]', 8000)),
        ('HTTPS://[1080:0:0:0:8:800:200c:417A]/index.html', (
            'https', '[1080:0:0:0:8:800:200c:417a]', None)),
        ('abOut://eXamPlE.com?info=1', ('about', 'eXamPlE.com', None)),
        ('http+UNIX://%2fvar%2frun%2fSOCKET/path', (
            'http+unix', '%2fvar%2frun%2fSOCKET', None)),
    ]

    @pytest.mark.parametrize('url, expected_host', url_host_map)
    def test_get_host(self, url, expected_host):
        returned_host = get_host(url)
        assert returned_host == expected_host

    # TODO: Add more tests
    @pytest.mark.parametrize('location', [
        'http://google.com:foo',
        'http://::1/',
        'http://::1:80/',
        'http://google.com:-80',
        six.u('http://google.com:\xb2\xb2'),  # \xb2 = ^2
    ])
    def test_invalid_host(self, location):
        with pytest.raises(LocationParseError):
            get_host(location)

    @pytest.mark.parametrize('url, expected_normalized_url', [
        ('HTTP://GOOGLE.COM/MAIL/', 'http://google.com/MAIL/'),
        ('HTTP://*****:*****@Example.com:8080/',
         'http://*****:*****@example.com:8080/'),
        ('HTTPS://Example.Com/?Key=Value', 'https://example.com/?Key=Value'),
        ('Https://Example.Com/#Fragment', 'https://example.com/#Fragment'),
    ])
    def test_parse_url_normalization(self, url, expected_normalized_url):
        """Assert parse_url normalizes the scheme/host, and only the scheme/host"""
        actual_normalized_url = parse_url(url).url
        assert actual_normalized_url == expected_normalized_url

    parse_url_host_map = [
        ('http://google.com/mail', Url('http', host='google.com', path='/mail')),
        ('http://google.com/mail/', Url('http', host='google.com', path='/mail/')),
        ('http://google.com/mail', Url('http', host='google.com', path='mail')),
        ('google.com/mail', Url(host='google.com', path='/mail')),
        ('http://google.com/', Url('http', host='google.com', path='/')),
        ('http://google.com', Url('http', host='google.com')),
        ('http://google.com?foo', Url('http', host='google.com', path='', query='foo')),

        # Path/query/fragment
        ('', Url()),
        ('/', Url(path='/')),
        ('/abc/../def', Url(path="/abc/../def")),
        ('#?/!google.com/?foo#bar', Url(path='', fragment='?/!google.com/?foo#bar')),
        ('/foo', Url(path='/foo')),
        ('/foo?bar=baz', Url(path='/foo', query='bar=baz')),
        ('/foo?bar=baz#banana?apple/orange', Url(path='/foo',
                                                 query='bar=baz',
                                                 fragment='banana?apple/orange')),
        ('/redirect?target=http://localhost:61020/', Url(path='redirect',
                                                         query='target=http://localhost:61020/')),

        # Port
        ('http://google.com/', Url('http', host='google.com', path='/')),
        ('http://google.com:80/', Url('http', host='google.com', port=80, path='/')),
        ('http://google.com:80', Url('http', host='google.com', port=80)),

        # Auth
        ('http://*****:*****@localhost/', Url('http', auth='foo:bar', host='localhost', path='/')),
        ('http://foo@localhost/', Url('http', auth='foo', host='localhost', path='/')),
        ('http://*****:*****@baz@localhost/', Url('http',
                                              auth='foo:bar@baz',
                                              host='localhost',
                                              path='/'))
    ]

    non_round_tripping_parse_url_host_map = [
        # Path/query/fragment
        ('?', Url(path='', query='')),
        ('#', Url(path='', fragment='')),

        # Empty Port
        ('http://google.com:', Url('http', host='google.com')),
        ('http://google.com:/', Url('http', host='google.com', path='/')),
    ]

    @pytest.mark.parametrize(
        'url, expected_url',
        chain(parse_url_host_map, non_round_tripping_parse_url_host_map)
    )
    def test_parse_url(self, url, expected_url):
        returned_url = parse_url(url)
        assert returned_url == expected_url

    @pytest.mark.parametrize('url, expected_url', parse_url_host_map)
    def test_unparse_url(self, url, expected_url):
        assert url == expected_url.url

    def test_parse_url_invalid_IPv6(self):
        with pytest.raises(LocationParseError):
            parse_url('[::1')

    def test_parse_url_negative_port(self):
        with pytest.raises(LocationParseError):
            parse_url("https://www.google.com:-80/")

    def test_Url_str(self):
        U = Url('http', host='google.com')
        assert str(U) == U.url

    request_uri_map = [
        ('http://google.com/mail', '/mail'),
        ('http://google.com/mail/', '/mail/'),
        ('http://google.com/', '/'),
        ('http://google.com', '/'),
        ('', '/'),
        ('/', '/'),
        ('?', '/?'),
        ('#', '/'),
        ('/foo?bar=baz', '/foo?bar=baz'),
    ]

    @pytest.mark.parametrize('url, expected_request_uri', request_uri_map)
    def test_request_uri(self, url, expected_request_uri):
        returned_url = parse_url(url)
        assert returned_url.request_uri == expected_request_uri

    url_netloc_map = [
        ('http://google.com/mail', 'google.com'),
        ('http://google.com:80/mail', 'google.com:80'),
        ('google.com/foobar', 'google.com'),
        ('google.com:12345', 'google.com:12345'),
    ]

    @pytest.mark.parametrize('url, expected_netloc', url_netloc_map)
    def test_netloc(self, url, expected_netloc):
        assert parse_url(url).netloc == expected_netloc

    url_vulnerabilities = [
        # urlparse doesn't follow RFC 3986 Section 3.2
        ("http://google.com#@evil.com/", Url("http",
                                             host="google.com",
                                             path="",
                                             fragment="@evil.com/")),

        # CVE-2016-5699
        ("http://127.0.0.1%0d%0aConnection%3a%20keep-alive",
         Url("http", host="127.0.0.1%0d%0aConnection%3a%20keep-alive")),

        # NodeJS unicode -> double dot
        (u"http://google.com/\uff2e\uff2e/abc", Url("http",
                                                    host="google.com",
                                                    path='/%ef%bc%ae%ef%bc%ae/abc'))
    ]

    @pytest.mark.parametrize("url, expected_url", url_vulnerabilities)
    def test_url_vulnerabilities(self, url, expected_url):
        if expected_url is False:
            with pytest.raises(LocationParseError):
                parse_url(url)
        else:
            assert parse_url(url) == expected_url

    @pytest.mark.parametrize('kwargs, expected', [
        ({'accept_encoding': True},
         {'accept-encoding': 'gzip,deflate'}),
        ({'accept_encoding': 'foo,bar'},
         {'accept-encoding': 'foo,bar'}),
        ({'accept_encoding': ['foo', 'bar']},
         {'accept-encoding': 'foo,bar'}),
        ({'accept_encoding': True, 'user_agent': 'banana'},
         {'accept-encoding': 'gzip,deflate', 'user-agent': 'banana'}),
        ({'user_agent': 'banana'},
         {'user-agent': 'banana'}),
        ({'keep_alive': True},
         {'connection': 'keep-alive'}),
        ({'basic_auth': 'foo:bar'},
         {'authorization': 'Basic Zm9vOmJhcg=='}),
        ({'proxy_basic_auth': 'foo:bar'},
         {'proxy-authorization': 'Basic Zm9vOmJhcg=='}),
        ({'disable_cache': True},
         {'cache-control': 'no-cache'}),
    ])
    def test_make_headers(self, kwargs, expected):
        assert make_headers(**kwargs) == expected

    def test_rewind_body(self):
        body = io.BytesIO(b'test data')
        assert body.read() == b'test data'

        # Assert the file object has been consumed
        assert body.read() == b''

        # Rewind it back to just be b'data'
        rewind_body(body, 5)
        assert body.read() == b'data'

    def test_rewind_body_failed_tell(self):
        body = io.BytesIO(b'test data')
        body.read()  # Consume body

        # Simulate failed tell()
        body_pos = _FAILEDTELL
        with pytest.raises(UnrewindableBodyError):
            rewind_body(body, body_pos)

    def test_rewind_body_bad_position(self):
        body = io.BytesIO(b'test data')
        body.read()  # Consume body

        # Pass non-integer position
        with pytest.raises(ValueError):
            rewind_body(body, body_pos=None)
        with pytest.raises(ValueError):
            rewind_body(body, body_pos=object())

    def test_rewind_body_failed_seek(self):
        class BadSeek():

            def seek(self, pos, offset=0):
                raise IOError

        with pytest.raises(UnrewindableBodyError):
            rewind_body(BadSeek(), body_pos=2)

    @pytest.mark.parametrize('input, expected', [
        (('abcd', 'b'),  ('a', 'cd', 'b')),
        (('abcd', 'cb'), ('a', 'cd', 'b')),
        (('abcd', ''),   ('abcd', '', None)),
        (('abcd', 'a'),  ('', 'bcd', 'a')),
        (('abcd', 'ab'), ('', 'bcd', 'a')),
    ])
    def test_split_first(self, input, expected):
        output = split_first(*input)
        assert output == expected

    def test_add_stderr_logger(self):
        handler = add_stderr_logger(level=logging.INFO)  # Don't actually print debug
        logger = logging.getLogger('urllib3')
        assert handler in logger.handlers

        logger.debug('Testing add_stderr_logger')
        logger.removeHandler(handler)

    def test_disable_warnings(self):
        with warnings.catch_warnings(record=True) as w:
            clear_warnings()
            warnings.warn('This is a test.', InsecureRequestWarning)
            assert len(w) == 1
            disable_warnings()
            warnings.warn('This is a test.', InsecureRequestWarning)
            assert len(w) == 1

    def _make_time_pass(self, seconds, timeout, time_mock):
        """ Make some time pass for the timeout object """
        time_mock.return_value = TIMEOUT_EPOCH
        timeout.start_connect()
        time_mock.return_value = TIMEOUT_EPOCH + seconds
        return timeout

    @pytest.mark.parametrize('kwargs, message', [
        ({'total': -1},                 'less than'),
        ({'connect': 2, 'total': -1},   'less than'),
        ({'read': -1},                  'less than'),
        ({'connect': False},            'cannot be a boolean'),
        ({'read': True},                'cannot be a boolean'),
        ({'connect': 0},                'less than or equal'),
        ({'read': 'foo'},               'int, float or None')
    ])
    def test_invalid_timeouts(self, kwargs, message):
        with pytest.raises(ValueError) as e:
            Timeout(**kwargs)
        assert message in str(e.value)

    @patch('urllib3.util.timeout.current_time')
    def test_timeout(self, current_time):
        timeout = Timeout(total=3)

        # make 'no time' elapse
        timeout = self._make_time_pass(seconds=0, timeout=timeout,
                                       time_mock=current_time)
        assert timeout.read_timeout == 3
        assert timeout.connect_timeout == 3

        timeout = Timeout(total=3, connect=2)
        assert timeout.connect_timeout == 2

        timeout = Timeout()
        assert timeout.connect_timeout == Timeout.DEFAULT_TIMEOUT

        # Connect takes 5 seconds, leaving 5 seconds for read
        timeout = Timeout(total=10, read=7)
        timeout = self._make_time_pass(seconds=5, timeout=timeout,
                                       time_mock=current_time)
        assert timeout.read_timeout == 5

        # Connect takes 2 seconds, read timeout still 7 seconds
        timeout = Timeout(total=10, read=7)
        timeout = self._make_time_pass(seconds=2, timeout=timeout,
                                       time_mock=current_time)
        assert timeout.read_timeout == 7

        timeout = Timeout(total=10, read=7)
        assert timeout.read_timeout == 7

        timeout = Timeout(total=None, read=None, connect=None)
        assert timeout.connect_timeout is None
        assert timeout.read_timeout is None
        assert timeout.total is None

        timeout = Timeout(5)
        assert timeout.total == 5

    def test_timeout_str(self):
        timeout = Timeout(connect=1, read=2, total=3)
        assert str(timeout) == "Timeout(connect=1, read=2, total=3)"
        timeout = Timeout(connect=1, read=None, total=3)
        assert str(timeout) == "Timeout(connect=1, read=None, total=3)"

    @patch('urllib3.util.timeout.current_time')
    def test_timeout_elapsed(self, current_time):
        current_time.return_value = TIMEOUT_EPOCH
        timeout = Timeout(total=3)
        with pytest.raises(TimeoutStateError):
            timeout.get_connect_duration()

        timeout.start_connect()
        with pytest.raises(TimeoutStateError):
            timeout.start_connect()

        current_time.return_value = TIMEOUT_EPOCH + 2
        assert timeout.get_connect_duration() == 2
        current_time.return_value = TIMEOUT_EPOCH + 37
        assert timeout.get_connect_duration() == 37

    @pytest.mark.parametrize('candidate, requirements', [
        (None, ssl.CERT_NONE),
        (ssl.CERT_NONE, ssl.CERT_NONE),
        (ssl.CERT_REQUIRED, ssl.CERT_REQUIRED),
        ('REQUIRED', ssl.CERT_REQUIRED),
        ('CERT_REQUIRED', ssl.CERT_REQUIRED),
    ])
    def test_resolve_cert_reqs(self, candidate, requirements):
        assert resolve_cert_reqs(candidate) == requirements

    @pytest.mark.parametrize('candidate, version', [
        (ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1),
        ("PROTOCOL_TLSv1", ssl.PROTOCOL_TLSv1),
        ("TLSv1", ssl.PROTOCOL_TLSv1),
        (ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23),
    ])
    def test_resolve_ssl_version(self, candidate, version):
        assert resolve_ssl_version(candidate) == version

    def test_is_fp_closed_object_supports_closed(self):
        class ClosedFile(object):
            @property
            def closed(self):
                return True

        assert is_fp_closed(ClosedFile())

    def test_is_fp_closed_object_has_none_fp(self):
        class NoneFpFile(object):
            @property
            def fp(self):
                return None

        assert is_fp_closed(NoneFpFile())

    def test_is_fp_closed_object_has_fp(self):
        class FpFile(object):
            @property
            def fp(self):
                return True

        assert not is_fp_closed(FpFile())

    def test_is_fp_closed_object_has_neither_fp_nor_closed(self):
        class NotReallyAFile(object):
            pass

        with pytest.raises(ValueError):
            is_fp_closed(NotReallyAFile())

    def test_ssl_wrap_socket_loads_the_cert_chain(self):
        socket = object()
        mock_context = Mock()
        ssl_wrap_socket(ssl_context=mock_context, sock=socket,
                        certfile='/path/to/certfile')

        mock_context.load_cert_chain.assert_called_once_with(
            '/path/to/certfile', None
        )

    @patch('urllib3.util.ssl_.create_urllib3_context')
    def test_ssl_wrap_socket_creates_new_context(self,
                                                 create_urllib3_context):
        socket = object()
        ssl_wrap_socket(sock=socket, cert_reqs='CERT_REQUIRED')

        create_urllib3_context.assert_called_once_with(
            None, 'CERT_REQUIRED', ciphers=None
        )

    def test_ssl_wrap_socket_loads_verify_locations(self):
        socket = object()
        mock_context = Mock()
        ssl_wrap_socket(ssl_context=mock_context, ca_certs='/path/to/pem',
                        sock=socket)
        mock_context.load_verify_locations.assert_called_once_with(
            '/path/to/pem', None
        )

    def test_ssl_wrap_socket_loads_certificate_directories(self):
        socket = object()
        mock_context = Mock()
        ssl_wrap_socket(ssl_context=mock_context, ca_cert_dir='/path/to/pems',
                        sock=socket)
        mock_context.load_verify_locations.assert_called_once_with(
            None, '/path/to/pems'
        )

    def test_ssl_wrap_socket_with_no_sni_warns(self):
        socket = object()
        mock_context = Mock()
        # Ugly preservation of original value
        HAS_SNI = ssl_.HAS_SNI
        ssl_.HAS_SNI = False
        try:
            with patch('warnings.warn') as warn:
                ssl_wrap_socket(ssl_context=mock_context, sock=socket,
                                server_hostname='www.google.com')
            mock_context.wrap_socket.assert_called_once_with(socket)
            assert warn.call_count >= 1
            warnings = [call[0][1] for call in warn.call_args_list]
            assert SNIMissingWarning in warnings
        finally:
            ssl_.HAS_SNI = HAS_SNI

    def test_const_compare_digest_fallback(self):
        target = hashlib.sha256(b'abcdef').digest()
        assert _const_compare_digest_backport(target, target)

        prefix = target[:-1]
        assert not _const_compare_digest_backport(target, prefix)

        suffix = target + b'0'
        assert not _const_compare_digest_backport(target, suffix)

        incorrect = hashlib.sha256(b'xyz').digest()
        assert not _const_compare_digest_backport(target, incorrect)

    def test_has_ipv6_disabled_on_compile(self):
        with patch('socket.has_ipv6', False):
            assert not _has_ipv6('::1')

    def test_has_ipv6_enabled_but_fails(self):
        with patch('socket.has_ipv6', True):
            with patch('socket.socket') as mock:
                instance = mock.return_value
                instance.bind = Mock(side_effect=Exception('No IPv6 here!'))
                assert not _has_ipv6('::1')

    def test_has_ipv6_enabled_and_working(self):
        with patch('socket.has_ipv6', True):
            with patch('socket.socket') as mock:
                instance = mock.return_value
                instance.bind.return_value = True
                assert _has_ipv6('::1')

    def test_has_ipv6_disabled_on_appengine(self):
        gae_patch = patch(
            'urllib3.contrib._appengine_environ.is_appengine_sandbox',
            return_value=True)
        with gae_patch:
            assert not _has_ipv6('::1')

    def test_ip_family_ipv6_enabled(self):
        with patch('urllib3.util.connection.HAS_IPV6', True):
            assert allowed_gai_family() == socket.AF_UNSPEC

    def test_ip_family_ipv6_disabled(self):
        with patch('urllib3.util.connection.HAS_IPV6', False):
            assert allowed_gai_family() == socket.AF_INET

    @pytest.mark.parametrize('value', [
        "-1",
        "+1",
        "1.0",
        six.u("\xb2"),  # \xb2 = ^2
    ])
    def test_parse_retry_after_invalid(self, value):
        retry = Retry()
        with pytest.raises(InvalidHeader):
            retry.parse_retry_after(value)

    @pytest.mark.parametrize('value, expected', [
        ("0", 0),
        ("1000", 1000),
        ("\t42 ", 42),
    ])
    def test_parse_retry_after(self, value, expected):
        retry = Retry()
        assert retry.parse_retry_after(value) == expected

    @pytest.mark.parametrize('headers', [
        b'foo',
        None,
        object,
    ])
    def test_assert_header_parsing_throws_typeerror_with_non_headers(self, headers):
        with pytest.raises(TypeError):
            assert_header_parsing(headers)
예제 #23
0
파일: test_util.py 프로젝트: toywei/urllib3
 def test_Url_str(self):
     U = Url('http', host='google.com')
     assert str(U) == U.url
예제 #24
0
class TestUtil(unittest.TestCase):
    def test_get_host(self):
        url_host_map = {
            # Hosts
            'http://google.com/mail': ('http', 'google.com', None),
            'http://google.com/mail/': ('http', 'google.com', None),
            'google.com/mail': ('http', 'google.com', None),
            'http://google.com/': ('http', 'google.com', None),
            'http://google.com': ('http', 'google.com', None),
            'http://www.google.com': ('http', 'www.google.com', None),
            'http://mail.google.com': ('http', 'mail.google.com', None),
            'http://google.com:8000/mail/': ('http', 'google.com', 8000),
            'http://google.com:8000': ('http', 'google.com', 8000),
            'https://google.com': ('https', 'google.com', None),
            'https://google.com:8000': ('https', 'google.com', 8000),
            'http://*****:*****@127.0.0.1:1234': ('http', '127.0.0.1', 1234),
            'http://google.com/foo=http://bar:42/baz': ('http', 'google.com',
                                                        None),
            'http://google.com?foo=http://bar:42/baz': ('http', 'google.com',
                                                        None),
            'http://google.com#foo=http://bar:42/baz': ('http', 'google.com',
                                                        None),

            # IPv4
            '173.194.35.7': ('http', '173.194.35.7', None),
            'http://173.194.35.7': ('http', '173.194.35.7', None),
            'http://173.194.35.7/test': ('http', '173.194.35.7', None),
            'http://173.194.35.7:80': ('http', '173.194.35.7', 80),
            'http://173.194.35.7:80/test': ('http', '173.194.35.7', 80),

            # IPv6
            '[2a00:1450:4001:c01::67]': ('http', '[2a00:1450:4001:c01::67]',
                                         None),
            'http://[2a00:1450:4001:c01::67]':
            ('http', '[2a00:1450:4001:c01::67]', None),
            'http://[2a00:1450:4001:c01::67]/test':
            ('http', '[2a00:1450:4001:c01::67]', None),
            'http://[2a00:1450:4001:c01::67]:80':
            ('http', '[2a00:1450:4001:c01::67]', 80),
            'http://[2a00:1450:4001:c01::67]:80/test':
            ('http', '[2a00:1450:4001:c01::67]', 80),

            # More IPv6 from http://www.ietf.org/rfc/rfc2732.txt
            'http://[FEDC:BA98:7654:3210:FEDC:BA98:7654:3210]:8000/index.html':
            ('http', '[FEDC:BA98:7654:3210:FEDC:BA98:7654:3210]', 8000),
            'http://[1080:0:0:0:8:800:200C:417A]/index.html':
            ('http', '[1080:0:0:0:8:800:200C:417A]', None),
            'http://[3ffe:2a00:100:7031::1]':
            ('http', '[3ffe:2a00:100:7031::1]', None),
            'http://[1080::8:800:200C:417A]/foo':
            ('http', '[1080::8:800:200C:417A]', None),
            'http://[::192.9.5.5]/ipng': ('http', '[::192.9.5.5]', None),
            'http://[::FFFF:129.144.52.38]:42/index.html':
            ('http', '[::FFFF:129.144.52.38]', 42),
            'http://[2010:836B:4179::836B:4179]':
            ('http', '[2010:836B:4179::836B:4179]', None),
        }
        for url, expected_host in url_host_map.items():
            returned_host = get_host(url)
            self.assertEqual(returned_host, expected_host)

    def test_invalid_host(self):
        # TODO: Add more tests
        invalid_host = [
            'http://google.com:foo',
            'http://::1/',
            'http://::1:80/',
        ]

        for location in invalid_host:
            self.assertRaises(LocationParseError, get_host, location)

    parse_url_host_map = {
        'http://google.com/mail':
        Url('http', host='google.com', path='/mail'),
        'http://google.com/mail/':
        Url('http', host='google.com', path='/mail/'),
        'http://google.com/mail':
        Url('http', host='google.com', path='mail'),
        'google.com/mail':
        Url(host='google.com', path='/mail'),
        'http://google.com/':
        Url('http', host='google.com', path='/'),
        'http://google.com':
        Url('http', host='google.com'),
        'http://google.com?foo':
        Url('http', host='google.com', path='', query='foo'),

        # Path/query/fragment
        '':
        Url(),
        '/':
        Url(path='/'),
        '#?/!google.com/?foo#bar':
        Url(path='', fragment='?/!google.com/?foo#bar'),
        '/foo':
        Url(path='/foo'),
        '/foo?bar=baz':
        Url(path='/foo', query='bar=baz'),
        '/foo?bar=baz#banana?apple/orange':
        Url(path='/foo', query='bar=baz', fragment='banana?apple/orange'),

        # Port
        'http://google.com/':
        Url('http', host='google.com', path='/'),
        'http://google.com:80/':
        Url('http', host='google.com', port=80, path='/'),
        'http://google.com:80':
        Url('http', host='google.com', port=80),

        # Auth
        'http://*****:*****@localhost/':
        Url('http', auth='foo:bar', host='localhost', path='/'),
        'http://foo@localhost/':
        Url('http', auth='foo', host='localhost', path='/'),
        'http://*****:*****@baz@localhost/':
        Url('http', auth='foo:bar@baz', host='localhost', path='/'),
        'http://@':
        Url('http', host=None, auth='')
    }

    non_round_tripping_parse_url_host_map = {
        # Path/query/fragment
        '?': Url(path='', query=''),
        '#': Url(path='', fragment=''),

        # Empty Port
        'http://google.com:': Url('http', host='google.com'),
        'http://google.com:/': Url('http', host='google.com', path='/'),
    }

    def test_parse_url(self):
        for url, expected_Url in chain(
                self.parse_url_host_map.items(),
                self.non_round_tripping_parse_url_host_map.items()):
            returned_Url = parse_url(url)
            self.assertEqual(returned_Url, expected_Url)

    def test_unparse_url(self):
        for url, expected_Url in self.parse_url_host_map.items():
            self.assertEqual(url, expected_Url.url)

    def test_parse_url_invalid_IPv6(self):
        self.assertRaises(ValueError, parse_url, '[::1')

    def test_Url_str(self):
        U = Url('http', host='google.com')
        self.assertEqual(str(U), U.url)

    def test_request_uri(self):
        url_host_map = {
            'http://google.com/mail': '/mail',
            'http://google.com/mail/': '/mail/',
            'http://google.com/': '/',
            'http://google.com': '/',
            '': '/',
            '/': '/',
            '?': '/?',
            '#': '/',
            '/foo?bar=baz': '/foo?bar=baz',
        }
        for url, expected_request_uri in url_host_map.items():
            returned_url = parse_url(url)
            self.assertEqual(returned_url.request_uri, expected_request_uri)

    def test_netloc(self):
        url_netloc_map = {
            'http://google.com/mail': 'google.com',
            'http://google.com:80/mail': 'google.com:80',
            'google.com/foobar': 'google.com',
            'google.com:12345': 'google.com:12345',
        }

        for url, expected_netloc in url_netloc_map.items():
            self.assertEqual(parse_url(url).netloc, expected_netloc)

    def test_make_headers(self):
        self.assertEqual(make_headers(accept_encoding=True),
                         {'accept-encoding': 'gzip,deflate'})

        self.assertEqual(make_headers(accept_encoding='foo,bar'),
                         {'accept-encoding': 'foo,bar'})

        self.assertEqual(make_headers(accept_encoding=['foo', 'bar']),
                         {'accept-encoding': 'foo,bar'})

        self.assertEqual(
            make_headers(accept_encoding=True, user_agent='banana'), {
                'accept-encoding': 'gzip,deflate',
                'user-agent': 'banana'
            })

        self.assertEqual(make_headers(user_agent='banana'),
                         {'user-agent': 'banana'})

        self.assertEqual(make_headers(keep_alive=True),
                         {'connection': 'keep-alive'})

        self.assertEqual(make_headers(basic_auth='foo:bar'),
                         {'authorization': 'Basic Zm9vOmJhcg=='})

        self.assertEqual(make_headers(proxy_basic_auth='foo:bar'),
                         {'proxy-authorization': 'Basic Zm9vOmJhcg=='})

        self.assertEqual(make_headers(disable_cache=True),
                         {'cache-control': 'no-cache'})

    def test_split_first(self):
        test_cases = {
            ('abcd', 'b'): ('a', 'cd', 'b'),
            ('abcd', 'cb'): ('a', 'cd', 'b'),
            ('abcd', ''): ('abcd', '', None),
            ('abcd', 'a'): ('', 'bcd', 'a'),
            ('abcd', 'ab'): ('', 'bcd', 'a'),
        }
        for input, expected in test_cases.items():
            output = split_first(*input)
            self.assertEqual(output, expected)

    def test_add_stderr_logger(self):
        handler = add_stderr_logger(
            level=logging.INFO)  # Don't actually print debug
        logger = logging.getLogger('urllib3')
        self.assertTrue(handler in logger.handlers)

        logger.debug('Testing add_stderr_logger')
        logger.removeHandler(handler)

    def test_disable_warnings(self):
        with warnings.catch_warnings(record=True) as w:
            clear_warnings()
            warnings.warn('This is a test.', InsecureRequestWarning)
            self.assertEqual(len(w), 1)
            disable_warnings()
            warnings.warn('This is a test.', InsecureRequestWarning)
            self.assertEqual(len(w), 1)

    def _make_time_pass(self, seconds, timeout, time_mock):
        """ Make some time pass for the timeout object """
        time_mock.return_value = TIMEOUT_EPOCH
        timeout.start_connect()
        time_mock.return_value = TIMEOUT_EPOCH + seconds
        return timeout

    def test_invalid_timeouts(self):
        try:
            Timeout(total=-1)
            self.fail("negative value should throw exception")
        except ValueError as e:
            self.assertTrue('less than' in str(e))
        try:
            Timeout(connect=2, total=-1)
            self.fail("negative value should throw exception")
        except ValueError as e:
            self.assertTrue('less than' in str(e))

        try:
            Timeout(read=-1)
            self.fail("negative value should throw exception")
        except ValueError as e:
            self.assertTrue('less than' in str(e))

        # Booleans are allowed also by socket.settimeout and converted to the
        # equivalent float (1.0 for True, 0.0 for False)
        Timeout(connect=False, read=True)

        try:
            Timeout(read="foo")
            self.fail("string value should not be allowed")
        except ValueError as e:
            self.assertTrue('int or float' in str(e))

    @patch('urllib3.util.timeout.current_time')
    def test_timeout(self, current_time):
        timeout = Timeout(total=3)

        # make 'no time' elapse
        timeout = self._make_time_pass(seconds=0,
                                       timeout=timeout,
                                       time_mock=current_time)
        self.assertEqual(timeout.read_timeout, 3)
        self.assertEqual(timeout.connect_timeout, 3)

        timeout = Timeout(total=3, connect=2)
        self.assertEqual(timeout.connect_timeout, 2)

        timeout = Timeout()
        self.assertEqual(timeout.connect_timeout, Timeout.DEFAULT_TIMEOUT)

        # Connect takes 5 seconds, leaving 5 seconds for read
        timeout = Timeout(total=10, read=7)
        timeout = self._make_time_pass(seconds=5,
                                       timeout=timeout,
                                       time_mock=current_time)
        self.assertEqual(timeout.read_timeout, 5)

        # Connect takes 2 seconds, read timeout still 7 seconds
        timeout = Timeout(total=10, read=7)
        timeout = self._make_time_pass(seconds=2,
                                       timeout=timeout,
                                       time_mock=current_time)
        self.assertEqual(timeout.read_timeout, 7)

        timeout = Timeout(total=10, read=7)
        self.assertEqual(timeout.read_timeout, 7)

        timeout = Timeout(total=None, read=None, connect=None)
        self.assertEqual(timeout.connect_timeout, None)
        self.assertEqual(timeout.read_timeout, None)
        self.assertEqual(timeout.total, None)

        timeout = Timeout(5)
        self.assertEqual(timeout.total, 5)

    def test_timeout_str(self):
        timeout = Timeout(connect=1, read=2, total=3)
        self.assertEqual(str(timeout), "Timeout(connect=1, read=2, total=3)")
        timeout = Timeout(connect=1, read=None, total=3)
        self.assertEqual(str(timeout),
                         "Timeout(connect=1, read=None, total=3)")

    @patch('urllib3.util.timeout.current_time')
    def test_timeout_elapsed(self, current_time):
        current_time.return_value = TIMEOUT_EPOCH
        timeout = Timeout(total=3)
        self.assertRaises(TimeoutStateError, timeout.get_connect_duration)

        timeout.start_connect()
        self.assertRaises(TimeoutStateError, timeout.start_connect)

        current_time.return_value = TIMEOUT_EPOCH + 2
        self.assertEqual(timeout.get_connect_duration(), 2)
        current_time.return_value = TIMEOUT_EPOCH + 37
        self.assertEqual(timeout.get_connect_duration(), 37)

    def test_resolve_cert_reqs(self):
        self.assertEqual(resolve_cert_reqs(None), ssl.CERT_NONE)
        self.assertEqual(resolve_cert_reqs(ssl.CERT_NONE), ssl.CERT_NONE)

        self.assertEqual(resolve_cert_reqs(ssl.CERT_REQUIRED),
                         ssl.CERT_REQUIRED)
        self.assertEqual(resolve_cert_reqs('REQUIRED'), ssl.CERT_REQUIRED)
        self.assertEqual(resolve_cert_reqs('CERT_REQUIRED'), ssl.CERT_REQUIRED)

    def test_is_fp_closed_object_supports_closed(self):
        class ClosedFile(object):
            @property
            def closed(self):
                return True

        self.assertTrue(is_fp_closed(ClosedFile()))

    def test_is_fp_closed_object_has_none_fp(self):
        class NoneFpFile(object):
            @property
            def fp(self):
                return None

        self.assertTrue(is_fp_closed(NoneFpFile()))

    def test_is_fp_closed_object_has_fp(self):
        class FpFile(object):
            @property
            def fp(self):
                return True

        self.assertTrue(not is_fp_closed(FpFile()))

    def test_is_fp_closed_object_has_neither_fp_nor_closed(self):
        class NotReallyAFile(object):
            pass

        self.assertRaises(ValueError, is_fp_closed, NotReallyAFile())

    def test_ssl_wrap_socket_loads_the_cert_chain(self):
        socket = object()
        mock_context = Mock()
        ssl_wrap_socket(ssl_context=mock_context,
                        sock=socket,
                        certfile='/path/to/certfile')

        mock_context.load_cert_chain.assert_called_once_with(
            '/path/to/certfile', None)

    def test_ssl_wrap_socket_loads_verify_locations(self):
        socket = object()
        mock_context = Mock()
        ssl_wrap_socket(ssl_context=mock_context,
                        ca_certs='/path/to/pem',
                        sock=socket)
        mock_context.load_verify_locations.assert_called_once_with(
            '/path/to/pem', None)

    def test_ssl_wrap_socket_loads_certificate_directories(self):
        socket = object()
        mock_context = Mock()
        ssl_wrap_socket(ssl_context=mock_context,
                        ca_cert_dir='/path/to/pems',
                        sock=socket)
        mock_context.load_verify_locations.assert_called_once_with(
            None, '/path/to/pems')

    def test_ssl_wrap_socket_with_no_sni(self):
        socket = object()
        mock_context = Mock()
        # Ugly preservation of original value
        HAS_SNI = ssl_.HAS_SNI
        ssl_.HAS_SNI = False
        ssl_wrap_socket(ssl_context=mock_context, sock=socket)
        mock_context.wrap_socket.assert_called_once_with(socket)
        ssl_.HAS_SNI = HAS_SNI

    def test_const_compare_digest_fallback(self):
        target = hashlib.sha256(b'abcdef').digest()
        self.assertTrue(_const_compare_digest_backport(target, target))

        prefix = target[:-1]
        self.assertFalse(_const_compare_digest_backport(target, prefix))

        suffix = target + b'0'
        self.assertFalse(_const_compare_digest_backport(target, suffix))

        incorrect = hashlib.sha256(b'xyz').digest()
        self.assertFalse(_const_compare_digest_backport(target, incorrect))
예제 #25
0
 def test_Url_str(self):
     U = Url('http', host='google.com')
     self.assertEqual(str(U), U.url)
예제 #26
0
class TestUtil:

    url_host_map = [
        # Hosts
        ("http://google.com/mail", ("http", "google.com", None)),
        ("http://google.com/mail/", ("http", "google.com", None)),
        ("google.com/mail", ("http", "google.com", None)),
        ("http://google.com/", ("http", "google.com", None)),
        ("http://google.com", ("http", "google.com", None)),
        ("http://www.google.com", ("http", "www.google.com", None)),
        ("http://mail.google.com", ("http", "mail.google.com", None)),
        ("http://google.com:8000/mail/", ("http", "google.com", 8000)),
        ("http://google.com:8000", ("http", "google.com", 8000)),
        ("https://google.com", ("https", "google.com", None)),
        ("https://google.com:8000", ("https", "google.com", 8000)),
        ("http://*****:*****@127.0.0.1:1234", ("http", "127.0.0.1", 1234)),
        ("http://google.com/foo=http://bar:42/baz", ("http", "google.com",
                                                     None)),
        ("http://google.com?foo=http://bar:42/baz", ("http", "google.com",
                                                     None)),
        ("http://google.com#foo=http://bar:42/baz", ("http", "google.com",
                                                     None)),
        # IPv4
        ("173.194.35.7", ("http", "173.194.35.7", None)),
        ("http://173.194.35.7", ("http", "173.194.35.7", None)),
        ("http://173.194.35.7/test", ("http", "173.194.35.7", None)),
        ("http://173.194.35.7:80", ("http", "173.194.35.7", 80)),
        ("http://173.194.35.7:80/test", ("http", "173.194.35.7", 80)),
        # IPv6
        ("[2a00:1450:4001:c01::67]", ("http", "[2a00:1450:4001:c01::67]", None)
         ),
        ("http://[2a00:1450:4001:c01::67]",
         ("http", "[2a00:1450:4001:c01::67]", None)),
        (
            "http://[2a00:1450:4001:c01::67]/test",
            ("http", "[2a00:1450:4001:c01::67]", None),
        ),
        (
            "http://[2a00:1450:4001:c01::67]:80",
            ("http", "[2a00:1450:4001:c01::67]", 80),
        ),
        (
            "http://[2a00:1450:4001:c01::67]:80/test",
            ("http", "[2a00:1450:4001:c01::67]", 80),
        ),
        # More IPv6 from http://www.ietf.org/rfc/rfc2732.txt
        (
            "http://[fedc:ba98:7654:3210:fedc:ba98:7654:3210]:8000/index.html",
            ("http", "[fedc:ba98:7654:3210:fedc:ba98:7654:3210]", 8000),
        ),
        (
            "http://[1080:0:0:0:8:800:200c:417a]/index.html",
            ("http", "[1080:0:0:0:8:800:200c:417a]", None),
        ),
        ("http://[3ffe:2a00:100:7031::1]", ("http", "[3ffe:2a00:100:7031::1]",
                                            None)),
        (
            "http://[1080::8:800:200c:417a]/foo",
            ("http", "[1080::8:800:200c:417a]", None),
        ),
        ("http://[::192.9.5.5]/ipng", ("http", "[::192.9.5.5]", None)),
        (
            "http://[::ffff:129.144.52.38]:42/index.html",
            ("http", "[::ffff:129.144.52.38]", 42),
        ),
        (
            "http://[2010:836b:4179::836b:4179]",
            ("http", "[2010:836b:4179::836b:4179]", None),
        ),
        # Hosts
        ("HTTP://GOOGLE.COM/mail/", ("http", "google.com", None)),
        ("GOogle.COM/mail", ("http", "google.com", None)),
        ("HTTP://GoOgLe.CoM:8000/mail/", ("http", "google.com", 8000)),
        ("HTTP://*****:*****@EXAMPLE.COM:1234", ("http", "example.com",
                                                   1234)),
        ("173.194.35.7", ("http", "173.194.35.7", None)),
        ("HTTP://173.194.35.7", ("http", "173.194.35.7", None)),
        (
            "HTTP://[2a00:1450:4001:c01::67]:80/test",
            ("http", "[2a00:1450:4001:c01::67]", 80),
        ),
        (
            "HTTP://[FEDC:BA98:7654:3210:FEDC:BA98:7654:3210]:8000/index.html",
            ("http", "[fedc:ba98:7654:3210:fedc:ba98:7654:3210]", 8000),
        ),
        (
            "HTTPS://[1080:0:0:0:8:800:200c:417A]/index.html",
            ("https", "[1080:0:0:0:8:800:200c:417a]", None),
        ),
        ("abOut://eXamPlE.com?info=1", ("about", "eXamPlE.com", None)),
        (
            "http+UNIX://%2fvar%2frun%2fSOCKET/path",
            ("http+unix", "%2fvar%2frun%2fSOCKET", None),
        ),
    ]

    @pytest.mark.parametrize("url, expected_host", url_host_map)
    def test_get_host(self, url, expected_host):
        returned_host = get_host(url)
        assert returned_host == expected_host

    # TODO: Add more tests
    @pytest.mark.parametrize(
        "location",
        [
            "http://google.com:foo",
            "http://::1/",
            "http://::1:80/",
            "http://google.com:-80",
            "http://google.com:\xb2\xb2",  # \xb2 = ^2
        ],
    )
    def test_invalid_host(self, location):
        with pytest.raises(LocationParseError):
            get_host(location)

    @pytest.mark.parametrize(
        "url",
        [
            # Invalid IDNA labels
            "http://\uD7FF.com",
            "http://❤️",
            # Unicode surrogates
            "http://\uD800.com",
            "http://\uDC00.com",
        ],
    )
    def test_invalid_url(self, url):
        with pytest.raises(LocationParseError):
            parse_url(url)

    @pytest.mark.parametrize(
        "url, expected_normalized_url",
        [
            ("HTTP://GOOGLE.COM/MAIL/", "http://google.com/MAIL/"),
            (
                "http://[email protected]:[email protected]/~tilde@?@",
                "http://user%40domain.com:[email protected]/~tilde@?@",
            ),
            (
                "HTTP://*****:*****@Example.com:8080/",
                "http://*****:*****@example.com:8080/",
            ),
            ("HTTPS://Example.Com/?Key=Value",
             "https://example.com/?Key=Value"),
            ("Https://Example.Com/#Fragment", "https://example.com/#Fragment"),
            ("[::1%25]", "[::1%25]"),
            ("[::Ff%etH0%Ff]/%ab%Af", "[::ff%etH0%FF]/%AB%AF"),
            (
                "http://*****:*****@[AaAa::Ff%25etH0%Ff]/%ab%Af",
                "http://*****:*****@[aaaa::ff%etH0%FF]/%AB%AF",
            ),
            # Invalid characters for the query/fragment getting encoded
            (
                'http://google.com/p[]?parameter[]="hello"#fragment#',
                "http://google.com/p%5B%5D?parameter%5B%5D=%22hello%22#fragment%23",
            ),
            # Percent encoding isn't applied twice despite '%' being invalid
            # but the percent encoding is still normalized.
            (
                "http://google.com/p%5B%5d?parameter%5b%5D=%22hello%22#fragment%23",
                "http://google.com/p%5B%5D?parameter%5B%5D=%22hello%22#fragment%23",
            ),
        ],
    )
    def test_parse_url_normalization(self, url, expected_normalized_url):
        """Assert parse_url normalizes the scheme/host, and only the scheme/host"""
        actual_normalized_url = parse_url(url).url
        assert actual_normalized_url == expected_normalized_url

    @pytest.mark.parametrize("char",
                             [chr(i) for i in range(0x00, 0x21)] + ["\x7F"])
    def test_control_characters_are_percent_encoded(self, char):
        percent_char = "%" + (hex(ord(char))[2:].zfill(2).upper())
        url = parse_url(
            "http://user{0}@example.com/path{0}?query{0}#fragment{0}".format(
                char))

        assert url == Url(
            "http",
            auth="user" + percent_char,
            host="example.com",
            path="/path" + percent_char,
            query="query" + percent_char,
            fragment="fragment" + percent_char,
        )

    parse_url_host_map = [
        ("http://google.com/mail", Url("http", host="google.com",
                                       path="/mail")),
        ("http://google.com/mail/",
         Url("http", host="google.com", path="/mail/")),
        ("http://google.com/mail", Url("http", host="google.com",
                                       path="mail")),
        ("google.com/mail", Url(host="google.com", path="/mail")),
        ("http://google.com/", Url("http", host="google.com", path="/")),
        ("http://google.com", Url("http", host="google.com")),
        ("http://google.com?foo",
         Url("http", host="google.com", path="", query="foo")),
        # Path/query/fragment
        ("", Url()),
        ("/", Url(path="/")),
        ("#?/!google.com/?foo", Url(path="", fragment="?/!google.com/?foo")),
        ("/foo", Url(path="/foo")),
        ("/foo?bar=baz", Url(path="/foo", query="bar=baz")),
        (
            "/foo?bar=baz#banana?apple/orange",
            Url(path="/foo", query="bar=baz", fragment="banana?apple/orange"),
        ),
        (
            "/redirect?target=http://localhost:61020/",
            Url(path="redirect", query="target=http://localhost:61020/"),
        ),
        # Port
        ("http://google.com/", Url("http", host="google.com", path="/")),
        ("http://google.com:80/",
         Url("http", host="google.com", port=80, path="/")),
        ("http://google.com:80", Url("http", host="google.com", port=80)),
        # Auth
        (
            "http://*****:*****@localhost/",
            Url("http", auth="foo:bar", host="localhost", path="/"),
        ),
        ("http://foo@localhost/",
         Url("http", auth="foo", host="localhost", path="/")),
        (
            "http://*****:*****@localhost/",
            Url("http", auth="foo:bar", host="localhost", path="/"),
        ),
    ]

    non_round_tripping_parse_url_host_map = [
        # Path/query/fragment
        ("?", Url(path="", query="")),
        ("#", Url(path="", fragment="")),
        # Path normalization
        ("/abc/../def", Url(path="/def")),
        # Empty Port
        ("http://google.com:", Url("http", host="google.com")),
        ("http://google.com:/", Url("http", host="google.com", path="/")),
        # Uppercase IRI
        (
            "http://Königsgäßchen.de/straße",
            Url("http", host="xn--knigsgchen-b4a3dun.de", path="/stra%C3%9Fe"),
        ),
        # Percent-encode in userinfo
        (
            "http://[email protected]:[email protected]/",
            Url("http",
                auth="user%40email.com:password",
                host="example.com",
                path="/"),
        ),
        (
            'http://user":[email protected]/',
            Url("http", auth="user%22:quoted", host="example.com", path="/"),
        ),
        # Unicode Surrogates
        ("http://google.com/\uD800",
         Url("http", host="google.com", path="%ED%A0%80")),
        (
            "http://google.com?q=\uDC00",
            Url("http", host="google.com", path="", query="q=%ED%B0%80"),
        ),
        (
            "http://google.com#\uDC00",
            Url("http", host="google.com", path="", fragment="%ED%B0%80"),
        ),
    ]

    @pytest.mark.parametrize(
        "url, expected_url",
        chain(parse_url_host_map, non_round_tripping_parse_url_host_map),
    )
    def test_parse_url(self, url, expected_url):
        returned_url = parse_url(url)
        assert returned_url == expected_url

    @pytest.mark.parametrize("url, expected_url", parse_url_host_map)
    def test_unparse_url(self, url, expected_url):
        assert url == expected_url.url

    @pytest.mark.parametrize(
        ["url", "expected_url"],
        [
            # RFC 3986 5.2.4
            ("/abc/../def", Url(path="/def")),
            ("/..", Url(path="/")),
            ("/./abc/./def/", Url(path="/abc/def/")),
            ("/.", Url(path="/")),
            ("/./", Url(path="/")),
            ("/abc/./.././d/././e/.././f/./../../ghi", Url(path="/ghi")),
        ],
    )
    def test_parse_and_normalize_url_paths(self, url, expected_url):
        actual_url = parse_url(url)
        assert actual_url == expected_url
        assert actual_url.url == expected_url.url

    def test_parse_url_invalid_IPv6(self):
        with pytest.raises(LocationParseError):
            parse_url("[::1")

    def test_parse_url_negative_port(self):
        with pytest.raises(LocationParseError):
            parse_url("https://www.google.com:-80/")

    def test_Url_str(self):
        U = Url("http", host="google.com")
        assert str(U) == U.url

    request_uri_map = [
        ("http://google.com/mail", "/mail"),
        ("http://google.com/mail/", "/mail/"),
        ("http://google.com/", "/"),
        ("http://google.com", "/"),
        ("", "/"),
        ("/", "/"),
        ("?", "/?"),
        ("#", "/"),
        ("/foo?bar=baz", "/foo?bar=baz"),
    ]

    @pytest.mark.parametrize("url, expected_request_uri", request_uri_map)
    def test_request_uri(self, url, expected_request_uri):
        returned_url = parse_url(url)
        assert returned_url.request_uri == expected_request_uri

    url_netloc_map = [
        ("http://google.com/mail", "google.com"),
        ("http://google.com:80/mail", "google.com:80"),
        ("google.com/foobar", "google.com"),
        ("google.com:12345", "google.com:12345"),
    ]

    @pytest.mark.parametrize("url, expected_netloc", url_netloc_map)
    def test_netloc(self, url, expected_netloc):
        assert parse_url(url).netloc == expected_netloc

    url_vulnerabilities = [
        # urlparse doesn't follow RFC 3986 Section 3.2
        (
            "http://google.com#@evil.com/",
            Url("http", host="google.com", path="", fragment="@evil.com/"),
        ),
        # CVE-2016-5699
        (
            "http://127.0.0.1%0d%0aConnection%3a%20keep-alive",
            Url("http", host="127.0.0.1%0d%0aconnection%3a%20keep-alive"),
        ),
        # NodeJS unicode -> double dot
        (
            "http://google.com/\uff2e\uff2e/abc",
            Url("http", host="google.com", path="/%EF%BC%AE%EF%BC%AE/abc"),
        ),
        # Scheme without ://
        (
            "javascript:a='@google.com:12345/';alert(0)",
            Url(scheme="javascript", path="a='@google.com:12345/';alert(0)"),
        ),
        ("//google.com/a/b/c", Url(host="google.com", path="/a/b/c")),
        # International URLs
        (
            "http://ヒ:キ@ヒ.abc.ニ/ヒ?キ#ワ",
            Url(
                "http",
                host="xn--pdk.abc.xn--idk",
                auth="%E3%83%92:%E3%82%AD",
                path="/%E3%83%92",
                query="%E3%82%AD",
                fragment="%E3%83%AF",
            ),
        ),
        # Injected headers (CVE-2016-5699, CVE-2019-9740, CVE-2019-9947)
        (
            "10.251.0.83:7777?a=1 HTTP/1.1\r\nX-injected: header",
            Url(
                host="10.251.0.83",
                port=7777,
                path="",
                query="a=1%20HTTP/1.1%0D%0AX-injected:%20header",
            ),
        ),
        (
            "http://127.0.0.1:6379?\r\nSET test failure12\r\n:8080/test/?test=a",
            Url(
                scheme="http",
                host="127.0.0.1",
                port=6379,
                path="",
                query="%0D%0ASET%20test%20failure12%0D%0A:8080/test/?test=a",
            ),
        ),
        # See https://bugs.xdavidhu.me/google/2020/03/08/the-unexpected-google-wide-domain-check-bypass/
        (
            "https://*****:*****@xdavidhu.me\\test.corp.google.com:8080/path/to/something?param=value#hash",
            Url(
                scheme="https",
                auth="user:pass",
                host="xdavidhu.me",
                path="/%5Ctest.corp.google.com:8080/path/to/something",
                query="param=value",
                fragment="hash",
            ),
        ),
    ]

    @pytest.mark.parametrize("url, expected_url", url_vulnerabilities)
    def test_url_vulnerabilities(self, url, expected_url):
        if expected_url is False:
            with pytest.raises(LocationParseError):
                parse_url(url)
        else:
            assert parse_url(url) == expected_url

    def test_parse_url_bytes_type_error(self):
        with pytest.raises(TypeError):
            parse_url(b"https://www.google.com/")

    @pytest.mark.parametrize(
        "kwargs, expected",
        [
            pytest.param(
                {"accept_encoding": True},
                {"accept-encoding": "gzip,deflate,br"},
                marks=onlyBrotlipy(),
            ),
            pytest.param(
                {"accept_encoding": True},
                {"accept-encoding": "gzip,deflate"},
                marks=notBrotlipy(),
            ),
            ({
                "accept_encoding": "foo,bar"
            }, {
                "accept-encoding": "foo,bar"
            }),
            ({
                "accept_encoding": ["foo", "bar"]
            }, {
                "accept-encoding": "foo,bar"
            }),
            pytest.param(
                {
                    "accept_encoding": True,
                    "user_agent": "banana"
                },
                {
                    "accept-encoding": "gzip,deflate,br",
                    "user-agent": "banana"
                },
                marks=onlyBrotlipy(),
            ),
            pytest.param(
                {
                    "accept_encoding": True,
                    "user_agent": "banana"
                },
                {
                    "accept-encoding": "gzip,deflate",
                    "user-agent": "banana"
                },
                marks=notBrotlipy(),
            ),
            ({
                "user_agent": "banana"
            }, {
                "user-agent": "banana"
            }),
            ({
                "keep_alive": True
            }, {
                "connection": "keep-alive"
            }),
            ({
                "basic_auth": "foo:bar"
            }, {
                "authorization": "Basic Zm9vOmJhcg=="
            }),
            (
                {
                    "proxy_basic_auth": "foo:bar"
                },
                {
                    "proxy-authorization": "Basic Zm9vOmJhcg=="
                },
            ),
            ({
                "disable_cache": True
            }, {
                "cache-control": "no-cache"
            }),
        ],
    )
    def test_make_headers(self, kwargs, expected):
        assert make_headers(**kwargs) == expected

    def test_rewind_body(self):
        body = io.BytesIO(b"test data")
        assert body.read() == b"test data"

        # Assert the file object has been consumed
        assert body.read() == b""

        # Rewind it back to just be b'data'
        rewind_body(body, 5)
        assert body.read() == b"data"

    def test_rewind_body_failed_tell(self):
        body = io.BytesIO(b"test data")
        body.read()  # Consume body

        # Simulate failed tell()
        body_pos = _FAILEDTELL
        with pytest.raises(UnrewindableBodyError):
            rewind_body(body, body_pos)

    def test_rewind_body_bad_position(self):
        body = io.BytesIO(b"test data")
        body.read()  # Consume body

        # Pass non-integer position
        with pytest.raises(ValueError):
            rewind_body(body, body_pos=None)
        with pytest.raises(ValueError):
            rewind_body(body, body_pos=object())

    def test_rewind_body_failed_seek(self):
        class BadSeek:
            def seek(self, pos, offset=0):
                raise OSError

        with pytest.raises(UnrewindableBodyError):
            rewind_body(BadSeek(), body_pos=2)

    @pytest.mark.parametrize(
        "input, expected",
        [
            (("abcd", "b"), ("a", "cd", "b")),
            (("abcd", "cb"), ("a", "cd", "b")),
            (("abcd", ""), ("abcd", "", None)),
            (("abcd", "a"), ("", "bcd", "a")),
            (("abcd", "ab"), ("", "bcd", "a")),
            (("abcd", "eb"), ("a", "cd", "b")),
        ],
    )
    def test_split_first(self, input, expected):
        output = split_first(*input)
        assert output == expected

    def test_add_stderr_logger(self):
        handler = add_stderr_logger(
            level=logging.INFO)  # Don't actually print debug
        logger = logging.getLogger("urllib3")
        assert handler in logger.handlers

        logger.debug("Testing add_stderr_logger")
        logger.removeHandler(handler)

    def test_disable_warnings(self):
        with warnings.catch_warnings(record=True) as w:
            clear_warnings()
            warnings.warn("This is a test.", InsecureRequestWarning)
            assert len(w) == 1
            disable_warnings()
            warnings.warn("This is a test.", InsecureRequestWarning)
            assert len(w) == 1

    def _make_time_pass(self, seconds, timeout, time_mock):
        """ Make some time pass for the timeout object """
        time_mock.return_value = TIMEOUT_EPOCH
        timeout.start_connect()
        time_mock.return_value = TIMEOUT_EPOCH + seconds
        return timeout

    @pytest.mark.parametrize(
        "kwargs, message",
        [
            ({
                "total": -1
            }, "less than"),
            ({
                "connect": 2,
                "total": -1
            }, "less than"),
            ({
                "read": -1
            }, "less than"),
            ({
                "connect": False
            }, "cannot be a boolean"),
            ({
                "read": True
            }, "cannot be a boolean"),
            ({
                "connect": 0
            }, "less than or equal"),
            ({
                "read": "foo"
            }, "int, float or None"),
        ],
    )
    def test_invalid_timeouts(self, kwargs, message):
        with pytest.raises(ValueError) as e:
            Timeout(**kwargs)
        assert message in str(e.value)

    @patch("urllib3.util.timeout.current_time")
    def test_timeout(self, current_time):
        timeout = Timeout(total=3)

        # make 'no time' elapse
        timeout = self._make_time_pass(seconds=0,
                                       timeout=timeout,
                                       time_mock=current_time)
        assert timeout.read_timeout == 3
        assert timeout.connect_timeout == 3

        timeout = Timeout(total=3, connect=2)
        assert timeout.connect_timeout == 2

        timeout = Timeout()
        assert timeout.connect_timeout == Timeout.DEFAULT_TIMEOUT

        # Connect takes 5 seconds, leaving 5 seconds for read
        timeout = Timeout(total=10, read=7)
        timeout = self._make_time_pass(seconds=5,
                                       timeout=timeout,
                                       time_mock=current_time)
        assert timeout.read_timeout == 5

        # Connect takes 2 seconds, read timeout still 7 seconds
        timeout = Timeout(total=10, read=7)
        timeout = self._make_time_pass(seconds=2,
                                       timeout=timeout,
                                       time_mock=current_time)
        assert timeout.read_timeout == 7

        timeout = Timeout(total=10, read=7)
        assert timeout.read_timeout == 7

        timeout = Timeout(total=None, read=None, connect=None)
        assert timeout.connect_timeout is None
        assert timeout.read_timeout is None
        assert timeout.total is None

        timeout = Timeout(5)
        assert timeout.total == 5

    def test_timeout_str(self):
        timeout = Timeout(connect=1, read=2, total=3)
        assert str(timeout) == "Timeout(connect=1, read=2, total=3)"
        timeout = Timeout(connect=1, read=None, total=3)
        assert str(timeout) == "Timeout(connect=1, read=None, total=3)"

    @patch("urllib3.util.timeout.current_time")
    def test_timeout_elapsed(self, current_time):
        current_time.return_value = TIMEOUT_EPOCH
        timeout = Timeout(total=3)
        with pytest.raises(TimeoutStateError):
            timeout.get_connect_duration()

        timeout.start_connect()
        with pytest.raises(TimeoutStateError):
            timeout.start_connect()

        current_time.return_value = TIMEOUT_EPOCH + 2
        assert timeout.get_connect_duration() == 2
        current_time.return_value = TIMEOUT_EPOCH + 37
        assert timeout.get_connect_duration() == 37

    def test_is_fp_closed_object_supports_closed(self):
        class ClosedFile:
            @property
            def closed(self):
                return True

        assert is_fp_closed(ClosedFile())

    def test_is_fp_closed_object_has_none_fp(self):
        class NoneFpFile:
            @property
            def fp(self):
                return None

        assert is_fp_closed(NoneFpFile())

    def test_is_fp_closed_object_has_fp(self):
        class FpFile:
            @property
            def fp(self):
                return True

        assert not is_fp_closed(FpFile())

    def test_is_fp_closed_object_has_neither_fp_nor_closed(self):
        class NotReallyAFile:
            pass

        with pytest.raises(ValueError):
            is_fp_closed(NotReallyAFile())

    def test_has_ipv6_disabled_on_compile(self):
        with patch("socket.has_ipv6", False):
            assert not _has_ipv6("::1")

    def test_has_ipv6_enabled_but_fails(self):
        with patch("socket.has_ipv6", True):
            with patch("socket.socket") as mock:
                instance = mock.return_value
                instance.bind = Mock(side_effect=Exception("No IPv6 here!"))
                assert not _has_ipv6("::1")

    def test_has_ipv6_enabled_and_working(self):
        with patch("socket.has_ipv6", True):
            with patch("socket.socket") as mock:
                instance = mock.return_value
                instance.bind.return_value = True
                assert _has_ipv6("::1")

    def test_ip_family_ipv6_enabled(self):
        with patch("urllib3.util.connection.HAS_IPV6", True):
            assert allowed_gai_family() == socket.AF_UNSPEC

    def test_ip_family_ipv6_disabled(self):
        with patch("urllib3.util.connection.HAS_IPV6", False):
            assert allowed_gai_family() == socket.AF_INET

    @pytest.mark.parametrize("headers", [b"foo", None, object])
    def test_assert_header_parsing_throws_typeerror_with_non_headers(
            self, headers):
        with pytest.raises(TypeError):
            assert_header_parsing(headers)

    def test_connection_requires_http_tunnel_no_proxy(self):
        assert not connection_requires_http_tunnel(
            proxy_url=None, proxy_config=None, destination_scheme=None)

    def test_connection_requires_http_tunnel_http_proxy(self):
        proxy = parse_url("http://*****:*****@pytest.mark.parametrize("host", [".localhost", "...", "t" * 64])
    def test_create_connection_with_invalid_idna_labels(self, host):
        with pytest.raises(LocationParseError) as ctx:
            create_connection((host, 80))
        assert str(
            ctx.value) == f"Failed to parse: '{host}', label empty or too long"

    @pytest.mark.parametrize(
        "host",
        [
            "a.example.com",
            "localhost.",
            "[dead::beef]",
            "[dead::beef%en5]",
            "[dead::beef%en5.]",
        ],
    )
    @patch("socket.getaddrinfo")
    @patch("socket.socket")
    def test_create_connection_with_valid_idna_labels(self, socket,
                                                      getaddrinfo, host):
        getaddrinfo.return_value = [(None, None, None, None, None)]
        socket.return_value = Mock()
        create_connection((host, 80))

    @pytest.mark.parametrize(
        "input,params,expected",
        (
            ("test", {}, "test"),  # str input
            (b"test", {}, "test"),  # bytes input
            (b"test", {
                "encoding": "utf-8"
            }, "test"),  # bytes input with utf-8
            (b"test", {
                "encoding": "ascii"
            }, "test"),  # bytes input with ascii
        ),
    )
    def test_to_str(self, input, params, expected):
        assert to_str(input, **params) == expected

    def test_to_str_error(self):
        with pytest.raises(TypeError, match="not expecting type int"):
            to_str(1)

    @pytest.mark.parametrize(
        "input,params,expected",
        (
            (b"test", {}, b"test"),  # str input
            ("test", {}, b"test"),  # bytes input
            ("é", {}, b"\xc3\xa9"),  # bytes input
            ("test", {
                "encoding": "utf-8"
            }, b"test"),  # bytes input with utf-8
            ("test", {
                "encoding": "ascii"
            }, b"test"),  # bytes input with ascii
        ),
    )
    def test_to_bytes(self, input, params, expected):
        assert to_bytes(input, **params) == expected

    def test_to_bytes_error(self):
        with pytest.raises(TypeError, match="not expecting type int"):
            to_bytes(1)
예제 #27
0
class TestUtil(unittest.TestCase):
    def test_get_host(self):
        url_host_map = {
            # Hosts
            'http://google.com/mail': ('http', 'google.com', None),
            'http://google.com/mail/': ('http', 'google.com', None),
            'google.com/mail': ('http', 'google.com', None),
            'http://google.com/': ('http', 'google.com', None),
            'http://google.com': ('http', 'google.com', None),
            'http://www.google.com': ('http', 'www.google.com', None),
            'http://mail.google.com': ('http', 'mail.google.com', None),
            'http://google.com:8000/mail/': ('http', 'google.com', 8000),
            'http://google.com:8000': ('http', 'google.com', 8000),
            'https://google.com': ('https', 'google.com', None),
            'https://google.com:8000': ('https', 'google.com', 8000),
            'http://*****:*****@127.0.0.1:1234': ('http', '127.0.0.1', 1234),
            'http://google.com/foo=http://bar:42/baz': ('http', 'google.com',
                                                        None),
            'http://google.com?foo=http://bar:42/baz': ('http', 'google.com',
                                                        None),
            'http://google.com#foo=http://bar:42/baz': ('http', 'google.com',
                                                        None),

            # IPv4
            '173.194.35.7': ('http', '173.194.35.7', None),
            'http://173.194.35.7': ('http', '173.194.35.7', None),
            'http://173.194.35.7/test': ('http', '173.194.35.7', None),
            'http://173.194.35.7:80': ('http', '173.194.35.7', 80),
            'http://173.194.35.7:80/test': ('http', '173.194.35.7', 80),

            # IPv6
            '[2a00:1450:4001:c01::67]': ('http', '[2a00:1450:4001:c01::67]',
                                         None),
            'http://[2a00:1450:4001:c01::67]':
            ('http', '[2a00:1450:4001:c01::67]', None),
            'http://[2a00:1450:4001:c01::67]/test':
            ('http', '[2a00:1450:4001:c01::67]', None),
            'http://[2a00:1450:4001:c01::67]:80':
            ('http', '[2a00:1450:4001:c01::67]', 80),
            'http://[2a00:1450:4001:c01::67]:80/test':
            ('http', '[2a00:1450:4001:c01::67]', 80),

            # More IPv6 from http://www.ietf.org/rfc/rfc2732.txt
            'http://[fedc:ba98:7654:3210:fedc:ba98:7654:3210]:8000/index.html':
            ('http', '[fedc:ba98:7654:3210:fedc:ba98:7654:3210]', 8000),
            'http://[1080:0:0:0:8:800:200c:417a]/index.html':
            ('http', '[1080:0:0:0:8:800:200c:417a]', None),
            'http://[3ffe:2a00:100:7031::1]':
            ('http', '[3ffe:2a00:100:7031::1]', None),
            'http://[1080::8:800:200c:417a]/foo':
            ('http', '[1080::8:800:200c:417a]', None),
            'http://[::192.9.5.5]/ipng': ('http', '[::192.9.5.5]', None),
            'http://[::ffff:129.144.52.38]:42/index.html':
            ('http', '[::ffff:129.144.52.38]', 42),
            'http://[2010:836b:4179::836b:4179]':
            ('http', '[2010:836b:4179::836b:4179]', None),
        }
        for url, expected_host in url_host_map.items():
            returned_host = get_host(url)
            self.assertEqual(returned_host, expected_host)

    def test_invalid_host(self):
        # TODO: Add more tests
        invalid_host = [
            'http://google.com:foo',
            'http://::1/',
            'http://::1:80/',
            'http://google.com:-80',
            six.u('http://google.com:\xb2\xb2'),  # \xb2 = ^2
        ]

        for location in invalid_host:
            self.assertRaises(LocationParseError, get_host, location)

    def test_host_normalization(self):
        """
        Asserts the scheme and hosts with a normalizable scheme are
        converted to lower-case.
        """
        url_host_map = {
            # Hosts
            'HTTP://GOOGLE.COM/mail/': ('http', 'google.com', None),
            'GOogle.COM/mail': ('http', 'google.com', None),
            'HTTP://GoOgLe.CoM:8000/mail/': ('http', 'google.com', 8000),
            'HTTP://*****:*****@EXAMPLE.COM:1234': ('http', 'example.com',
                                                      1234),
            '173.194.35.7': ('http', '173.194.35.7', None),
            'HTTP://173.194.35.7': ('http', '173.194.35.7', None),
            'HTTP://[2a00:1450:4001:c01::67]:80/test':
            ('http', '[2a00:1450:4001:c01::67]', 80),
            'HTTP://[FEDC:BA98:7654:3210:FEDC:BA98:7654:3210]:8000/index.html':
            ('http', '[fedc:ba98:7654:3210:fedc:ba98:7654:3210]', 8000),
            'HTTPS://[1080:0:0:0:8:800:200c:417A]/index.html':
            ('https', '[1080:0:0:0:8:800:200c:417a]', None),
            'abOut://eXamPlE.com?info=1': ('about', 'eXamPlE.com', None),
            'http+UNIX://%2fvar%2frun%2fSOCKET/path':
            ('http+unix', '%2fvar%2frun%2fSOCKET', None),
        }
        for url, expected_host in url_host_map.items():
            returned_host = get_host(url)
            self.assertEqual(returned_host, expected_host)

    def test_parse_url_normalization(self):
        """Assert parse_url normalizes the scheme/host, and only the scheme/host"""
        test_urls = [
            ('HTTP://GOOGLE.COM/MAIL/', 'http://google.com/MAIL/'),
            ('HTTP://*****:*****@Example.com:8080/',
             'http://*****:*****@example.com:8080/'),
            ('HTTPS://Example.Com/?Key=Value',
             'https://example.com/?Key=Value'),
            ('Https://Example.Com/#Fragment', 'https://example.com/#Fragment'),
        ]
        for url, expected_normalized_url in test_urls:
            actual_normalized_url = parse_url(url).url
            self.assertEqual(actual_normalized_url, expected_normalized_url)

    parse_url_host_map = [
        ('http://google.com/mail', Url('http', host='google.com',
                                       path='/mail')),
        ('http://google.com/mail/',
         Url('http', host='google.com', path='/mail/')),
        ('http://google.com/mail', Url('http', host='google.com',
                                       path='mail')),
        ('google.com/mail', Url(host='google.com', path='/mail')),
        ('http://google.com/', Url('http', host='google.com', path='/')),
        ('http://google.com', Url('http', host='google.com')),
        ('http://google.com?foo',
         Url('http', host='google.com', path='', query='foo')),

        # Path/query/fragment
        ('', Url()),
        ('/', Url(path='/')),
        ('#?/!google.com/?foo#bar',
         Url(path='', fragment='?/!google.com/?foo#bar')),
        ('/foo', Url(path='/foo')),
        ('/foo?bar=baz', Url(path='/foo', query='bar=baz')),
        ('/foo?bar=baz#banana?apple/orange',
         Url(path='/foo', query='bar=baz', fragment='banana?apple/orange')),

        # Port
        ('http://google.com/', Url('http', host='google.com', path='/')),
        ('http://google.com:80/',
         Url('http', host='google.com', port=80, path='/')),
        ('http://google.com:80', Url('http', host='google.com', port=80)),

        # Auth
        ('http://*****:*****@localhost/',
         Url('http', auth='foo:bar', host='localhost', path='/')),
        ('http://foo@localhost/',
         Url('http', auth='foo', host='localhost', path='/')),
        ('http://*****:*****@baz@localhost/',
         Url('http', auth='foo:bar@baz', host='localhost', path='/')),
        ('http://@', Url('http', host=None, auth=''))
    ]

    non_round_tripping_parse_url_host_map = {
        # Path/query/fragment
        '?': Url(path='', query=''),
        '#': Url(path='', fragment=''),

        # Empty Port
        'http://google.com:': Url('http', host='google.com'),
        'http://google.com:/': Url('http', host='google.com', path='/'),
    }

    def test_parse_url(self):
        for url, expected_Url in chain(
                self.parse_url_host_map,
                self.non_round_tripping_parse_url_host_map.items()):
            returned_Url = parse_url(url)
            self.assertEqual(returned_Url, expected_Url)

    def test_unparse_url(self):
        for url, expected_Url in self.parse_url_host_map:
            self.assertEqual(url, expected_Url.url)

    def test_parse_url_invalid_IPv6(self):
        self.assertRaises(ValueError, parse_url, '[::1')

    def test_Url_str(self):
        U = Url('http', host='google.com')
        self.assertEqual(str(U), U.url)

    def test_request_uri(self):
        url_host_map = {
            'http://google.com/mail': '/mail',
            'http://google.com/mail/': '/mail/',
            'http://google.com/': '/',
            'http://google.com': '/',
            '': '/',
            '/': '/',
            '?': '/?',
            '#': '/',
            '/foo?bar=baz': '/foo?bar=baz',
        }
        for url, expected_request_uri in url_host_map.items():
            returned_url = parse_url(url)
            self.assertEqual(returned_url.request_uri, expected_request_uri)

    def test_netloc(self):
        url_netloc_map = {
            'http://google.com/mail': 'google.com',
            'http://google.com:80/mail': 'google.com:80',
            'google.com/foobar': 'google.com',
            'google.com:12345': 'google.com:12345',
        }

        for url, expected_netloc in url_netloc_map.items():
            self.assertEqual(parse_url(url).netloc, expected_netloc)

    def test_make_headers(self):
        self.assertEqual(make_headers(accept_encoding=True),
                         {'accept-encoding': 'gzip,deflate'})

        self.assertEqual(make_headers(accept_encoding='foo,bar'),
                         {'accept-encoding': 'foo,bar'})

        self.assertEqual(make_headers(accept_encoding=['foo', 'bar']),
                         {'accept-encoding': 'foo,bar'})

        self.assertEqual(
            make_headers(accept_encoding=True, user_agent='banana'), {
                'accept-encoding': 'gzip,deflate',
                'user-agent': 'banana'
            })

        self.assertEqual(make_headers(user_agent='banana'),
                         {'user-agent': 'banana'})

        self.assertEqual(make_headers(keep_alive=True),
                         {'connection': 'keep-alive'})

        self.assertEqual(make_headers(basic_auth='foo:bar'),
                         {'authorization': 'Basic Zm9vOmJhcg=='})

        self.assertEqual(make_headers(proxy_basic_auth='foo:bar'),
                         {'proxy-authorization': 'Basic Zm9vOmJhcg=='})

        self.assertEqual(make_headers(disable_cache=True),
                         {'cache-control': 'no-cache'})

    def test_rewind_body(self):
        body = io.BytesIO(b'test data')
        self.assertEqual(body.read(), b'test data')

        # Assert the file object has been consumed
        self.assertEqual(body.read(), b'')

        # Rewind it back to just be b'data'
        rewind_body(body, 5)
        self.assertEqual(body.read(), b'data')

    def test_rewind_body_failed_tell(self):
        body = io.BytesIO(b'test data')
        body.read()  # Consume body

        # Simulate failed tell()
        body_pos = _FAILEDTELL
        self.assertRaises(UnrewindableBodyError, rewind_body, body, body_pos)

    def test_rewind_body_bad_position(self):
        body = io.BytesIO(b'test data')
        body.read()  # Consume body

        # Pass non-integer position
        self.assertRaises(ValueError, rewind_body, body, None)
        self.assertRaises(ValueError, rewind_body, body, object())

    def test_rewind_body_failed_seek(self):
        class BadSeek():
            def seek(self, pos, offset=0):
                raise IOError

        self.assertRaises(UnrewindableBodyError, rewind_body, BadSeek(), 2)

    def test_split_first(self):
        test_cases = {
            ('abcd', 'b'): ('a', 'cd', 'b'),
            ('abcd', 'cb'): ('a', 'cd', 'b'),
            ('abcd', ''): ('abcd', '', None),
            ('abcd', 'a'): ('', 'bcd', 'a'),
            ('abcd', 'ab'): ('', 'bcd', 'a'),
        }
        for input, expected in test_cases.items():
            output = split_first(*input)
            self.assertEqual(output, expected)

    def test_add_stderr_logger(self):
        handler = add_stderr_logger(
            level=logging.INFO)  # Don't actually print debug
        logger = logging.getLogger('urllib3')
        self.assertTrue(handler in logger.handlers)

        logger.debug('Testing add_stderr_logger')
        logger.removeHandler(handler)

    def test_disable_warnings(self):
        with warnings.catch_warnings(record=True) as w:
            clear_warnings()
            warnings.warn('This is a test.', InsecureRequestWarning)
            self.assertEqual(len(w), 1)
            disable_warnings()
            warnings.warn('This is a test.', InsecureRequestWarning)
            self.assertEqual(len(w), 1)

    def _make_time_pass(self, seconds, timeout, time_mock):
        """ Make some time pass for the timeout object """
        time_mock.return_value = TIMEOUT_EPOCH
        timeout.start_connect()
        time_mock.return_value = TIMEOUT_EPOCH + seconds
        return timeout

    def test_invalid_timeouts(self):
        try:
            Timeout(total=-1)
            self.fail("negative value should throw exception")
        except ValueError as e:
            self.assertTrue('less than' in str(e))
        try:
            Timeout(connect=2, total=-1)
            self.fail("negative value should throw exception")
        except ValueError as e:
            self.assertTrue('less than' in str(e))

        try:
            Timeout(read=-1)
            self.fail("negative value should throw exception")
        except ValueError as e:
            self.assertTrue('less than' in str(e))

        try:
            Timeout(connect=False)
            self.fail("boolean values should throw exception")
        except ValueError as e:
            self.assertTrue('cannot be a boolean' in str(e))

        try:
            Timeout(read=True)
            self.fail("boolean values should throw exception")
        except ValueError as e:
            self.assertTrue('cannot be a boolean' in str(e))

        try:
            Timeout(connect=0)
            self.fail("value <= 0 should throw exception")
        except ValueError as e:
            self.assertTrue('less than or equal' in str(e))

        try:
            Timeout(read="foo")
            self.fail("string value should not be allowed")
        except ValueError as e:
            self.assertTrue('int, float or None' in str(e))

    @patch('urllib3.util.timeout.current_time')
    def test_timeout(self, current_time):
        timeout = Timeout(total=3)

        # make 'no time' elapse
        timeout = self._make_time_pass(seconds=0,
                                       timeout=timeout,
                                       time_mock=current_time)
        self.assertEqual(timeout.read_timeout, 3)
        self.assertEqual(timeout.connect_timeout, 3)

        timeout = Timeout(total=3, connect=2)
        self.assertEqual(timeout.connect_timeout, 2)

        timeout = Timeout()
        self.assertEqual(timeout.connect_timeout, Timeout.DEFAULT_TIMEOUT)

        # Connect takes 5 seconds, leaving 5 seconds for read
        timeout = Timeout(total=10, read=7)
        timeout = self._make_time_pass(seconds=5,
                                       timeout=timeout,
                                       time_mock=current_time)
        self.assertEqual(timeout.read_timeout, 5)

        # Connect takes 2 seconds, read timeout still 7 seconds
        timeout = Timeout(total=10, read=7)
        timeout = self._make_time_pass(seconds=2,
                                       timeout=timeout,
                                       time_mock=current_time)
        self.assertEqual(timeout.read_timeout, 7)

        timeout = Timeout(total=10, read=7)
        self.assertEqual(timeout.read_timeout, 7)

        timeout = Timeout(total=None, read=None, connect=None)
        self.assertEqual(timeout.connect_timeout, None)
        self.assertEqual(timeout.read_timeout, None)
        self.assertEqual(timeout.total, None)

        timeout = Timeout(5)
        self.assertEqual(timeout.total, 5)

    def test_timeout_str(self):
        timeout = Timeout(connect=1, read=2, total=3)
        self.assertEqual(str(timeout), "Timeout(connect=1, read=2, total=3)")
        timeout = Timeout(connect=1, read=None, total=3)
        self.assertEqual(str(timeout),
                         "Timeout(connect=1, read=None, total=3)")

    @patch('urllib3.util.timeout.current_time')
    def test_timeout_elapsed(self, current_time):
        current_time.return_value = TIMEOUT_EPOCH
        timeout = Timeout(total=3)
        self.assertRaises(TimeoutStateError, timeout.get_connect_duration)

        timeout.start_connect()
        self.assertRaises(TimeoutStateError, timeout.start_connect)

        current_time.return_value = TIMEOUT_EPOCH + 2
        self.assertEqual(timeout.get_connect_duration(), 2)
        current_time.return_value = TIMEOUT_EPOCH + 37
        self.assertEqual(timeout.get_connect_duration(), 37)

    def test_resolve_cert_reqs(self):
        self.assertEqual(resolve_cert_reqs(None), ssl.CERT_NONE)
        self.assertEqual(resolve_cert_reqs(ssl.CERT_NONE), ssl.CERT_NONE)
        self.assertEqual(resolve_cert_reqs(ssl.CERT_REQUIRED),
                         ssl.CERT_REQUIRED)
        self.assertEqual(resolve_cert_reqs('REQUIRED'), ssl.CERT_REQUIRED)
        self.assertEqual(resolve_cert_reqs('CERT_REQUIRED'), ssl.CERT_REQUIRED)

    def test_resolve_ssl_version(self):
        self.assertEqual(resolve_ssl_version(ssl.PROTOCOL_TLSv1),
                         ssl.PROTOCOL_TLSv1)
        self.assertEqual(resolve_ssl_version("PROTOCOL_TLSv1"),
                         ssl.PROTOCOL_TLSv1)
        self.assertEqual(resolve_ssl_version("TLSv1"), ssl.PROTOCOL_TLSv1)
        self.assertEqual(resolve_ssl_version(ssl.PROTOCOL_SSLv23),
                         ssl.PROTOCOL_SSLv23)

    def test_is_fp_closed_object_supports_closed(self):
        class ClosedFile(object):
            @property
            def closed(self):
                return True

        self.assertTrue(is_fp_closed(ClosedFile()))

    def test_is_fp_closed_object_has_none_fp(self):
        class NoneFpFile(object):
            @property
            def fp(self):
                return None

        self.assertTrue(is_fp_closed(NoneFpFile()))

    def test_is_fp_closed_object_has_fp(self):
        class FpFile(object):
            @property
            def fp(self):
                return True

        self.assertTrue(not is_fp_closed(FpFile()))

    def test_is_fp_closed_object_has_neither_fp_nor_closed(self):
        class NotReallyAFile(object):
            pass

        self.assertRaises(ValueError, is_fp_closed, NotReallyAFile())

    def test_ssl_wrap_socket_loads_the_cert_chain(self):
        socket = object()
        mock_context = Mock()
        ssl_wrap_socket(ssl_context=mock_context,
                        sock=socket,
                        certfile='/path/to/certfile')

        mock_context.load_cert_chain.assert_called_once_with(
            '/path/to/certfile', None)

    @patch('urllib3.util.ssl_.create_urllib3_context')
    def test_ssl_wrap_socket_creates_new_context(self, create_urllib3_context):
        socket = object()
        ssl_wrap_socket(sock=socket, cert_reqs='CERT_REQUIRED')

        create_urllib3_context.assert_called_once_with(None,
                                                       'CERT_REQUIRED',
                                                       ciphers=None)

    def test_ssl_wrap_socket_loads_verify_locations(self):
        socket = object()
        mock_context = Mock()
        ssl_wrap_socket(ssl_context=mock_context,
                        ca_certs='/path/to/pem',
                        sock=socket)
        mock_context.load_verify_locations.assert_called_once_with(
            '/path/to/pem', None)

    def test_ssl_wrap_socket_loads_certificate_directories(self):
        socket = object()
        mock_context = Mock()
        ssl_wrap_socket(ssl_context=mock_context,
                        ca_cert_dir='/path/to/pems',
                        sock=socket)
        mock_context.load_verify_locations.assert_called_once_with(
            None, '/path/to/pems')

    def test_ssl_wrap_socket_with_no_sni(self):
        socket = object()
        mock_context = Mock()
        # Ugly preservation of original value
        HAS_SNI = ssl_.HAS_SNI
        ssl_.HAS_SNI = False
        ssl_wrap_socket(ssl_context=mock_context, sock=socket)
        mock_context.wrap_socket.assert_called_once_with(socket)
        ssl_.HAS_SNI = HAS_SNI

    def test_ssl_wrap_socket_with_no_sni_warns(self):
        socket = object()
        mock_context = Mock()
        # Ugly preservation of original value
        HAS_SNI = ssl_.HAS_SNI
        ssl_.HAS_SNI = False
        with patch('warnings.warn') as warn:
            ssl_wrap_socket(ssl_context=mock_context, sock=socket)
        mock_context.wrap_socket.assert_called_once_with(socket)
        ssl_.HAS_SNI = HAS_SNI
        self.assertTrue(warn.call_count >= 1)
        warnings = [call[0][1] for call in warn.call_args_list]
        self.assertTrue(SNIMissingWarning in warnings)

    def test_const_compare_digest_fallback(self):
        target = hashlib.sha256(b'abcdef').digest()
        self.assertTrue(_const_compare_digest_backport(target, target))

        prefix = target[:-1]
        self.assertFalse(_const_compare_digest_backport(target, prefix))

        suffix = target + b'0'
        self.assertFalse(_const_compare_digest_backport(target, suffix))

        incorrect = hashlib.sha256(b'xyz').digest()
        self.assertFalse(_const_compare_digest_backport(target, incorrect))

    def test_has_ipv6_disabled_on_compile(self):
        with patch('socket.has_ipv6', False):
            self.assertFalse(_has_ipv6('::1'))

    def test_has_ipv6_enabled_but_fails(self):
        with patch('socket.has_ipv6', True):
            with patch('socket.socket') as mock:
                instance = mock.return_value
                instance.bind = Mock(side_effect=Exception('No IPv6 here!'))
                self.assertFalse(_has_ipv6('::1'))

    def test_has_ipv6_enabled_and_working(self):
        with patch('socket.has_ipv6', True):
            with patch('socket.socket') as mock:
                instance = mock.return_value
                instance.bind.return_value = True
                self.assertTrue(_has_ipv6('::1'))

    def test_ip_family_ipv6_enabled(self):
        with patch('urllib3.util.connection.HAS_IPV6', True):
            self.assertEqual(allowed_gai_family(), socket.AF_UNSPEC)

    def test_ip_family_ipv6_disabled(self):
        with patch('urllib3.util.connection.HAS_IPV6', False):
            self.assertEqual(allowed_gai_family(), socket.AF_INET)

    def test_parse_retry_after(self):
        invalid = [
            "-1",
            "+1",
            "1.0",
            six.u("\xb2"),  # \xb2 = ^2
        ]
        retry = Retry()

        for value in invalid:
            self.assertRaises(InvalidHeader, retry.parse_retry_after, value)

        self.assertEqual(retry.parse_retry_after("0"), 0)
        self.assertEqual(retry.parse_retry_after("1000"), 1000)
        self.assertEqual(retry.parse_retry_after("\t42 "), 42)