Пример #1
0
    def build_session(self,
                      apply_for: Optional[ClientSession] = None,
                      *args,
                      **kwargs) -> ClientSession:
        trace_config = TraceConfig()
        trace_config.on_request_start.append(self.on_request_start)
        trace_config.on_request_end.append(self.on_request_end)
        session = apply_for or ClientSession(
            trace_configs=[trace_config], *args, **kwargs)
        if apply_for is not None:
            session.trace_configs.append(trace_config)
            trace_config.freeze()

        def hook_method(method):
            def hooked_wrapper(*args, **kwargs):
                return method(
                    *args, **{
                        **kwargs, "trace_request_ctx": (args, kwargs)
                    })

            return hooked_wrapper

        session.get = hook_method(session.get)
        session.post = hook_method(session.post)
        return session
Пример #2
0
def _get_timing_trace_config(host_pool):
    """Get trace config to log and calculate request delay."""
    async def on_request_start(session, trace_config_ctx, params):
        now = session.loop.time()
        url = str(params.url)

        trace_config_ctx.timer = _start_timer(host_pool, url, now)

    async def on_request_end(session, trace_config_ctx, params):
        now = session.loop.time()
        _stop_timer(
            host_pool,
            params.response,
            trace_config_ctx.timer,
            end=now,
        )

    async def on_request_exception(session, trace_config_ctx, params):
        url = trace_config_ctx.timer['url']

        host_pool.increase_error_delay(url)

    trace_config = TraceConfig()

    trace_config.on_request_start.append(on_request_start)
    trace_config.on_request_end.append(on_request_end)
    trace_config.on_request_exception.append(on_request_exception)

    return trace_config
Пример #3
0
async def test_add_trace_request_ctx(aiohttp_client, loop):
    actual_request_contexts = []

    async def on_request_start(
        _: ClientSession,
        trace_config_ctx: SimpleNamespace,
        __: TraceRequestStartParams,
    ) -> None:
        actual_request_contexts.append(trace_config_ctx)

    test_app = App()

    trace_config = TraceConfig()
    trace_config.on_request_start.append(on_request_start)  # type: ignore

    retry_client = RetryClient()
    retry_client._client = await aiohttp_client(
        test_app.get_app(),
        trace_configs=[trace_config]
    )

    async with retry_client.get('/sometimes_error', trace_request_ctx={'foo': 'bar'}):
        assert test_app.counter == 3

    assert actual_request_contexts == [
        SimpleNamespace(
            trace_request_ctx={
                'foo': 'bar',
                'current_attempt': i + 1,
            },
        )
        for i in range(3)
    ]
Пример #4
0
    def __init__(self,
                 headers=None,
                 cookies=None,
                 raise_for_status=True,
                 timeout=60,
                 debug=False,
                 trace_configs=None,
                 **kwargs):
        self._url_base = "https://www.behance.net"

        headers = headers or {}
        headers.update({
            'User-Agent': useragent_generator.random,
            'Host': 'www.behance.net',
            'Origin': 'https://www.behance.net',
            'Referer': 'https://www.behance.net',
            'X-Requested-With': 'XMLHttpRequest'
        })

        trace_configs = trace_configs or []

        if debug:
            t_conf = TraceConfig()
            t_conf.on_request_start.append(self.on_debug_req_start)
            t_conf.on_request_end.append(self.on_debug_req_end)
            trace_configs.append(t_conf)

        timeout = ClientTimeout(total=timeout)

        self.session = ClientSession(headers=headers,
                                     cookies=cookies,
                                     raise_for_status=raise_for_status,
                                     timeout=timeout,
                                     trace_configs=trace_configs,
                                     **kwargs)
 def get_client_session(self, config: SeekerConfig) -> ClientSession:
     async def _on_request_start(
         session: ClientSession,
         trace_config_ctx: SimpleNamespace,
         params: TraceRequestStartParams
     ) -> None:
         current_attempt = \
             trace_config_ctx.trace_request_ctx['current_attempt']
         if(current_attempt > 1):
             logger.warning(
                 f'::warn ::Retry Attempt #{current_attempt} ' +
                 f'of {config.max_tries}: {params.url}')
     trace_config = TraceConfig()
     trace_config.on_request_start.append(_on_request_start)
     limit_per_host = max(0, config.connect_limit_per_host)
     connector = TCPConnector(
         limit_per_host=limit_per_host,
         ttl_dns_cache=600  # 10-minute DNS cache
     )
     retry_options = ExponentialRetry(
                         attempts=config.max_tries,
                         max_timeout=config.max_time,
                         exceptions=[
                             aiohttp.ClientError,
                             asyncio.TimeoutError
                         ])
     return RetryClient(
             raise_for_status=True,
             connector=connector,
             timeout=ClientTimeout(total=config.timeout),
             headers={'User-Agent': config.agent},
             retry_options=retry_options,
             trace_configs=[trace_config])
