Exemple #1
0
    async def _create_aiohttp_session(self):
        if self.loop is None:
            self.loop = asyncio.get_running_loop()

        if self.use_static_responses:
            connector = StaticConnector(
                limit=self._limit,
                enable_cleanup_closed=self._enable_cleanup_closed)
        else:
            connector = aiohttp.TCPConnector(
                limit=self._limit,
                use_dns_cache=True,
                ssl_context=self._ssl_context,
                enable_cleanup_closed=self._enable_cleanup_closed)

        self.session = aiohttp.ClientSession(
            headers=self.headers,
            auto_decompress=True,
            loop=self.loop,
            cookie_jar=aiohttp.DummyCookieJar(),
            request_class=self._request_class,
            response_class=self._response_class,
            connector=connector,
            trace_configs=self._trace_configs,
        )
Exemple #2
0
async def main(request_body):
    try:
        request_urls = json.loads(request_body)
    except json.decoder.JSONDecodeError:
        return {
            "statusCode": HTTPStatus.UNPROCESSABLE_ENTITY,
            "body": "Bad json payload",
        }
    if not all(map(url_validator, request_urls)):
        return {
            "statusCode": HTTPStatus.UNPROCESSABLE_ENTITY,
            "body": "Not a valid URL in payload.",
        }
    urls_to_measure = []
    async with aiohttp.ClientSession(
            cookie_jar=aiohttp.DummyCookieJar()) as session:
        for url in request_urls:
            measure_data = measure_response_time(session, url)
            urls_to_measure.append(measure_data)
        measurements_results = await asyncio.gather(*urls_to_measure)
    response_data = {"results": [], "errors": []}
    for original_url, measurement_result in zip(request_urls,
                                                measurements_results):
        if measurement_result["status_code"] is None:
            response_data["errors"].append(measurement_result)
        else:
            response_data["results"].append(measurement_result)
    return {
        "statusCode": HTTPStatus.OK,
        "body": json.dumps(response_data),
        "headers": {
            "Content-Type": "application/json"
        },
    }
 def _generate_session(concurrency_limit=None, loop=None, with_cookie_jar=False):
     if not loop:
         loop = asyncio.get_event_loop()
     concurrency_limit = main_config()["main"].getint("concurrency") if concurrency_limit is None else concurrency_limit
     jar = aiohttp.DummyCookieJar() if with_cookie_jar else None
     return aiohttp.ClientSession(connector=_SessionManger._generate_connector(limit=concurrency_limit, loop=loop),
                                  loop=loop, cookie_jar=jar)
    async def init_idle_session_pool(self):
        """空闲队列初始化时机:
        因为对象初始化时,返回对象本身,所以__init__不能是携程,也即是不能在__init__内部初始化空闲队列;在创建下载管理器后,马上进行空闲session队列初始化工作。
        session就相当于一个打开的浏览器,一个session用于多次页面访问下载。
        类比为一开始就打开若干个浏览器,全部呈空闲状态,所以全部放入空闲队列。
        
        需要根据是否需要代理ip,为download选择合适的协程构建函数,避免每次下载时判断。

        如果要用到代理ip,还需要创建并初始化一个proxyDirector"""
        while True:
            try:
                session = aiohttp.ClientSession(
                    headers=DEFAULT_REQUEST_HEADERS,
                    timeout=aiohttp.ClientTimeout(60),
                    cookie_jar=aiohttp.DummyCookieJar())
                self.idle_session_pool.put_nowait(session)
            except QueueFull:
                # 最后创建的session,插入不到队列中,需要close
                await session.close()
                break

        if self.proxy_used:
            self.download = self.download_proxy
            self.proxyDirector = ProxyDirector()
        else:
            self.download = self.download_straight
Exemple #5
0
async def getJSONwithToken(requestURL, discordID):
    """
        Takes url and discordID, returns dict with [token] = JSON
        otherwise [error] has a errormessage
    """

    # handle and return working token
    ret = await handleAndReturnToken(discordID)
    if ret["result"]:
        token = ret["result"]
    else:
        return ret

    token_headers = {
        'Authorization': f'Bearer {token}',
        'x-api-key': BUNGIE_TOKEN,
        'Accept': 'application/json'
    }

    no_jar = aiohttp.DummyCookieJar()
    async with aiohttp.ClientSession(cookie_jar=no_jar) as session:
        # abort after 5 tries
        for _ in range(10):
            # wait for a token from the rate limiter
            async with asyncio.Lock():
                await limiter.wait_for_token()

            async with session.get(url=requestURL, headers=token_headers) as r:
                #might be application/json; charset=utf-8
                if 'application/json' not in r.headers['Content-Type']:
                    #print(await r.text())
                    print(f'Wrong content type! {r.status}: {r.reason})')
                    continue
                res = await r.json()

                # ok
                if r.status == 200:
                    return {'result': res, 'error': None}

                # handling any errors if not ok
                else:
                    if await errorCodeHandling(requestURL, r, res):
                        return {'result': None, 'error': f"Status Code <{r.status}>"}
                    if res["ErrorStatus"] == "PerEndpointRequestThrottleExceeded":
                        return await getJSONwithToken(requestURL, discordID)

        print('Request failed 5 times, aborting')
        
        error = await r.json()
        msg = f"""Didn't get a valid response. Bungie returned status {r.status}: \n\
            `ErrorCode - {error["ErrorCode"]} \n\
            ErrorStatus - {error["ErrorStatus"]} \n\
            Message - {error["Message"]}`"""
        #TODO add more specific exception
        # except:
        #     msg = "Bungie is doing wierd stuff right now or there is a big error in my programming, 
        #  the first is definitely more likely. Try again in a sec."

        return {'result': None, 'error': msg}
Exemple #6
0
 def client_session(self) -> aiohttp.ClientSession:
     return aiohttp.ClientSession(
         connector=self.client_session_connector,
         headers={
             "User-Agent":
             "Mozilla/5.0 (X11; Linux x86_64; rv:92.0) Gecko/20100101 Firefox/92.0"
         },
         cookie_jar=aiohttp.DummyCookieJar())
Exemple #7
0
    def __init__(self,
                 job_queue,
                 url="http://localhost:9001",
                 tlp=tlp.AMBER,
                 api_token="",
                 poll_interval=5,
                 submit_original_filename=True,
                 max_job_age=900,
                 retries=5,
                 backoff=0.5):
        """ Initialize the object.

        @param job_queue: The job queue to use from now on
        @type job_queue: JobQueue object
        @param url: Where to reach the Cortex REST API
        @type url: string
        @param tlp: colour according to traffic light protocol
        @type tlp: tlp
        @param api_token: API token to use for authentication
        @type api_token: string
        @param poll_interval: How long to wait inbetween job status checks
        @type poll_interval: int
        @param submit_original_filename: Whether to provide the original
                                         filename to Cortex to enhance
                                         analysis.
        @type submit_original_filename: bool
        @param max_job_age: How long to track jobs before declaring them
                            failed.
        @type max_job_age: int (seconds)
        @param retries: Number of retries on API requests
        @type retries: int
        @param backoff: Backoff factor for urllib3
        @type backoff: float
        """
        self.job_queue = job_queue
        self.running_jobs = {}
        self.url = url
        self.tlp = tlp
        self.poll_interval = poll_interval
        self.submit_original_filename = submit_original_filename
        self.max_job_age = max_job_age

        self.retrier = tenacity.AsyncRetrying(
            stop=tenacity.stop_after_attempt(retries),
            wait=tenacity.wait_exponential(multiplier=backoff),
            retry=tenacity.retry_if_exception_type(aiohttp.ClientError),
            before_sleep=log_retry)

        headers = {"Authorization": f"Bearer {api_token}"}
        cookie_jar = aiohttp.DummyCookieJar()
        self.session = aiohttp.ClientSession(
            raise_for_status=True,
            response_class=PickyClientResponse,
            headers=headers,
            cookie_jar=cookie_jar)

        self.tracker = None
 def __init__(self):
     self.session = aiohttp.ClientSession(
         headers=DEFAULT_REQUEST_HEADERS,
         timeout=aiohttp.ClientTimeout(60),
         cookie_jar=aiohttp.DummyCookieJar())
     self.proxy_ip_pool = Queue(maxsize=PROXY_TIMES_SIZE)
     self.semaphore = asyncio.Semaphore()
     self.proxy_key_url = PROXY_KEY_URL
     self.proxy_times_size = PROXY_TIMES_SIZE
