def safe_create_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, source_address=None, socket_options=None): host, port = address if host.startswith("["): host = host.strip("[]") err = None host = ensure_fqdn(host) # Using the value from allowed_gai_family() in the context of getaddrinfo lets # us select whether to work with IPv4 DNS records, IPv6 records, or both. # The original create_connection function always returns all records. family = allowed_gai_family() for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM): af, socktype, proto, canonname, sa = res # HACK(mattrobenolt): This is the only code that diverges ip = sa[0] if not is_ipaddress_allowed(ip): # I am explicitly choosing to be overly aggressive here. This means # the first IP that matches that hits our restricted set of IP networks, # we reject all records. In theory, there might be IP addresses that # are safe, but if one record is straddling safe and unsafe IPs, it's # suspicious. if host == ip: raise RestrictedIPAddress("(%s) matches the URL blacklist" % ip) raise RestrictedIPAddress("(%s/%s) matches the URL blacklist" % (host, ip)) sock = None try: sock = socket.socket(af, socktype, proto) # If provided, set socket level options before connecting. _set_socket_options(sock, socket_options) if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT: sock.settimeout(timeout) if source_address: sock.bind(source_address) sock.connect(sa) return sock except socket.error as e: err = e if sock is not None: sock.close() sock = None if err is not None: raise err raise socket.error("getaddrinfo returns an empty list")
def safe_socket_connect(address, timeout=30, ssl=False): """ Creates a socket and connects to address, but prevents connecting to our disallowed IP blocks. """ if not is_safe_hostname(address[0]): raise RestrictedIPAddress('%s matches the hostname blacklist' % address[0]) s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.settimeout(timeout) s.connect(address) if not ssl: return s return wrap_socket(s)
def send(self, request, *args, **kwargs): if not is_valid_url(request.url): raise RestrictedIPAddress('%s matches the URL blacklist' % (request.url,)) return super(BlacklistAdapter, self).send(request, *args, **kwargs)