Пример #6
0
    def __init__(self):
        # pylint: disable=unused-argument
        async def on_request_start(
            session: ClientSession,
            trace_config_ctx: SimpleNamespace,
            params: TraceRequestStartParams,
        ) -> None:
            current_attempt = trace_config_ctx.trace_request_ctx[
                "current_attempt"]
            if current_attempt > 1:
                LOG.info("iNat request attempt #%d: %s", current_attempt,
                         repr(params))

        trace_config = TraceConfig()
        trace_config.on_request_start.append(on_request_start)
        self.session = RetryClient(
            raise_for_status=False,
            trace_configs=[trace_config],
        )
        self.request_time = time()
        self.places_cache = {}
        self.projects_cache = {}
        self.users_cache = {}
        self.users_login_cache = {}
        self.taxa_cache = {}
        # api_v1_limiter:
        # ---------------
        # - Allow a burst of 60 requests (i.e. equal to max_rate) in the initial
        #   seconds of the 60 second time_period before enforcing a rate limit of
        #   60 requests per minute (max_rate).
        # - This honours "try to keep it to 60 requests per minute or lower":
        #   - https://api.inaturalist.org/v1/docs/
        # - Since the iNat API doesn't throttle until 100 requests per minute,
        #   this should ensure we never get throttled.
        self.api_v1_limiter = AsyncLimiter(60, 60)
Пример #7
0
def time_tracer(collector: dict):
    # https://github.com/aio-libs/aiohttp/issues/670
    # https://github.com/aio-libs/aiohttp/issues/1692

    async def on_request_start(session, context, params):
        context.on_request_start = session.loop.time()
        context.raw_url = str(params.url)

    async def on_request_end(session, context, params):
        total_time = math.floor(
            (session.loop.time() - context.on_request_start) * 1000)
        collector[context.raw_url] = WebsitePollingResult(
            url=context.raw_url,
            status_code=params.response.status,
            response_time=total_time)

    async def on_request_exception(session, context, params):
        collector[context.raw_url] = WebsitePollingResult(
            url=context.raw_url,
            status_code=999,
            response_time=0,
        )

    trace_config = TraceConfig()
    trace_config.on_request_start.append(on_request_start)
    trace_config.on_request_end.append(on_request_end)
    trace_config.on_request_exception.append(on_request_exception)
    return trace_config
Пример #8
0
 def __init__(self, url: str, message_broker: str, regexp: str):
     self.url = url
     self.message_broker = message_broker
     self.regexp = regexp
     self.latency = -1
     self.error_code = -1
     self.trace_config = TraceConfig()
     self.checked_at = datetime.utcnow().strftime('%Y-%m-%d_%H:%M')