Exemple #9
0
 def get_client(self):
     if self._client is None:
         jar = aiohttp.DummyCookieJar()
         if self.outbound_unix_socket:
             conn = aiohttp.UnixConnector(path=self.outbound_unix_socket,)
         else:
             conn = aiohttp.TCPConnector(limit=30)
         self._client = aiohttp.ClientSession(
             connector=conn, auto_decompress=False, cookie_jar=jar,
         )
     return self._client
Exemple #10
0
 async def open(self):
     """Opens the connection.
     """
     if not self.session and self._session_owner:
         jar = aiohttp.DummyCookieJar()
         self.session = aiohttp.ClientSession(
             loop=self._loop,
             trust_env=self._use_env_settings,
             cookie_jar=jar)
     if self.session is not None:
         await self.session.__aenter__()
    async def get_proxy_ip(self):
        """此处要对代理ip池为空做两次判断,那一开始来说,n个协程都来取代理ip,都发现池子是空的,都去取semaphore,必然只有一个协程获取到semaphore
        那个幸运协程从网上下载代理ip并存放到池子里后,释放semaphore,拿到一个代理ip后耗尽其协程。
        其余协程等待到semaphore后,再一次判断池子是否为空,发现不为空,不再到网上下载代理ip,而是直接从池子中获取到已有代理ip后耗尽协程。
        
        对于get_proxy_ip() coroutine来说,一次获取,要么获取到合法的代理ip,要么抛出代理ip获取异常ProxyFetchError。

        对于从网上下载代理ip发生异常,以及获取的到代理ip不符合规范,都需要抛出异常。
        
        无论是正常结束,还是发生异常情况,退出前都需要释放semaphore,不行就由别人来上。"""
        if self.proxy_ip_pool.empty():
            #此信号标的作用为,保证一次只有一个协程来根据key下载代理ip
            await self.semaphore.acquire()
            # 利用上下文管理器管理信号标,保证退出此上下文后都必须release
            # async with (await self.semaphore):
            try:
                if self.proxy_ip_pool.empty():
                    async with self.session.get(self.proxy_key_url) as resp:
                        if resp.status == 200:
                            proxy_ip = await resp.text()
                            # print(proxy_ip)

                            mo = PROXY_IP_PATTERN.search(proxy_ip)
                            if mo:
                                for _ in range(0, self.proxy_times_size):
                                    await self.proxy_ip_pool.put('http://%s' %
                                                                 mo.group(0))

                            else:
                                # 下载代理ip获得正常响应,但是代理ip不合法,比如欠费导致
                                # print('zzz')
                                # self.semaphore.release()
                                raise ProxyFectchError(self.proxy_key_url)

                        else:
                            #下载代理ip请求获得相应,但是响应httpstatus有误
                            # print('yyy')
                            # self.semaphore.release()
                            raise ProxyFectchError(self.proxy_key_url)

            except (aiohttp.ClientError, asyncio.TimeoutError) as err:
                #下载代理ip时网络异常
                # print('xxx')
                if self.session.closed:
                    self.session = aiohttp.ClientSession(
                        headers=DEFAULT_REQUEST_HEADERS,
                        timeout=aiohttp.ClientTimeout(60),
                        cookie_jar=aiohttp.DummyCookieJar())
                raise ProxyFectchError(err)
            finally:
                #发生异常,一定要释放此信号量,如果某个协程在请求代理ip资源时发生异常,没有释放此信号量。那么所有的协程将阻塞在获取本信号量处
                self.semaphore.release()

        return (await self.proxy_ip_pool.get())
Exemple #12
0
 async def async_init(self):
     """
     初始化
     """
     cookie_jar = aiohttp.DummyCookieJar()
     conn = aiohttp.TCPConnector(limit=0)
     timeout = aiohttp.ClientTimeout(
         total=Config().get_config("scanner.request_timeout"))
     self.session = aiohttp.ClientSession(cookie_jar=cookie_jar,
                                          connector=conn,
                                          timeout=timeout)
Exemple #13
0
    async def init(self, http_client_kwargs=None):
        if http_client_kwargs:
            self.session = aiohttp.ClientSession(**http_client_kwargs)
        else:
            jar = aiohttp.DummyCookieJar()
            self.tc = TCPConnector(limit=100,
                                   force_close=True,
                                   enable_cleanup_closed=True,
                                   verify_ssl=False)
            self.session = aiohttp.ClientSession(connector=self.tc,
                                                 cookie_jar=jar)

        return self
Exemple #14
0
    async def process_urls(self, urls: Iterable[str]) -> None:
        resolver = PrecachedAsyncResolver()

        connector4 = aiohttp.TCPConnector(resolver=resolver, use_dns_cache=False, limit_per_host=1, family=socket.AF_INET)
        connector6 = aiohttp.TCPConnector(resolver=resolver, use_dns_cache=False, limit_per_host=1, family=socket.AF_INET6)

        timeout = aiohttp.ClientTimeout(total=self._timeout)
        headers = {'User-Agent': USER_AGENT}

        async with aiohttp.ClientSession(cookie_jar=aiohttp.DummyCookieJar(), timeout=timeout, headers=headers, connector=connector4) as session4:
            async with aiohttp.ClientSession(cookie_jar=aiohttp.DummyCookieJar(), timeout=timeout, headers=headers, connector=connector6) as session6:
                for url in urls:
                    try:
                        host = yarl.URL(url).host
                    except Exception:
                        host = None

                    if host is None:
                        errstatus = UrlStatus(False, ExtendedStatusCodes.INVALID_URL)
                        await self._url_updater.update(url, errstatus, errstatus)
                        continue

                    dns = await resolver.get_host_status(host)

                    if dns.ipv4.exception is not None:
                        status4 = UrlStatus(False, classify_exception(dns.ipv4.exception, url))
                    else:
                        status4 = await self._check_url(url, session4)

                    if self._skip_ipv6:
                        status6 = None
                    elif dns.ipv6.exception is not None:
                        status6 = UrlStatus(False, classify_exception(dns.ipv6.exception, url))
                    else:
                        status6 = await self._check_url(url, session6)

                    await self._url_updater.update(url, status4, status6)

        await resolver.close()
    async def download_straight(self, req, headers={}, proxy=None):
        """每次下载时,都要开启一个下载协程,session托管在下载协程中,session是有限的,下载协程是无限的"""
        # 不能仅仅根据data是否为空来判断POST\GET请求,有时候明明为post请求,偏偏data为空
        session = await self.idle_session_pool.get()
        try:
            # _log_req('%s %s' % (proxy, _format_req(req)))

            method = req[1]
            url = req[2]

            referer = req[3]
            if referer:
                headers['Referer'] = referer

            payload = req[4]

            if method == 'POST':
                async with session.post(url=url,
                                        headers=headers,
                                        data=payload,
                                        proxy=proxy) as resp:
                    # print(url, resp.status)
                    if resp.status == 200:
                        source = await resp.text()
                        # print(source)
                    else:
                        # 下载时发生httpstatus异常
                        raise DownLoadError(url)
            else:
                async with session.get(url=url, headers=headers,
                                       proxy=proxy) as resp:
                    if resp.status == 200:
                        source = await resp.text()
                    else:
                        # 下载时发生httpstatus异常
                        raise DownLoadError(url)

            return source
        except (aiohttp.ClientError, asyncio.TimeoutError,
                UnicodeDecodeError) as err:
            # 下载时发生网络异常
            if session.closed:
                # 如果这个 session 已经关闭,那么重新创建一个session
                session = aiohttp.ClientSession(
                    headers=DEFAULT_REQUEST_HEADERS,
                    timeout=aiohttp.ClientTimeout(60),
                    cookie_jar=aiohttp.DummyCookieJar())
            raise DownLoadError(err)
        finally:
            # 此协程退出或耗尽时,都必须释放所占有的session
            await self.idle_session_pool.put(session)