Пример #9
0
    def __init__(self,
                 context: "Context",
                 name: str,
                 exchange_setting: dict,
                 loop: Optional[asyncio.AbstractEventLoop] = None):
        """
        :param exchange_setting:
        example:
        {
            'engine': 'monkq.exchange.bitmex',
            "IS_TEST": True,
            "API_KEY": '',
            "API_SECRET": ''
        }
        """
        super(BitmexExchange, self).__init__(context=context,
                                             name=name,
                                             exchange_setting=exchange_setting)
        if loop:
            self._loop = loop
        else:
            self._loop = asyncio.get_event_loop()
        if self.exchange_setting['IS_TEST']:
            ws_url = BITMEX_TESTNET_WEBSOCKET_URL
        else:
            ws_url = BITMEX_WEBSOCKET_URL

        self._trace_config = TraceConfig()
        self._ssl = ssl.create_default_context()
        if context.settings.SSL_PATH:  # type:ignore
            self._ssl.load_verify_locations(
                context.settings.SSL_PATH)  # type:ignore

        self.api_key = exchange_setting.get("API_KEY", '')
        self.api_secret = exchange_setting.get("API_SECRET", '')

        self._available_instrument_cache: Dict[str, Instrument] = {}

        self._connector = TCPConnector(keepalive_timeout=90)  # type:ignore
        self.session = ClientSession(trace_configs=[self._trace_config],
                                     loop=self._loop,
                                     connector=self._connector)

        self.ws = BitmexWebsocket(strategy=context.strategy,
                                  loop=self._loop,
                                  session=self.session,
                                  ws_url=ws_url,
                                  api_key=self.api_key,
                                  api_secret=self.api_secret,
                                  ssl=self._ssl,
                                  http_proxy=None)
        proxy = self.context.settings.HTTP_PROXY or None  # type:ignore

        self.http_interface = BitMexHTTPInterface(exchange_setting,
                                                  self._connector,
                                                  self.session, self._ssl,
                                                  proxy, loop)
Пример #10
0
 def __init__(self, api_data):
     trace_config = TraceConfig()
     trace_config.on_request_end.append(on_request_end)
     trace_config.on_request_chunk_sent.append(on_request_chunk_sent)
     trace_config.on_request_redirect.append(on_request_redirect)
     self.session = ClientSession(
         raise_for_status=True,
         headers=ZalandoAPI.CONSTANT_HEADERS,
         timeout=ClientTimeout(total=SESSION_TIMEOUT),
         trace_configs=[trace_config])
     self.api = ZalandoAPI(session=self.session, **api_data)
     logging.info('task for %s is initialized', self.api)
Пример #11
0
def create_session():
    async def on_request_start(session, trace_config_ctx, params):
        trace_config_ctx.start = time.time()

    async def on_request_end(session, trace_config_ctx, params):
        end = time.time()
        elapsed = end - trace_config_ctx.start
        session.latency_record.append((trace_config_ctx.start, end, elapsed))

    trace_config = TraceConfig()
    trace_config.on_request_start.append(on_request_start)
    trace_config.on_request_end.append(on_request_end)
    session = ClientSession(trace_configs=[trace_config])
    session.latency_record = []
    return session
Пример #12
0
def setup_client_logger():
    async def on_request_start(session, trace_config_ctx, params):
        trace_config_ctx.start = asyncio.get_event_loop().time()
        logging.debug('Request [%s] (%s), headers: %s', params.method,
                      params.url, params.headers)

    async def on_request_end(session, trace_config_ctx, params):
        elapsed = asyncio.get_event_loop().time() - trace_config_ctx.start
        logging.debug(f'Response [%s] %d %s (%s) Take {elapsed:.3f} ms',
                      params.method, params.response.status,
                      params.response.reason, params.url)

    trace_config = TraceConfig()
    trace_config.on_request_start.append(on_request_start)
    trace_config.on_request_end.append(on_request_end)

    return trace_config
Пример #13
0
def http_trace_config():
    async def _on_request_start(session, trace_config_ctx, params):
        if not trace_config_ctx.trace_request_ctx: return
        trace_config_ctx.trace_request_ctx.start = session.loop.time()

    async def _on_request_end(session, trace_config_ctx, params):
        if not trace_config_ctx.trace_request_ctx: return
        elapsed = round(
            (session.loop.time() - trace_config_ctx.trace_request_ctx.start) *
            1000, 3)
        trace_config_ctx.trace_request_ctx.logger.info(
            f"flag:{trace_config_ctx.trace_request_ctx.flag}|spend:{elapsed}|context:|error:|error_detail:"
        )

    trace_config = TraceConfig()
    trace_config.on_request_start.append(_on_request_start)
    trace_config.on_request_end.append(_on_request_end)
    return trace_config
Пример #14
0
async def main(websocket_url, service_base_url, websocket_proxy=None):
    trace_config = TraceConfig()
    trace_config.on_request_start.append(on_request_start)

    async with aiohttp.ClientSession() as gateway_server_session, \
            aiohttp.ClientSession(trace_configs=[trace_config]) as service_session:
        while True:
            try:
                await connect_ws(gateway_server_session,
                                 websocket_url,
                                 service_session,
                                 service_base_url,
                                 websocket_proxy=websocket_proxy)
            except asyncio.CancelledError:
                break
            except Exception as e:
                logger_proxy_client.error(e)

            logger_proxy_client.info(
                'WS DISCONNECTED: waiting 10 seconds for reconnect')
            await asyncio.sleep(10)
Пример #15
0
async def test_request_tracing(loop, test_client):

    on_request_start = mock.Mock(side_effect=asyncio.coroutine(mock.Mock()))
    on_request_end = mock.Mock(side_effect=asyncio.coroutine(mock.Mock()))
    on_request_redirect = mock.Mock(side_effect=asyncio.coroutine(mock.Mock()))
    on_connection_create_start = mock.Mock(
        side_effect=asyncio.coroutine(mock.Mock()))
    on_connection_create_end = mock.Mock(
        side_effect=asyncio.coroutine(mock.Mock()))

    async def redirector(request):
        raise web.HTTPFound(location=URL('/redirected'))

    async def redirected(request):
        return web.Response()

    trace_config = TraceConfig()

    trace_config.on_request_start.append(on_request_start)
    trace_config.on_request_end.append(on_request_end)
    trace_config.on_request_redirect.append(on_request_redirect)
    trace_config.on_connection_create_start.append(on_connection_create_start)
    trace_config.on_connection_create_end.append(on_connection_create_end)

    app = web.Application()
    app.router.add_get('/redirector', redirector)
    app.router.add_get('/redirected', redirected)

    client = await test_client(app, trace_configs=[trace_config])

    await client.get('/redirector', data="foo")

    assert on_request_start.called
    assert on_request_end.called
    assert on_request_redirect.called
    assert on_connection_create_start.called
    assert on_connection_create_end.called
Пример #16
0
    def __init__(self,
                 headers=None,
                 cookies=None,
                 raise_for_status=True,
                 timeout=60,
                 debug=False,
                 trace_configs=None,
                 **kwargs):
        self._url_base = "https://www.artstation.com"

        headers = headers or {}
        headers.update({
            'User-Agent': useragent_generator.random,
            'Host': 'www.artstation.com',
            'Origin': 'https://www.artstation.com',
            'Referer': 'https://www.artstation.com',
            'Accept': 'application/json, text/plain, */*',
            'Content-Type': 'application/json'
        })

        trace_configs = trace_configs or []

        if debug:
            t_conf = TraceConfig()
            t_conf.on_request_start.append(self.on_debug_req_start)
            t_conf.on_request_end.append(self.on_debug_req_end)
            trace_configs.append(t_conf)

        timeout = ClientTimeout(total=timeout)

        self.session = ClientSession(headers=headers,
                                     cookies=cookies,
                                     raise_for_status=raise_for_status,
                                     timeout=timeout,
                                     trace_configs=trace_configs,
                                     **kwargs)
Пример #17
0

async def on_request_end(session: ClientSession,
                         trace_config_ctx: SimpleNamespace,
                         params: TraceRequestEndParams):
    status: int = params.response.status
    if status == 200:
        logger.debug(f"[{status}]")
    else:
        method, url = params.method, params.url
        logger.debug(f"Ending request {method}: {url} with status [{status}]")


async def on_request_exception(session: ClientSession,
                               trace_config_ctx: SimpleNamespace,
                               params: TraceRequestExceptionParams):
    method, url, exception = params.method, params.url, type(params.exception)
    logger.error(f"{method}: {url} raised {exception.__name__}")


async def on_request_redirect(session: ClientSession,
                              trace_config_ctx: SimpleNamespace,
                              params: TraceRequestRedirectParams):
    logger.info(f"Redirected to {params.response.url}")