Exemple #16
0
 def http_client(self, *args, **kwargs) -> aiohttp.ClientSession:
     """
     Construct an HTTP client using the shared connection pool
     """
     no_reuse = os.getenv('HTTP_NO_CONNECTOR_REUSE')
     no_reuse = bool(no_reuse) and no_reuse != 'false'
     if 'connector' not in kwargs and not no_reuse:
         kwargs['connector'] = self.tcp_connector
         kwargs['connector_owner'] = False
     keep_cookies = os.getenv('HTTP_PRESERVE_COOKIES')
     keep_cookies = bool(keep_cookies) and keep_cookies != 'false'
     if 'cookie_jar' not in kwargs and not keep_cookies:
         kwargs['cookie_jar'] = aiohttp.DummyCookieJar()
     return aiohttp.ClientSession(*args, **kwargs)
Exemple #17
0
    def __init__(self, *, prefix, token, if_bot):
        self.if_bot = if_bot
        self.token = token
        self.loop = asyncio.get_event_loop()
        self.cookies = aiohttp.DummyCookieJar()
        self.connector = aiohttp.TCPConnector()
        self.session = aiohttp.ClientSession(loop=self.loop,
                                             connector=self.connector,
                                             cookie_jar=self.cookies)

        self.core = commands.Bot(command_prefix=prefix,
                                 fetch_offline_members=True,
                                 loop=self.loop,
                                 connector=self.connector)
 async def open(self):
     """Opens the connection.
     """
     if not self.session and self._session_owner:
         jar = aiohttp.DummyCookieJar()
         clientsession_kwargs = {
             "trust_env": self._use_env_settings,
             "cookie_jar": jar,
             "auto_decompress": False,
         }
         if self._loop is not None:
             clientsession_kwargs["loop"] = self._loop
         self.session = aiohttp.ClientSession(**clientsession_kwargs)
     if self.session is not None:
         await self.session.__aenter__()
Exemple #19
0
    async def open_session(self):
        """Open the web API session (not the websocket session).
        If there already was one open, close that first.
        """
        if self.session:
            await self.session.close()
            self.session = None

        headers = {
            'user-agent': self.client.get_useragent(),
            'Authorization': 'Bearer ' + self.access_token,
        }
        # Again, we disable the cookie jar. See above.
        self.session = aiohttp.ClientSession(
            headers=headers, cookie_jar=aiohttp.DummyCookieJar())
Exemple #20
0
 async def _create_aiohttp_session(self):
     """Creates an aiohttp.ClientSession(). This is delayed until
     the first call to perform_request() so that AsyncTransport has
     a chance to set AIOHttpConnection.loop
     """
     if self.loop is None:
         self.loop = get_running_loop()
     self.session = aiohttp.ClientSession(
         headers=self.headers,
         auto_decompress=True,
         loop=self.loop,
         cookie_jar=aiohttp.DummyCookieJar(),
         response_class=ESClientResponse,
         connector=aiohttp.TCPConnector(
             limit=self._limit, use_dns_cache=True, ssl=self._ssl_context,
         ),
     )
Exemple #21
0
 async def create_session(self):
     """
     Creates the http session in a coro, like that aiohttp stops whining
     """
     if isinstance(self.session, aiohttp.ClientSession):
         await self.session.close()
     self.session = aiohttp.ClientSession(
         loop=self.loop,
         cookie_jar=aiohttp.DummyCookieJar(
             loop=self.loop
         ),
         connector=aiohttp.TCPConnector(
             resolver=Resolver(),
             family=socket.AF_INET,
             loop=self.loop
         )
     )
Exemple #22
0
    def get_client(self):
        import aiohttp

        if self._client is None or self._client.closed:
            jar = aiohttp.DummyCookieJar()
            if self.timeout:
                timeout = aiohttp.ClientTimeout(total=self.timeout)
            else:
                timeout = None
            self._client = aiohttp.ClientSession(
                connector=self.get_conn(),
                auto_decompress=False,
                cookie_jar=jar,
                connector_owner=False,
                timeout=timeout,
            )
        return self._client
Exemple #23
0
def get_client():
    with open('config.py') as f:
        config = eval(f.read(), {})

    client = TelegramClient(config['session_name'], config['api_id'],
                            config['api_hash'])
    client.config = config

    for handler in event_handlers:
        client.add_event_handler(handler)

    client.http = aiohttp.ClientSession(
        headers={
            'Cookie': 'session=' + client.config['aoc_session_cookie'],
        },
        cookie_jar=aiohttp.DummyCookieJar(),  # suppress normal cookie handling
    )

    return client
Exemple #24
0
    async def open(self):
        """Open web sessions for the client, and one for each team,
        and then load the team data. (This does not open the websockets.)
        """
        headers = {
            'user-agent': self.client.get_useragent(),
        }
        # We use a disabled cookie jar because if we store cookies, Mattermost tries to store a MMCSRF cookie and (eventually) fails to recognize it. Not sure if this is a bug.
        self.session = aiohttp.ClientSession(
            headers=headers, cookie_jar=aiohttp.DummyCookieJar())

        if self.teams:
            (done, pending) = await asyncio.wait(
                [team.open() for team in self.teams.values()],
                loop=self.client.evloop)
            for res in done:
                self.print_exception(res.exception(), 'Could not set up team')

        self.waketask = self.client.evloop.create_task(self.wakeloop_async())
async def main(loop):
    async with aiohttp.ClientSession(
            loop=loop, cookie_jar=aiohttp.DummyCookieJar()) as session:

        pmid_count, webenv, query_key = await get_pmid_count(session)
        print(pmid_count)

        for pmid_block, percent_done, block_num in reg_get_pmid_block(
                session, pmid_count, webenv, query_key):
            print(percent_done)
            icite_recs = reg_get_icite(pmid_block)
            efetch_recs = reg_get_author_count(pmid_block)

            with open('data/pubmed_esearch/block_{}.pickle'.format(block_num),
                      'wb') as out_conn:
                pickle.dump(efetch_recs, out_conn)

            with open('data/icite/block_{}.pickle'.format(block_num),
                      'wb') as out_conn:
                pickle.dump(icite_recs, out_conn)
Exemple #26
0
 async def mainloop(self):
     '''create tasks and wait for responses & update self.results'''
     jar = aiohttp.DummyCookieJar()
     timeout = aiohttp.ClientTimeout(total=self.mco.options['total_timeout'])
     tasks = []
     self.mco.runfile['start_time'] = self.nowtime()
     async with aiohttp.ClientSession(cookie_jar=jar,
                                      timeout=timeout) as session:
         # for url in [chk['url'] for chk in self.mco.runfile['urls']]:
         for chk in self.mco.runfile['urls']:
             # get defaultcheck if no test defined
             test = chk.get('test', self.mco.options['defaultcheck'])
             tasks.append(asyncio.ensure_future(self.fetch(chk['url'],
                                                           test, session)))
         try:
             self.results = await asyncio.gather(*tasks)
         except TimeoutError:
             # exceeded total_timeout
             error = 'Exceeded total timeout (%ss)' % \
                 self.mco.options['total_timeout']
             self.mco.logger.warning(error)
             self.info = error
     self.mco.runfile['finish_time'] = self.nowtime()
Exemple #27
0
    def _get_client(
        self,
        timeout_sec: t.Optional[float] = None,
    ) -> "ClientSession":
        import aiohttp

        if (self._loop is None or self._client is None or self._client.closed
                or self._loop.is_closed()):
            import yarl
            from opentelemetry.instrumentation.aiohttp_client import create_trace_config

            def strip_query_params(url: yarl.URL) -> str:
                return str(url.with_query(None))

            jar = aiohttp.DummyCookieJar()
            if timeout_sec is not None:
                timeout = aiohttp.ClientTimeout(total=timeout_sec)
            else:
                DEFAULT_TIMEOUT = aiohttp.ClientTimeout(total=5 * 60)
                timeout = DEFAULT_TIMEOUT
            self._client = aiohttp.ClientSession(
                trace_configs=[
                    create_trace_config(
                        # Remove all query params from the URL attribute on the span.
                        url_filter=strip_query_params,
                        tracer_provider=DeploymentContainer.tracer_provider.
                        get(),
                    )
                ],
                connector=self._get_conn(),
                auto_decompress=False,
                cookie_jar=jar,
                connector_owner=False,
                timeout=timeout,
                loop=self._loop,
            )
        return self._client