trace_config = TraceConfig()
trace_config.on_request_start.append(on_request_start)
trace_config.on_request_end.append(on_request_end)
trace_config.on_request_redirect.append(on_request_redirect)
trace_config.on_request_exception.append(on_request_exception)
Пример #18
0
    async def do(self, target: Target):
        """
        сопрограмма, осуществляет подключение к Target, отправку и прием данных, формирует результата в виде dict
        """
        def return_ip_from_deep(sess, response) -> str:
            try:
                ip_port = response.connection.transport.get_extra_info(
                    'peername')
                if is_ip(ip_port[0]):
                    return ip_port[0]
            except BaseException:
                pass
            try:
                _tmp_conn_key = sess.connector._conns.items()
                for k, v in _tmp_conn_key:
                    _h = v[0][0]
                    ip_port = _h.transport.get_extra_info('peername')
                    if is_ip(ip_port[0]):
                        return ip_port[0]
            except BaseException:
                pass
            return ''

        def update_line(json_record, target):
            json_record['ip'] = target.ip
            # json_record['ip_v4_int'] = int(ip_address(target.ip))
            # json_record['datetime'] = datetime.datetime.utcnow()
            # json_record['port'] = int(target.port)
            return json_record

        async with self.semaphore:
            result = None
            timeout = ClientTimeout(total=target.total_timeout)

            # region tmp disable
            trace_config = TraceConfig()
            trace_config.on_request_start.append(on_request_start)
            trace_config.on_request_end.append(on_request_end)
            # endregion
            resolver = AsyncResolver(nameservers=['8.8.8.8', '8.8.4.4'])
            # resolver = None
            # https://github.com/aio-libs/aiohttp/issues/2228  - closed
            if target.ssl_check:

                conn = TCPConnector(
                    ssl=False,
                    family=2,  # need set current family (only IPv4)
                    limit_per_host=0,
                    resolver=resolver)
                session = ClientSession(timeout=timeout,
                                        connector=conn,
                                        response_class=WrappedResponseClass,
                                        trace_configs=[trace_config])
                simple_zero_sleep = 0.250
            else:
                simple_zero_sleep = 0.001
                session = ClientSession(
                    connector=TCPConnector(
                        limit_per_host=0,
                        family=2,  # need set current family (only IPv4)
                        resolver=resolver),
                    timeout=timeout,
                    trace_configs=[trace_config])
            selected_proxy_connection = None
            try:
                selected_proxy_connection = next(
                    self.app_config.proxy_connections)
            except:
                pass
            try:
                async with session.request(
                        target.method,
                        target.url,
                        timeout=timeout,
                        headers=target.headers,
                        cookies=target.cookies,
                        allow_redirects=target.allow_redirects,
                        data=target.payload,
                        proxy=selected_proxy_connection,
                        trace_request_ctx=self.trace_request_ctx) as response:
                    _default_record = create_template_struct(target)
                    if target.ssl_check:
                        cert = convert_bytes_to_cert(response.peer_cert)
                        if not self.app_config.without_certraw:
                            _default_record['data']['http']['result'][
                                'response']['request']['tls_log'][
                                    'handshake_log']['server_certificates'][
                                        'certificate']['raw'] = b64encode(
                                            response.peer_cert).decode('utf-8')
                        if cert:
                            _default_record['data']['http']['result'][
                                'response']['request']['tls_log'][
                                    'handshake_log']['server_certificates'][
                                        'certificate']['parsed'] = cert
                    _default_record['data']['http']['status'] = "success"
                    _default_record['data']['http']['result']['response'][
                        'status_code'] = response.status
                    # region
                    _header = {}
                    for key in response.headers:
                        _header[key.lower().replace(
                            '-', '_')] = response.headers.getall(key)
                    _default_record['data']['http']['result']['response'][
                        'headers'] = _header
                    # endregion
                    if target.method in [
                            'GET', 'POST', 'PUT', 'DELETE', 'UPDATE'
                    ]:
                        buffer = b""
                        try:
                            read_c = asyncio.wait_for(
                                read_http_content(response, n=target.max_size),
                                timeout=target.total_timeout)
                            buffer = await read_c
                        except Exception as e:
                            pass
                        else:
                            if filter_bytes(buffer, target):
                                _default_record['data']['http']['result'][
                                    'response']['content_length'] = len(buffer)
                                _default_record['data']['http']['result'][
                                    'response']['body'] = ''
                                try:
                                    _default_record['data']['http']['result'][
                                        'response']['body'] = buffer.decode()
                                except Exception as e:
                                    pass
                                if not self.app_config.without_base64:
                                    try:
                                        _base64_data = b64encode(
                                            buffer).decode('utf-8')
                                        _default_record['data']['http'][
                                            'result']['response'][
                                                'body_raw'] = _base64_data
                                    except Exception as e:
                                        pass
                                if not self.app_config.without_hashs:
                                    try:
                                        hashs = {
                                            'sha256': sha256,
                                            'sha1': sha1,
                                            'md5': md5
                                        }
                                        for namehash, func in hashs.items():
                                            hm = func()
                                            hm.update(buffer)
                                            _default_record['data']['http'][
                                                'result']['response'][
                                                    f'body_{namehash}'] = hm.hexdigest(
                                                    )
                                    except Exception as e:
                                        pass
                                result = update_line(_default_record, target)
                            else:
                                # TODO: добавить статус success-not-contain для обозначения того,
                                #  что сервис найден, но не попал под фильтр?
                                result = create_error_template(
                                    target,
                                    error_str='',
                                    status_string='success-not-contain')
                    if result:
                        if not result['ip']:
                            result['ip'] = return_ip_from_deep(
                                session, response)
            except Exception as exp:
                error_str = ''
                try:
                    error_str = exp.strerror
                except:
                    pass
                result = create_error_template(target, error_str,
                                               type(exp).__name__)
                await asyncio.sleep(simple_zero_sleep)
                try:
                    await session.close()
                except:
                    pass
                try:
                    await conn.close()
                except:
                    pass
            if result:
                if 'duration' in self.trace_request_ctx:
                    request_duration = self.trace_request_ctx['duration']
                    result['data']['http']['duration'] = request_duration
                success = access_dot_path(result, "data.http.status")
                if self.stats:
                    if success == "success":
                        self.stats.count_good += 1
                    else:
                        self.stats.count_error += 1
                if not (self.app_config.status_code == CONST_ANY_STATUS):
                    response_status = access_dot_path(
                        result, 'data.http.result.response.status_code')
                    if response_status:
                        if self.app_config.status_code != response_status:
                            error_str = f'status code: {response_status} is not equal to filter: {self.app_config.status_code}'
                            result = create_error_template(
                                target,
                                error_str=error_str,
                                status_string='success-not-need-status')
                            self.stats.count_good -= 1
                            self.stats.count_error += 1
                line = None
                try:
                    if self.success_only:
                        if success == "success":
                            line = ujson_dumps(result)
                    else:
                        line = ujson_dumps(result)
                except Exception:
                    pass
                if line:
                    await self.output_queue.put(line)

            await asyncio.sleep(simple_zero_sleep)
            try:
                await session.close()
            except:
                pass
            try:
                await conn.close()
            except:
                pass
Пример #19
0
 def __init__(self):
     trace_config = TraceConfig()
     trace_config.on_request_start.append(self.on_request_start)
     trace_config.on_request_end.append(self.on_request_end)
     self.session = ClientSession(trace_configs=[trace_config])
     self.logger = AppLogger()
Пример #20
0
    async def _track_request_without_callback(cookie_dict: dict, url: str) -> str:
        """
        Make a request to the spotify api without redirects. No callback server needed.

        Args:
            cookie_dict: The cookie dict used for authentification
            url: The url of the spotify request

        Raises:
            SpotifyError: If there is a redirect between you and spotify
            SpotifyError: If there is an unknown error

        Returns: The code of spotify
        """

        code: Optional[str] = None

        async def redirect(_: ClientSession, __: SimpleNamespace, trace_request: TraceRequestRedirectParams) -> None:
            """
            Handler the redirect event aiohttp is firing

            Args:
                _: ClientSession
                __: SimpleNamespace
                trace_request: The current redirect request response
            """

            # Get the redirect url
            location: Optional[str] = trace_request.response.headers.get('location')

            # Parse the url
            local_url = parse.urlparse(location)
            query: dict = parse.parse_qs(local_url.query)

            # Check if code is the redirect url
            _code: Optional[List[str]] = query.get('code')

            if _code:
                nonlocal code
                code = _code[0]

        # Create a callback every time there is a redirect
        trace_config = TraceConfig()
        trace_config.on_request_redirect.append(redirect)

        try:
            # Make an api request to spotify
            async with ClientSession(cookies=cookie_dict, trace_configs=[trace_config]) as session:
                async with session.get(url) as resp:
                    response_text = await resp.text()
            await session.close()
        except ClientConnectorError:
            # Ignore the error in case no callback server is running
            pass

        if not code:
            message = f'The collection of the code did not work. Did the user already agree to the scopes' \
                      f' of your app? \n {response_text}'

            raise SpotifyError(ErrorMessage(message=message).__dict__)

        return code
    def get_trace_config(self):
        trace_config = TraceConfig()
        trace_config.on_request_start.append(self.on_request_start)
        trace_config.on_request_end.append(self.on_request_end)

        return trace_config