Exemple #28
0
    def __init__(self, load=None, no_test=False, paused=False):
        self.loop = asyncio.get_event_loop()
        self.burner = burner.Burner('parser')
        self.stopping = False
        self.paused = paused
        self.no_test = no_test
        self.next_minute = 0
        self.next_hour = time.time() + 3600
        self.max_page_size = int(config.read('Crawl', 'MaxPageSize'))
        self.prevent_compression = config.read('Crawl', 'PreventCompression')
        self.upgrade_insecure_requests = config.read(
            'Crawl', 'UpgradeInsecureRequests')
        self.max_workers = int(config.read('Crawl', 'MaxWorkers'))
        self.workers = []

        try:
            # this works for the installed package
            self.version = get_distribution(__name__).version
        except DistributionNotFound:
            # this works for an uninstalled git repo, like in the CI infrastructure
            self.version = get_version(root='..', relative_to=__file__)
        self.warcheader_version = '0.99'

        self.robotname, self.ua = useragent.useragent(self.version)

        self.resolver = dns.get_resolver()

        geoip.init()

        self.conn_kwargs = {
            'use_dns_cache': False,
            'resolver': self.resolver,
            'limit': max(1, self.max_workers // 2),
            'enable_cleanup_closed': True
        }
        local_addr = config.read('Fetcher', 'LocalAddr')
        if local_addr:
            self.conn_kwargs['local_addr'] = (local_addr, 0)
        self.conn_kwargs[
            'family'] = socket.AF_INET  # XXX config option -- this is ipv4 only
        self.conn_kwargs['ssl'] = ssl.create_default_context(
            cafile=certifi.where())
        # see https://bugs.python.org/issue27970 for python not handling missing intermediates

        conn = aiohttp.connector.TCPConnector(**self.conn_kwargs)
        self.connector = conn

        connect_timeout = float(config.read('Crawl', 'ConnectTimeout'))
        page_timeout = float(config.read('Crawl', 'PageTimeout'))
        timeout_kwargs = {}
        if connect_timeout:
            timeout_kwargs['sock_connect'] = connect_timeout
        if page_timeout:
            timeout_kwargs['total'] = page_timeout
        timeout = aiohttp.ClientTimeout(**timeout_kwargs)

        cookie_jar = aiohttp.DummyCookieJar()
        self.session = aiohttp.ClientSession(connector=conn,
                                             cookie_jar=cookie_jar,
                                             auto_decompress=False,
                                             timeout=timeout)

        self.datalayer = datalayer.Datalayer()
        self.robots = robots.Robots(self.robotname, self.session,
                                    self.datalayer)
        self.scheduler = scheduler.Scheduler(self.robots, self.resolver)

        self.crawllog = config.read('Logging', 'Crawllog')
        if self.crawllog:
            self.crawllogfd = open(self.crawllog, 'a')
        else:
            self.crawllogfd = None

        self.frontierlog = config.read('Logging', 'Frontierlog')
        if self.frontierlog:
            self.frontierlogfd = open(self.frontierlog, 'a')
        else:
            self.frontierlogfd = None

        self.rejectedaddurl = config.read('Logging', 'RejectedAddUrllog')
        if self.rejectedaddurl:
            self.rejectedaddurlfd = open(self.rejectedaddurl, 'a')
        else:
            self.rejectedaddurlfd = None

        self.facetlog = config.read('Logging', 'Facetlog')
        if self.facetlog:
            self.facetlogfd = open(self.facetlog, 'a')
        else:
            self.facetlogfd = None

        self.warcwriter = warc.setup(self.version, self.warcheader_version,
                                     local_addr)

        url_allowed.setup()
        stats.init()

        if load is not None:
            self.load_all(load)
            LOGGER.info('after loading saved state, work queue is %r urls',
                        self.scheduler.qsize())
            LOGGER.info('at time of loading, stats are')
            stats.report()
        else:
            self._seeds = seeds.expand_seeds_config(self)
            LOGGER.info('after adding seeds, work queue is %r urls',
                        self.scheduler.qsize())
            stats.stats_max('initial seeds', self.scheduler.qsize())

        self.stop_crawler = os.path.expanduser('~/STOPCRAWLER.{}'.format(
            os.getpid()))
        LOGGER.info('Touch %s to stop the crawler.', self.stop_crawler)

        self.pause_crawler = os.path.expanduser('~/PAUSECRAWLER.{}'.format(
            os.getpid()))
        LOGGER.info('Touch %s to pause the crawler.', self.pause_crawler)

        self.memory_crawler = os.path.expanduser('~/MEMORYCRAWLER.{}'.format(
            os.getpid()))
        LOGGER.info('Use %s to debug objects in the crawler.',
                    self.memory_crawler)

        fetcher.establish_filters()
Exemple #29
0
async def async_main():
    stdout_handler = logging.StreamHandler(sys.stdout)
    for logger_name in [
            'aiohttp.server', 'aiohttp.web', 'aiohttp.access', 'proxy'
    ]:
        logger = logging.getLogger(logger_name)
        logger.setLevel(logging.INFO)
        logger.addHandler(stdout_handler)

    env = normalise_environment(os.environ)
    port = int(env['PROXY_PORT'])
    admin_root = env['UPSTREAM_ROOT']
    superset_root = env['SUPERSET_ROOT']
    hawk_senders = env['HAWK_SENDERS']
    sso_base_url = env['AUTHBROKER_URL']
    sso_host = URL(sso_base_url).host
    sso_client_id = env['AUTHBROKER_CLIENT_ID']
    sso_client_secret = env['AUTHBROKER_CLIENT_SECRET']
    redis_url = env['REDIS_URL']
    root_domain = env['APPLICATION_ROOT_DOMAIN']
    basic_auth_user = env['METRICS_SERVICE_DISCOVERY_BASIC_AUTH_USER']
    basic_auth_password = env['METRICS_SERVICE_DISCOVERY_BASIC_AUTH_PASSWORD']
    x_forwarded_for_trusted_hops = int(env['X_FORWARDED_FOR_TRUSTED_HOPS'])
    application_ip_whitelist = env['APPLICATION_IP_WHITELIST']
    ga_tracking_id = env.get('GA_TRACKING_ID')
    mirror_remote_root = env['MIRROR_REMOTE_ROOT']
    mirror_local_root = '/__mirror/'

    root_domain_no_port, _, root_port_str = root_domain.partition(':')
    try:
        root_port = int(root_port_str)
    except ValueError:
        root_port = None

    csp_common = "object-src 'none';"
    if root_domain not in ['dataworkspace.test:8000']:
        csp_common += 'upgrade-insecure-requests;'

    # A spawning application on <my-application>.<root_domain> shows the admin-styled site,
    # fetching assets from <root_domain>, but also makes requests to the current domain
    csp_application_spawning = csp_common + (
        f'default-src {root_domain};'
        f'base-uri {root_domain};'
        f'font-src {root_domain} data:  https://fonts.gstatic.com;'
        f'form-action {root_domain} *.{root_domain};'
        f'frame-ancestors {root_domain};'
        f'img-src {root_domain} data: https://www.googletagmanager.com https://www.google-analytics.com https://ssl.gstatic.com https://www.gstatic.com;'  # pylint: disable=line-too-long
        f"script-src 'unsafe-inline' {root_domain} https://www.googletagmanager.com https://www.google-analytics.com https://tagmanager.google.com;"  # pylint: disable=line-too-long
        f"style-src 'unsafe-inline' {root_domain} https://tagmanager.google.com https://fonts.googleapis.com;"
        f"connect-src {root_domain} 'self';")

    # A running wrapped application on <my-application>.<root_domain>  has an
    # iframe that directly routes to the app on <my-application>--8888.<root_domain>
    def csp_application_running_wrapped(direct_host):
        return csp_common + (
            f"default-src 'none';"
            f'base-uri {root_domain};'
            f"form-action 'none';"
            f"frame-ancestors 'none';"
            f'frame-src {direct_host} {sso_host} https://www.googletagmanager.com;'
            f'img-src {root_domain} https://www.googletagmanager.com https://www.google-analytics.com https://ssl.gstatic.com https://www.gstatic.com;'  # pylint: disable=line-too-long
            f"font-src {root_domain} data: https://fonts.gstatic.com;"
            f"script-src 'unsafe-inline' https://www.googletagmanager.com https://www.google-analytics.com https://tagmanager.google.com;"  # pylint: disable=line-too-long
            f"style-src 'unsafe-inline' {root_domain} https://tagmanager.google.com https://fonts.googleapis.com;"
        )

    # A running application should only connect to self: this is where we have the most
    # concern because we run the least-trusted code
    def csp_application_running_direct(host, public_host):
        return csp_common + (
            "default-src 'self';"
            "base-uri 'self';"
            # Safari does not have a 'self' for WebSockets
            f"connect-src 'self' wss://{host};"
            "font-src 'self' data:;"
            "form-action 'self';"
            f"frame-ancestors 'self' {root_domain} {public_host}.{root_domain};"
            "img-src 'self' data: blob:;"
            # Both JupyterLab and RStudio need `unsafe-eval`
            "script-src 'unsafe-inline' 'unsafe-eval' 'self';"
            "style-src 'unsafe-inline' 'self';"
            "worker-src 'self' blob:;")

    redis_pool = await aioredis.create_redis_pool(redis_url)

    default_http_timeout = aiohttp.ClientTimeout()

    # When spawning and tring to detect if the app is running,
    # we fail quickly and often so a connection check is quick
    spawning_http_timeout = aiohttp.ClientTimeout(sock_read=5, sock_connect=2)

    def get_random_context_logger():
        return ContextAdapter(
            logger,
            {'context': ''.join(random.choices(CONTEXT_ALPHABET, k=8))})

    def without_transfer_encoding(request_or_response):
        return tuple((key, value)
                     for key, value in request_or_response.headers.items()
                     if key.lower() != 'transfer-encoding')

    def admin_headers(downstream_request):
        return (without_transfer_encoding(downstream_request) +
                downstream_request['sso_profile_headers'])

    def mirror_headers(downstream_request):
        return tuple((key, value)
                     for key, value in downstream_request.headers.items()
                     if key.lower() not in ['host', 'transfer-encoding'])

    def application_headers(downstream_request):
        return without_transfer_encoding(downstream_request) + (
            (('x-scheme', downstream_request.headers['x-forwarded-proto']), )
            if 'x-forwarded-proto' in downstream_request.headers else ())

    def superset_headers(downstream_request):
        return (without_transfer_encoding(downstream_request) +
                downstream_request['sso_profile_headers'])

    def is_service_discovery(request):
        return (request.url.path == '/api/v1/application'
                and request.url.host == root_domain_no_port
                and request.method == 'GET')

    def is_superset_requested(request):
        return request.url.host == f'superset.{root_domain_no_port}'

    def is_app_requested(request):
        return (request.url.host.endswith(f'.{root_domain_no_port}')
                and not request.url.path.startswith(mirror_local_root)
                and not is_superset_requested(request))

    def is_mirror_requested(request):
        return request.url.host.endswith(
            f'.{root_domain_no_port}') and request.url.path.startswith(
                mirror_local_root)

    def is_requesting_credentials(request):
        return (request.url.host == root_domain_no_port
                and request.url.path == '/api/v1/aws_credentials')

    def is_requesting_files(request):
        return request.url.host == root_domain_no_port and request.url.path == '/files'

    def is_dataset_requested(request):
        return (request.url.path.startswith('/api/v1/dataset/')
                or request.url.path.startswith('/api/v1/reference-dataset/')
                or request.url.path.startswith('/api/v1/eventlog/')
                or request.url.path.startswith('/api/v1/account/')
                or request.url.path.startswith('/api/v1/application-instance/')
                and request.url.host == root_domain_no_port)

    def is_hawk_auth_required(request):
        return is_dataset_requested(request)

    def is_healthcheck_requested(request):
        return (request.url.path == '/healthcheck' and request.method == 'GET'
                and not is_app_requested(request))

    def is_table_requested(request):
        return (request.url.path.startswith('/api/v1/table/')
                and request.url.host == root_domain_no_port
                and request.method == 'POST')

    def is_sso_auth_required(request):
        return (not is_healthcheck_requested(request)
                and not is_service_discovery(request)
                and not is_table_requested(request)
                and not is_dataset_requested(request))

    def get_peer_ip(request):
        peer_ip = (request.headers['x-forwarded-for'].split(',')
                   [-x_forwarded_for_trusted_hops].strip())

        is_private = True
        try:
            is_private = ipaddress.ip_address(peer_ip).is_private
        except ValueError:
            is_private = False

        return peer_ip, is_private

    def request_scheme(request):
        return request.headers.get('x-forwarded-proto', request.url.scheme)

    def request_url(request):
        return str(request.url.with_scheme(request_scheme(request)))

    async def handle(downstream_request):
        method = downstream_request.method
        path = downstream_request.url.path
        query = downstream_request.url.query
        app_requested = is_app_requested(downstream_request)
        mirror_requested = is_mirror_requested(downstream_request)
        superset_requested = is_superset_requested(downstream_request)

        # Websocket connections
        # - tend to close unexpectedly, both from the client and app
        # - don't need to show anything nice to the user on error
        is_websocket = (downstream_request.headers.get('connection',
                                                       '').lower() == 'upgrade'
                        and downstream_request.headers.get(
                            'upgrade', '').lower() == 'websocket')

        try:
            if app_requested:
                return await handle_application(is_websocket,
                                                downstream_request, method,
                                                path, query)
            if mirror_requested:
                return await handle_mirror(downstream_request, method, path)
            if superset_requested:
                return await handle_superset(downstream_request, method, path,
                                             query)
            return await handle_admin(downstream_request, method, path, query)
        except Exception as exception:
            user_exception = isinstance(exception, UserException)
            if not user_exception or (user_exception
                                      and exception.args[1] == 500):
                logger.exception(
                    'Exception during %s %s %s',
                    downstream_request.method,
                    downstream_request.url,
                    type(exception),
                )

            if is_websocket:
                raise

            params = {'message': exception.args[0]} if user_exception else {}

            status = exception.args[1] if user_exception else 500

            return await handle_http(
                downstream_request,
                'GET',
                CIMultiDict(admin_headers(downstream_request)),
                URL(admin_root).with_path(f'/error_{status}'),
                params,
                default_http_timeout,
            )

    async def handle_application(is_websocket, downstream_request, method,
                                 path, query):
        public_host, _, _ = downstream_request.url.host.partition(
            f'.{root_domain_no_port}')
        possible_public_host, _, public_host_or_port_override = public_host.rpartition(
            '--')
        try:
            port_override = int(public_host_or_port_override)
        except ValueError:
            port_override = None
        else:
            public_host = possible_public_host
        host_api_url = admin_root + '/api/v1/application/' + public_host
        host_html_path = '/tools/' + public_host

        async with client_session.request(
                'GET',
                host_api_url,
                headers=CIMultiDict(
                    admin_headers(downstream_request))) as response:
            host_exists = response.status == 200
            application = await response.json()

        if response.status != 200 and response.status != 404:
            raise UserException('Unable to start the application',
                                response.status)

        if host_exists and application['state'] not in ['SPAWNING', 'RUNNING']:
            if ('x-data-workspace-no-modify-application-instance'
                    not in downstream_request.headers):
                async with client_session.request(
                        'DELETE',
                        host_api_url,
                        headers=CIMultiDict(admin_headers(downstream_request)),
                ) as delete_response:
                    await delete_response.read()
            raise UserException('Application ' + application['state'], 500)

        if not host_exists:
            if ('x-data-workspace-no-modify-application-instance'
                    not in downstream_request.headers):
                params = {
                    key: value
                    for key, value in downstream_request.query.items()
                    if key == '__memory_cpu'
                }
                async with client_session.request(
                        'PUT',
                        host_api_url,
                        params=params,
                        headers=CIMultiDict(admin_headers(downstream_request)),
                ) as response:
                    host_exists = response.status == 200
                    application = await response.json()
                if params:
                    return web.Response(status=302, headers={'location': '/'})
            else:
                raise UserException('Application stopped while starting', 500)

        if response.status != 200:
            raise UserException('Unable to start the application',
                                response.status)

        if application['state'] not in ['SPAWNING', 'RUNNING']:
            raise UserException(
                'Attempted to start the application, but it ' +
                application['state'],
                500,
            )

        if not application['proxy_url']:
            return await handle_http(
                downstream_request,
                'GET',
                CIMultiDict(admin_headers(downstream_request)),
                admin_root + host_html_path + '/spawning',
                {},
                default_http_timeout,
                (('content-security-policy', csp_application_spawning), ),
            )

        if is_websocket:
            return await handle_application_websocket(downstream_request,
                                                      application['proxy_url'],
                                                      path, query,
                                                      port_override)

        if application['state'] == 'SPAWNING':
            return await handle_application_http_spawning(
                downstream_request,
                method,
                application_upstream(application['proxy_url'], path,
                                     port_override),
                query,
                host_html_path,
                host_api_url,
                public_host,
            )

        if (application['state'] == 'RUNNING' and application['wrap'] != 'NONE'
                and not port_override):
            return await handle_application_http_running_wrapped(
                downstream_request,
                application_upstream(application['proxy_url'], path,
                                     port_override),
                host_html_path,
                public_host,
            )

        return await handle_application_http_running_direct(
            downstream_request,
            method,
            application_upstream(application['proxy_url'], path,
                                 port_override),
            query,
            public_host,
        )

    async def handle_application_websocket(downstream_request, proxy_url, path,
                                           query, port_override):
        upstream_url = application_upstream(proxy_url, path,
                                            port_override).with_query(query)
        return await handle_websocket(
            downstream_request,
            CIMultiDict(application_headers(downstream_request)),
            upstream_url,
        )

    def application_upstream(proxy_url, path, port_override):
        return (URL(proxy_url).with_path(path) if port_override is None else
                URL(proxy_url).with_path(path).with_port(port_override))

    async def handle_application_http_spawning(
        downstream_request,
        method,
        upstream_url,
        query,
        host_html_path,
        host_api_url,
        public_host,
    ):
        host = downstream_request.headers['host']
        try:
            logger.info('Spawning: Attempting to connect to %s', upstream_url)
            response = await handle_http(
                downstream_request,
                method,
                CIMultiDict(application_headers(downstream_request)),
                upstream_url,
                query,
                spawning_http_timeout,
                # Although the application is spawning, if the response makes it back to the client,
                # we know the application is running, so we return the _running_ CSP headers
                (
                    (
                        'content-security-policy',
                        csp_application_running_direct(host, public_host),
                    ), ),
            )

        except Exception:
            logger.info('Spawning: Failed to connect to %s', upstream_url)
            return await handle_http(
                downstream_request,
                'GET',
                CIMultiDict(admin_headers(downstream_request)),
                admin_root + host_html_path + '/spawning',
                {},
                default_http_timeout,
                (('content-security-policy', csp_application_spawning), ),
            )

        else:
            # Once a streaming response is done, if we have not yet returned
            # from the handler, it looks like aiohttp can cancel the current
            # task. We set RUNNING in another task to avoid it being cancelled
            async def set_application_running():
                async with client_session.request(
                        'PATCH',
                        host_api_url,
                        json={'state': 'RUNNING'},
                        headers=CIMultiDict(admin_headers(downstream_request)),
                        timeout=default_http_timeout,
                ) as patch_response:
                    await patch_response.read()

            asyncio.ensure_future(set_application_running())

            await send_to_google_analytics(downstream_request)

            return response

    async def handle_application_http_running_wrapped(downstream_request,
                                                      upstream_url,
                                                      host_html_path,
                                                      public_host):
        upstream = URL(upstream_url)
        direct_host = f'{public_host}--{upstream.port}.{root_domain}'
        return await handle_http(
            downstream_request,
            'GET',
            CIMultiDict(admin_headers(downstream_request)),
            admin_root + host_html_path + '/running',
            {},
            default_http_timeout,
            ((
                'content-security-policy',
                csp_application_running_wrapped(direct_host),
            ), ),
        )

    async def handle_application_http_running_direct(downstream_request,
                                                     method, upstream_url,
                                                     query, public_host):
        host = downstream_request.headers['host']

        await send_to_google_analytics(downstream_request)

        return await handle_http(
            downstream_request,
            method,
            CIMultiDict(application_headers(downstream_request)),
            upstream_url,
            query,
            default_http_timeout,
            ((
                'content-security-policy',
                csp_application_running_direct(host, public_host),
            ), ),
        )

    async def handle_mirror(downstream_request, method, path):
        mirror_path = path[len(mirror_local_root):]
        upstream_url = URL(mirror_remote_root + mirror_path)
        return await handle_http(
            downstream_request,
            method,
            CIMultiDict(mirror_headers(downstream_request)),
            upstream_url,
            {},
            default_http_timeout,
        )

    async def handle_superset(downstream_request, method, path, query):
        upstream_url = URL(superset_root).with_path(path)
        host = downstream_request.headers['host']
        return await handle_http(
            downstream_request,
            method,
            CIMultiDict(superset_headers(downstream_request)),
            upstream_url,
            query,
            default_http_timeout,
            ((
                'content-security-policy',
                csp_application_running_direct(host, 'superset'),
            ), ),
        )

    async def handle_admin(downstream_request, method, path, query):
        upstream_url = URL(admin_root).with_path(path)
        return await handle_http(
            downstream_request,
            method,
            CIMultiDict(admin_headers(downstream_request)),
            upstream_url,
            query,
            default_http_timeout,
        )

    async def handle_websocket(downstream_request, upstream_headers,
                               upstream_url):
        protocol = downstream_request.headers.get('Sec-WebSocket-Protocol')
        protocols = (protocol, ) if protocol else ()

        async def proxy_msg(msg, to_ws):
            if msg.type == aiohttp.WSMsgType.TEXT:
                await to_ws.send_str(msg.data)

            elif msg.type == aiohttp.WSMsgType.BINARY:
                await to_ws.send_bytes(msg.data)

            elif msg.type == aiohttp.WSMsgType.CLOSE:
                await to_ws.close()

            elif msg.type == aiohttp.WSMsgType.ERROR:
                await to_ws.close()

        async def upstream():
            try:
                async with client_session.ws_connect(
                        str(upstream_url),
                        headers=upstream_headers,
                        protocols=protocols) as upstream_ws:
                    upstream_connection.set_result(upstream_ws)
                    downstream_ws = await downstream_connection
                    async for msg in upstream_ws:
                        await proxy_msg(msg, downstream_ws)
            except BaseException as exception:
                if not upstream_connection.done():
                    upstream_connection.set_exception(exception)
                raise
            finally:
                try:
                    await downstream_ws.close()
                except UnboundLocalError:
                    # If we didn't get to the line that creates `downstream_ws`
                    pass

        # This is slightly convoluted, but aiohttp documents that reading
        # from websockets should be done in the same task as the websocket was
        # created, so we read from downstream in _this_ task, and create
        # another task to connect to and read from the upstream socket. We
        # also need to make sure we wait for each connection before sending
        # data to it
        downstream_connection = asyncio.Future()
        upstream_connection = asyncio.Future()
        upstream_task = asyncio.ensure_future(upstream())

        try:
            upstream_ws = await upstream_connection
            _, _, _, with_session_cookie = downstream_request[SESSION_KEY]
            downstream_ws = await with_session_cookie(
                web.WebSocketResponse(protocols=protocols))

            await downstream_ws.prepare(downstream_request)
            downstream_connection.set_result(downstream_ws)

            async for msg in downstream_ws:
                await proxy_msg(msg, upstream_ws)

        finally:
            upstream_task.cancel()

        return downstream_ws

    async def send_to_google_analytics(downstream_request):
        # Not perfect, but a good enough guide for usage
        _, extension = os.path.splitext(downstream_request.url.path)
        send_to_google = ga_tracking_id and extension in {
            '',
            '.doc',
            '.docx',
            '.html',
            '.pdf',
            '.ppt',
            '.pptx',
            '.xlsx',
            '.xlsx',
        }

        if not send_to_google:
            return

        async def _send():
            logger.info("Sending to Google Analytics %s...",
                        downstream_request.url)
            peer_ip, _ = get_peer_ip(downstream_request)

            response = await client_session.request(
                'POST',
                'https://www.google-analytics.com/collect',
                data={
                    'v': '1',
                    'tid': ga_tracking_id,
                    'cid': str(uuid.uuid4()),
                    't': 'pageview',
                    'uip': peer_ip,
                    'dh': downstream_request.url.host,
                    'dp': downstream_request.url.path_qs,
                    'ds': 'data-workspace-server',
                    'dr': downstream_request.headers.get('referer', ''),
                    'ua': downstream_request.headers.get('user-agent', ''),
                },
                timeout=default_http_timeout,
            )
            logger.info("Sending to Google Analytics %s... %s",
                        downstream_request.url, response)

        asyncio.create_task(_send())

    async def handle_http(
            downstream_request,
            upstream_method,
            upstream_headers,
            upstream_url,
            upstream_query,
            timeout,
            response_headers=tuple(),
    ):
        # Avoid aiohttp treating request as chunked unnecessarily, which works
        # for some upstream servers, but not all. Specifically RStudio drops
        # GET responses half way through if the request specified a chunked
        # encoding. AFAIK RStudio uses a custom webserver, so this behaviour
        # is not documented anywhere.

        # fmt: off
        data = \
            b'' if (
                'content-length' not in upstream_headers
                and downstream_request.headers.get('transfer-encoding', '').lower() != 'chunked'
            ) else \
            await downstream_request.read() if downstream_request.content.at_eof() else \
            downstream_request.content
        # fmt: on
        async with client_session.request(
            upstream_method,
            str(upstream_url),
            params=upstream_query,
            headers=upstream_headers,
            data=data,
            allow_redirects=False,
            timeout=timeout,
        ) as upstream_response:

            _, _, _, with_session_cookie = downstream_request[SESSION_KEY]
            downstream_response = await with_session_cookie(
                web.StreamResponse(
                    status=upstream_response.status,
                    headers=CIMultiDict(
                        without_transfer_encoding(upstream_response) +
                        response_headers),
                ))
            await downstream_response.prepare(downstream_request)
            async for chunk in upstream_response.content.iter_any():
                await downstream_response.write(chunk)

        return downstream_response

    def server_logger():
        @web.middleware
        async def _server_logger(request, handler):

            request_logger = get_random_context_logger()
            request['logger'] = request_logger
            url = request_url(request)

            request_logger.info(
                'Receiving (%s) (%s) (%s) (%s)',
                request.method,
                url,
                request.headers.get('User-Agent', '-'),
                request.headers.get('X-Forwarded-For', '-'),
            )

            response = await handler(request)

            request_logger.info(
                'Responding (%s) (%s) (%s) (%s) (%s) (%s)',
                request.method,
                url,
                request.headers.get('User-Agent', '-'),
                request.headers.get('X-Forwarded-For', '-'),
                response.status,
                response.content_length,
            )

            return response

        return _server_logger

    def authenticate_by_staff_sso_token():

        me_path = 'api/v1/user/me/'

        @web.middleware
        async def _authenticate_by_staff_sso_token(request, handler):
            staff_sso_token_required = is_table_requested(request)
            request.setdefault('sso_profile_headers', ())

            if not staff_sso_token_required:
                return await handler(request)

            if 'Authorization' not in request.headers:
                request['logger'].info(
                    'SSO-token unathenticated: missing authorization header')
                return await handle_admin(request, 'GET', '/error_403', {})

            async with client_session.get(
                    f'{sso_base_url}{me_path}',
                    headers={
                        'Authorization': request.headers['Authorization']
                    },
            ) as me_response:
                me_profile = (await me_response.json()
                              if me_response.status == 200 else None)

            if not me_profile:
                request['logger'].info(
                    'SSO-token unathenticated: bad authorization header')
                return await handle_admin(request, 'GET', '/error_403', {})

            request['sso_profile_headers'] = (
                ('sso-profile-email', me_profile['email']),
                ('sso-profile-contact-email', me_profile['contact_email']),
                (
                    'sso-profile-related-emails',
                    ','.join(me_profile.get('related_emails', [])),
                ),
                ('sso-profile-user-id', me_profile['user_id']),
                ('sso-profile-first-name', me_profile['first_name']),
                ('sso-profile-last-name', me_profile['last_name']),
            )

            request['logger'].info(
                'SSO-token authenticated: %s %s',
                me_profile['email'],
                me_profile['user_id'],
            )

            return await handler(request)

        return _authenticate_by_staff_sso_token

    def authenticate_by_staff_sso():

        auth_path = 'o/authorize/'
        token_path = 'o/token/'
        me_path = 'api/v1/user/me/'
        grant_type = 'authorization_code'
        scope = 'read write'
        response_type = 'code'

        redirect_from_sso_path = '/__redirect_from_sso'
        session_token_key = 'staff_sso_access_token'

        async def get_redirect_uri_authenticate(set_session_value,
                                                redirect_uri_final):
            scheme = URL(redirect_uri_final).scheme
            sso_state = await set_redirect_uri_final(set_session_value,
                                                     redirect_uri_final)

            redirect_uri_callback = urllib.parse.quote(
                get_redirect_uri_callback(scheme), safe='')
            return (f'{sso_base_url}{auth_path}?'
                    f'scope={scope}&state={sso_state}&'
                    f'redirect_uri={redirect_uri_callback}&'
                    f'response_type={response_type}&'
                    f'client_id={sso_client_id}')

        def get_redirect_uri_callback(scheme):
            return str(
                URL.build(
                    host=root_domain_no_port,
                    port=root_port,
                    scheme=scheme,
                    path=redirect_from_sso_path,
                ))

        async def set_redirect_uri_final(set_session_value,
                                         redirect_uri_final):
            session_key = secrets.token_hex(32)
            sso_state = urllib.parse.quote(
                f'{session_key}_{redirect_uri_final}', safe='')

            await set_session_value(session_key, redirect_uri_final)

            return sso_state

        async def get_redirect_uri_final(get_session_value, sso_state):
            session_key, _, state_redirect_url = urllib.parse.unquote(
                sso_state).partition('_')
            return state_redirect_url, await get_session_value(session_key)

        async def redirection_to_sso(with_new_session_cookie,
                                     set_session_value, redirect_uri_final):
            return await with_new_session_cookie(
                web.Response(
                    status=302,
                    headers={
                        'Location':
                        await
                        get_redirect_uri_authenticate(set_session_value,
                                                      redirect_uri_final)
                    },
                ))

        @web.middleware
        async def _authenticate_by_sso(request, handler):
            sso_auth_required = is_sso_auth_required(request)

            if not sso_auth_required:
                request.setdefault('sso_profile_headers', ())
                return await handler(request)

            get_session_value, set_session_value, with_new_session_cookie, _ = request[
                SESSION_KEY]

            token = await get_session_value(session_token_key)
            if request.path != redirect_from_sso_path and token is None:
                return await redirection_to_sso(with_new_session_cookie,
                                                set_session_value,
                                                request_url(request))

            if request.path == redirect_from_sso_path:
                code = request.query['code']
                sso_state = request.query['state']
                (
                    redirect_uri_final_from_url,
                    redirect_uri_final_from_session,
                ) = await get_redirect_uri_final(get_session_value, sso_state)

                if redirect_uri_final_from_url != redirect_uri_final_from_session:
                    # We might have been overtaken by a parallel request initiating another auth
                    # flow, and so another session. However, because we haven't retrieved the final
                    # URL from the session, we can't be sure that this is the same client that
                    # initiated this flow. However, we can redirect back to SSO
                    return await redirection_to_sso(
                        with_new_session_cookie,
                        set_session_value,
                        redirect_uri_final_from_url,
                    )

                async with client_session.post(
                        f'{sso_base_url}{token_path}',
                        data={
                            'grant_type':
                            grant_type,
                            'code':
                            code,
                            'client_id':
                            sso_client_id,
                            'client_secret':
                            sso_client_secret,
                            'redirect_uri':
                            get_redirect_uri_callback(request_scheme(request)),
                        },
                ) as sso_response:
                    sso_response_json = await sso_response.json()
                await set_session_value(session_token_key,
                                        sso_response_json['access_token'])
                return await with_new_session_cookie(
                    web.Response(
                        status=302,
                        headers={'Location': redirect_uri_final_from_session},
                    ))

            # Get profile from Redis cache to avoid calling SSO on every request
            redis_profile_key = f'{PROFILE_CACHE_PREFIX}___{session_token_key}___{token}'.encode(
                'ascii')
            with await redis_pool as conn:
                me_profile_raw = await conn.execute('GET', redis_profile_key)
            me_profile = json.loads(me_profile_raw) if me_profile_raw else None

            async def handler_with_sso_headers():
                request['sso_profile_headers'] = (
                    ('sso-profile-email', me_profile['email']),
                    # The default value of '' should be able to be removed after the cached
                    # profile in Redis without contact_email has expired, i.e. 60 seconds after
                    # deployment of this change
                    ('sso-profile-contact-email',
                     me_profile.get('contact_email', '')),
                    (
                        'sso-profile-related-emails',
                        ','.join(me_profile.get('related_emails', [])),
                    ),
                    ('sso-profile-user-id', me_profile['user_id']),
                    ('sso-profile-first-name', me_profile['first_name']),
                    ('sso-profile-last-name', me_profile['last_name']),
                )

                request['logger'].info(
                    'SSO-authenticated: %s %s %s',
                    me_profile['email'],
                    me_profile['user_id'],
                    request_url(request),
                )

                return await handler(request)

            if me_profile:
                return await handler_with_sso_headers()

            async with client_session.get(
                    f'{sso_base_url}{me_path}',
                    headers={'Authorization':
                             f'Bearer {token}'}) as me_response:
                me_profile_full = (await me_response.json()
                                   if me_response.status == 200 else None)

            if not me_profile_full:
                return await redirection_to_sso(with_new_session_cookie,
                                                set_session_value,
                                                request_url(request))

            me_profile = {
                'email': me_profile_full['email'],
                'related_emails': me_profile_full['related_emails'],
                'contact_email': me_profile_full['contact_email'],
                'user_id': me_profile_full['user_id'],
                'first_name': me_profile_full['first_name'],
                'last_name': me_profile_full['last_name'],
            }
            with await redis_pool as conn:
                await conn.execute(
                    'SET',
                    redis_profile_key,
                    json.dumps(me_profile).encode('utf-8'),
                    'EX',
                    60,
                )

            return await handler_with_sso_headers()

        return _authenticate_by_sso

    def authenticate_by_basic_auth():
        @web.middleware
        async def _authenticate_by_basic_auth(request, handler):
            basic_auth_required = is_service_discovery(request)

            if not basic_auth_required:
                return await handler(request)

            if 'Authorization' not in request.headers:
                return web.Response(status=401)

            basic_auth_prefix = 'Basic '
            auth_value = (request.headers['Authorization']
                          [len(basic_auth_prefix):].strip().encode('ascii'))
            required_auth_value = base64.b64encode(
                f'{basic_auth_user}:{basic_auth_password}'.encode('ascii'))

            if len(auth_value) != len(
                    required_auth_value) or not hmac.compare_digest(
                        auth_value, required_auth_value):
                return web.Response(status=401)

            request['logger'].info('Basic-authenticated: %s', basic_auth_user)

            return await handler(request)

        return _authenticate_by_basic_auth

    def authenticate_by_hawk_auth():
        async def lookup_credentials(sender_id):
            for hawk_sender in hawk_senders:
                if hawk_sender['id'] == sender_id:
                    return hawk_sender

        async def seen_nonce(nonce, sender_id):
            nonce_key = f'nonce-{sender_id}-{nonce}'
            with await redis_pool as conn:
                response = await conn.execute('SET', nonce_key, '1', 'EX', 60,
                                              'NX')
                seen_nonce = response != b'OK'
                return seen_nonce

        @web.middleware
        async def _authenticate_by_hawk_auth(request, handler):
            hawk_auth_required = is_hawk_auth_required(request)

            if not hawk_auth_required:
                return await handler(request)

            try:
                authorization_header = request.headers['Authorization']
            except KeyError:
                request['logger'].info('Hawk missing header')
                return web.Response(status=401)

            content = await request.read()

            error_message, creds = await authenticate_hawk_header(
                lookup_credentials,
                seen_nonce,
                15,
                authorization_header,
                request.method,
                request.url.host,
                request.url.port,
                request.url.path_qs,
                request.headers['Content-Type'],
                content,
            )
            if error_message is not None:
                request['logger'].info('Hawk unauthenticated: %s',
                                       error_message)
                return web.Response(status=401)

            request['logger'].info('Hawk authenticated: %s', creds['id'])

            return await handler(request)

        return _authenticate_by_hawk_auth

    def authenticate_by_ip_whitelist():
        @web.middleware
        async def _authenticate_by_ip_whitelist(request, handler):
            ip_whitelist_required = (is_app_requested(request)
                                     or is_superset_requested(request)
                                     or is_mirror_requested(request)
                                     or is_requesting_credentials(request)
                                     or is_requesting_files(request))

            if not ip_whitelist_required:
                return await handler(request)

            peer_ip, _ = get_peer_ip(request)
            peer_ip_in_whitelist = any(
                ipaddress.IPv4Address(peer_ip) in ipaddress.IPv4Network(
                    address_or_subnet)
                for address_or_subnet in application_ip_whitelist)

            if not peer_ip_in_whitelist:
                request['logger'].info('IP-whitelist unauthenticated: %s',
                                       peer_ip)
                return await handle_admin(request, 'GET', '/error_403', {})

            request['logger'].info('IP-whitelist authenticated: %s', peer_ip)
            return await handler(request)

        return _authenticate_by_ip_whitelist

    async with aiohttp.ClientSession(
            auto_decompress=False,
            cookie_jar=aiohttp.DummyCookieJar()) as client_session:
        app = web.Application(middlewares=[
            server_logger(),
            redis_session_middleware(redis_pool, root_domain_no_port),
            authenticate_by_staff_sso_token(),
            authenticate_by_staff_sso(),
            authenticate_by_basic_auth(),
            authenticate_by_hawk_auth(),
            authenticate_by_ip_whitelist(),
        ])
        app.add_routes([
            getattr(web, method)(r'/{path:.*}', handle) for method in [
                'delete',
                'get',
                'head',
                'options',
                'patch',
                'post',
                'put',
            ]
        ])

        elastic_apm_url = env.get("ELASTIC_APM_URL")
        elastic_apm_secret_token = env.get("ELASTIC_APM_SECRET_TOKEN")
        elastic_apm = ({
            'SERVICE_NAME': 'data-workspace',
            'SECRET_TOKEN': elastic_apm_secret_token,
            'SERVER_URL': elastic_apm_url,
            'ENVIRONMENT': env.get('ENVIRONMENT', 'development'),
        } if elastic_apm_secret_token else {})

        app['ELASTIC_APM'] = elastic_apm

        if elastic_apm:
            ElasticAPM(app)

        runner = web.AppRunner(app)
        await runner.setup()
        site = web.TCPSite(runner, '0.0.0.0', port)
        await site.start()
        await asyncio.Future()
Exemple #30
0
def _get_client_no_session(**kwargs):
    _client = None
    _client = aiohttp.ClientSession(connector=_get_connector(),
                                    cookie_jar=aiohttp.DummyCookieJar(),
                                    **kwargs)
    return _client