Пример #22
0
async def test_request_tracing(aiohttp_server):

    on_request_start = mock.Mock(side_effect=asyncio.coroutine(mock.Mock()))
    on_request_end = mock.Mock(side_effect=asyncio.coroutine(mock.Mock()))
    on_dns_resolvehost_start = mock.Mock(
        side_effect=asyncio.coroutine(mock.Mock()))
    on_dns_resolvehost_end = mock.Mock(
        side_effect=asyncio.coroutine(mock.Mock()))
    on_request_redirect = mock.Mock(side_effect=asyncio.coroutine(mock.Mock()))
    on_connection_create_start = mock.Mock(
        side_effect=asyncio.coroutine(mock.Mock()))
    on_connection_create_end = mock.Mock(
        side_effect=asyncio.coroutine(mock.Mock()))

    async def redirector(request):
        raise web.HTTPFound(location=URL('/redirected'))

    async def redirected(request):
        return web.Response()

    trace_config = TraceConfig()

    trace_config.on_request_start.append(on_request_start)
    trace_config.on_request_end.append(on_request_end)
    trace_config.on_request_redirect.append(on_request_redirect)
    trace_config.on_connection_create_start.append(on_connection_create_start)
    trace_config.on_connection_create_end.append(on_connection_create_end)
    trace_config.on_dns_resolvehost_start.append(on_dns_resolvehost_start)
    trace_config.on_dns_resolvehost_end.append(on_dns_resolvehost_end)

    app = web.Application()
    app.router.add_get('/redirector', redirector)
    app.router.add_get('/redirected', redirected)
    server = await aiohttp_server(app)

    class FakeResolver:
        _LOCAL_HOST = {0: '127.0.0.1', socket.AF_INET: '127.0.0.1'}

        def __init__(self, fakes):
            """fakes -- dns -> port dict"""
            self._fakes = fakes
            self._resolver = aiohttp.DefaultResolver()

        async def resolve(self, host, port=0, family=socket.AF_INET):
            fake_port = self._fakes.get(host)
            if fake_port is not None:
                return [{
                    'hostname': host,
                    'host': self._LOCAL_HOST[family],
                    'port': fake_port,
                    'family': socket.AF_INET,
                    'proto': 0,
                    'flags': socket.AI_NUMERICHOST
                }]
            else:
                return await self._resolver.resolve(host, port, family)

    resolver = FakeResolver({'example.com': server.port})
    connector = aiohttp.TCPConnector(resolver=resolver)
    client = aiohttp.ClientSession(connector=connector,
                                   trace_configs=[trace_config])

    await client.get('http://example.com/redirector', data="foo")

    assert on_request_start.called
    assert on_request_end.called
    assert on_dns_resolvehost_start.called
    assert on_dns_resolvehost_end.called
    assert on_request_redirect.called
    assert on_connection_create_start.called
    assert on_connection_create_end.called
    await client.close()