Beispiel #1
0
async def test_drive_cred_generation(reader: ReaderHelper, ui_server: UiServer, snapshot, config: Config, global_info: GlobalInfo, session: ClientSession, google):
    status = await reader.getjson("getstatus")
    assert len(status["snapshots"]) == 1
    assert global_info.credVersion == 0
    # Invalidate the drive creds, sync, then verify we see an error
    google.expireCreds()
    status = await reader.getjson("sync")
    assert status["last_error"]["error_type"] == ERROR_CREDS_EXPIRED

    # simulate the user going through the Drive authentication workflow
    auth_url = URL(config.get(Setting.AUTHENTICATE_URL)).with_query({
        "redirectbacktoken": reader.getUrl(True) + "token",
        "version": VERSION,
        "return": reader.getUrl(True)
    })
    async with session.get(auth_url) as resp:
        resp.raise_for_status()
        html = await resp.text()
        page = BeautifulSoup(html, 'html.parser')
        area = page.find("textarea")
        creds = str(area.getText()).strip()

    cred_url = URL(reader.getUrl(True) + "token").with_query({"creds": creds, "host": reader.getUrl(True)})
    async with session.get(cred_url) as resp:
        resp.raise_for_status()
        # verify we got redirected to the addon main page.
        assert resp.url == URL(reader.getUrl(True))
    await ui_server.sync(None)
    assert global_info._last_error is None
    assert global_info.credVersion == 1
class AIOResponseRedirectTest(TestCase):
    @asyncio.coroutine
    def setUp(self):
        self.url = "http://10.1.1.1:8080/redirect"
        self.session = ClientSession()
        super().setUp()

    @asyncio.coroutine
    def tearDown(self):
        close_result = self.session.close()
        if close_result is not None:
            yield from close_result
        super().tearDown()

    @aioresponses()
    @asyncio.coroutine
    def test_redirect_followed(self, rsps):
        rsps.get(
            self.url,
            status=307,
            headers={"Location": "https://httpbin.org"},
        )
        rsps.get("https://httpbin.org")
        response = yield from self.session.get(self.url, allow_redirects=True)
        self.assertEqual(response.status, 200)
        self.assertEqual(str(response.url), "https://httpbin.org")
        self.assertEqual(len(response.history), 1)
        self.assertEqual(str(response.history[0].url), self.url)

    @aioresponses()
    @asyncio.coroutine
    def test_redirect_missing_mocked_match(self, rsps):
        rsps.get(
            self.url,
            status=307,
            headers={"Location": "https://httpbin.org"},
        )
        with self.assertRaises(ClientConnectionError) as cm:
            response = yield from self.session.get(self.url,
                                                   allow_redirects=True)
        self.assertEqual(
            str(cm.exception),
            'Connection refused: GET http://10.1.1.1:8080/redirect')

    @aioresponses()
    @asyncio.coroutine
    def test_redirect_missing_location_header(self, rsps):
        rsps.get(self.url, status=307)
        response = yield from self.session.get(self.url, allow_redirects=True)
        self.assertEqual(str(response.url), self.url)
Beispiel #3
0
async def fetch_genderize(session: ClientSession, parameter: str) -> GenderizeResponse:
    """Fetches data asynchronously from Genderize API."""
    formatted_url = f"{GENDERIZE_API_URL}/{parameter}"

    async with session.get(formatted_url) as response:
        if valid_response(response, "Genderize"):
            return await response.json()
Beispiel #4
0
 async def _download_book(self, book: Book, path: Path,
                          session: ClientSession) -> str:
     f = open(path.joinpath(book.filename), "wb")
     async with session.get(url=book.download_url) as response:
         content = await response.text()
         soup = BeautifulSoup(content, features="html.parser")
         anchor_tag = soup.select_one("div#download h2 a[href]")
         url = anchor_tag["href"]
     async with session.get(url=url) as response:
         while True:
             data = await response.content.read(1024)
             if not data:
                 break
             f.write(data)
     f.close()
     return book.filename
Beispiel #5
0
async def _fetch_cached_response(session: ClientSession, headers: dict,
                                 url: str, params: dict,
                                 cache_expiry_seconds: int):
    cache = create_redis()
    key = f'{url}?{urlencode(params)}'
    try:
        cached_json = cache.get(key)
    except Exception as error:  # pylint: disable=broad-except
        cached_json = None
        logger.error(error)
    if cached_json:
        logger.info('redis cache hit %s', key)
        response_json = json.loads(cached_json.decode())
    else:
        logger.info('redis cache miss %s', key)
        async with session.get(url, headers=headers,
                               params=params) as response:
            response_json = await response.json()
        try:
            if response.status == 200:
                cache.set(key,
                          json.dumps(response_json).encode(),
                          ex=cache_expiry_seconds)
        except Exception as error:  # pylint: disable=broad-except
            logger.error(error)
    return response_json
Beispiel #6
0
class ProxiedClientSession:
    """A ClientSession that forwards requests through a custom proxy."""
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        self.proxy_url = kwargs.pop("proxy_url")

        self.permanent_headers = {
            "Proxy-Authorization-Key": kwargs.pop("authorization"),
            "Accept": "application/json",
        }
        self._session = ClientSession(*args, **kwargs)

    def get(self, url: str, *args: Any,
            **kwargs: Any) -> _RequestContextManager:
        headers = kwargs.pop("headers", {})
        headers.update(self.permanent_headers)
        headers["Requested-URI"] = url
        return self._session.get(self.proxy_url,
                                 headers=headers,
                                 *args,
                                 **kwargs)

    def post(self, url: str, *args: Any,
             **kwargs: Any) -> _RequestContextManager:
        headers = kwargs.pop("headers", {})
        headers.update(self.permanent_headers)
        headers["Requested-URI"] = url
        return self._session.post(self.proxy_url,
                                  headers=headers,
                                  *args,
                                  **kwargs)
async def download_to_file(client: ClientSession, url: str, save_to: str):
    async with client.get(url) as response:
        if response.status != 200:
            raise ValueError(f'Status {response.status} for {url}')

        with open(save_to, mode='wb') as f:
            f.write(await response.read())
Beispiel #8
0
 async def get_runners_tag_list(self, runner: ProjectRunner,
                                session: ClientSession) -> List:
     runner_id = runner._attrs["id"]
     url = f"{self.server}/api/v4/runners/{runner_id}"
     async with session.get(url, headers=self.headers) as response:
         runner_tag_list = await response.json()
         return runner_tag_list["tag_list"]
Beispiel #9
0
async def fetch(session: ClientSession, proxy: Proxy) -> None:
    url: str = 'https://2ip.ru/'

    try:
        async with session.get(url, proxy=str(proxy)) as response:
            content = response.content
            print(await content.read())
    except ClientProxyConnectionError as e:
        print(f"Proxy: '{proxy}' doesn't work: {e}")
class AIOResponsesRaiseForStatusSessionTestCase(TestCase):
    """Test case for sessions with raise_for_status=True.

    This flag, introduced in aiohttp v2.0.0, automatically calls
    `raise_for_status()`.
    It is overridden by the `raise_for_status` argument of the request since
    aiohttp v3.4.a0.

    """
    use_default_loop = False

    @asyncio.coroutine
    def setUp(self):
        self.url = 'http://example.com/api?foo=bar#fragment'
        self.session = ClientSession(raise_for_status=True)
        super().setUp()

    @asyncio.coroutine
    def tearDown(self):
        close_result = self.session.close()
        if close_result is not None:
            yield from close_result
        super().tearDown()

    @aioresponses()
    @asyncio.coroutine
    def test_raise_for_status(self, m):
        m.get(self.url, status=400)
        with self.assertRaises(ClientResponseError) as cm:
            yield from self.session.get(self.url)
        self.assertEqual(cm.exception.message, http.RESPONSES[400][0])

    @aioresponses()
    @asyncio.coroutine
    @skipIf(condition=AIOHTTP_VERSION < '3.4.0',
            reason='aiohttp<3.4.0 does not support raise_for_status '
                   'arguments for requests')
    def test_do_not_raise_for_status(self, m):
        m.get(self.url, status=400)
        response = yield from self.session.get(self.url,
                                               raise_for_status=False)

        self.assertEqual(response.status, 400)
async def load_coordinates(exchange_office: ExchangeOffice,
                           client: ClientSession) -> Tuple[Decimal, Decimal]:
    async with client.get(SELECTBY_EXCHANGE_OFFICE_PAGE_PATTERN %
                          exchange_office.identifier) as response:
        if response.status != 200:
            raise ValueError(f'Unexpected http status: {response.status} '
                             f'for office {exchange_office.identifier}')

        content = (await response.read()).decode('windows-1251')
        return extract_coordinates_from_page(content)
 async def fetch_queue_bindings(
     self, http_session: ClientSession, exchange: str, queue: str
 ) -> List[Metric]:
     vhost = self.data_vhost_quoted()
     async with http_session.get(
         self.http_api_url.with_path(
             f"/api/bindings/{vhost}/e/{exchange}/q/{queue}/", encoded=True
         )
     ) as response:
         bindings_json = await response.json()
         if not response.ok:
             error = bindings_json.get("error")
             raise RuntimeError(f"RabbitMQ API returned an error: {error}")
         return [binding["routing_key"] for binding in bindings_json]
Beispiel #13
0
 async def fetch(self, session: ClientSession, url: str) -> Any:
     try:
         async with session.get(url, timeout=15) as response:
             json_resp = await response.json()
             self.row_processing(json_resp)
     except ClientResponseError as e:
         logging.warning(e.code)
     except asyncio.TimeoutError:
         logging.warning("Timeout")
     except Exception as e:
         logging.warning(e)
     else:
         return json_resp
     return
 async def fetch_queue_bindings(self, http_session: ClientSession,
                                exchange: str, queue: str) -> List[Metric]:
     vhost = self.data_vhost_quoted()
     get_url = self.http_api_url.with_path(
         f"/api/bindings/{vhost}/e/{exchange}/q/{queue}/", encoded=True)
     logger.info(
         "Fetching queue bindings from {!r}",
         get_url.with_user("***").with_password("***").human_repr(),
     )
     async with http_session.get(get_url) as response:
         bindings_json = await response.json()
         if not response.ok:
             error = bindings_json.get("error")
             raise RuntimeError(f"RabbitMQ API returned an error: {error}")
         return [binding["routing_key"] for binding in bindings_json]
Beispiel #15
0
async def get_pokemon(session: ClientSession, url: str):
    """asyncio用のポケモンゲット処理

    Parameters
    ----------
    session : ClientSession
        aiohttpクライアントセッションオブジェクト
    url : str
        ポケモンAPI取得先URL

    Returns
    -------
    str
        取得ポケモン
    """
    async with session.get(url, proxy=http_proxy) as resp:
        pokemon = await resp.json()
        return pokemon['name']
async def test_drive_cred_generation(reader, ui_server, snapshot, server,
                                     config: Config, global_info: GlobalInfo,
                                     session: ClientSession):
    status = await reader.getjson("getstatus")
    assert len(status["snapshots"]) == 1
    assert global_info.credVersion == 0
    # Invalidate the drive creds, sync, then verify we see an error
    server.expireCreds()
    status = await reader.getjson("sync")
    assert status["last_error"]["error_type"] == ERROR_CREDS_EXPIRED

    # simulate the user going through the Drive authentication workflow
    async with session.get(
        config.get(Setting.AUTHENTICATE_URL) + "?redirectbacktoken=" +
        quote(reader.getUrl(True) + "token")) as resp:
        resp.raise_for_status()
    status = (await
              reader.getjson("sync"))["last_error"] is ERROR_CREDS_EXPIRED
    assert global_info.credVersion == 1
Beispiel #17
0
async def fetch_paged_response_generator(
        session: ClientSession,
        headers: dict,
        query_builder: BuildQuery,
        content_key: str,
        use_cache: bool = False,
        cache_expiry_seconds: int = 86400) -> AsyncGenerator[dict, None]:
    """ Asynchronous generator for iterating through responses from the API.
    The response is a paged response, but this generator abstracts that away.
    """
    # We don't know how many pages until our first call - so we assume one page to start with.
    total_pages = 1
    page_count = 0
    while page_count < total_pages:
        # Build up the request URL.
        url, params = query_builder.query(page_count)
        logger.debug('loading page %d...', page_count)
        if use_cache and config.get('REDIS_USE') == 'True':
            # We've been told and configured to use the redis cache.
            response_json = await _fetch_cached_response(
                session, headers, url, params, cache_expiry_seconds)
        else:
            async with session.get(url, headers=headers,
                                   params=params) as response:
                response_json = await response.json()
                logger.debug('done loading page %d.', page_count)

        # keep this code around for dumping responses to a json file - useful for when you're writing
        # tests to grab actual responses to use in fixtures.
        # import base64
        # TODO: write a beter way to make a temporary filename
        # fname = 'thing_{}_{}.json'.format(base64.urlsafe_b64encode(url.encode()), random.randint(0, 1000))
        # with open(fname, 'w') as f:
        #     json.dump(response_json, f)

        # Update the total page count.
        total_pages = response_json['page'][
            'totalPages'] if 'page' in response_json else 1
        for response_object in response_json['_embedded'][content_key]:
            yield response_object
        # Keep track of our page count.
        page_count = page_count + 1
Beispiel #18
0
async def fetch_raw_dailies_for_all_stations(
        session: ClientSession, headers: dict,
        time_of_interest: datetime) -> list:
    """ Fetch the noon values(observations and forecasts) for a given time, for all weather stations.
    """
    # We don't know how many pages until our first call - so we assume one page to start with.
    total_pages = 1
    page_count = 0
    hourlies = []
    while page_count < total_pages:
        # Build up the request URL.
        url, params = prepare_fetch_dailies_for_all_stations_query(
            time_of_interest, page_count)
        # Get dailies
        async with session.get(url, params=params,
                               headers=headers) as response:
            dailies_json = await response.json()
            total_pages = dailies_json['page']['totalPages']
            hourlies.extend(dailies_json['_embedded']['dailies'])
        page_count = page_count + 1
    return hourlies
Beispiel #19
0
async def fetch_access_token(session: ClientSession) -> dict:
    """ Fetch an access token for WFWX Fireweather API
    """
    logger.debug('fetching access token...')
    password = config.get('WFWX_SECRET')
    user = config.get('WFWX_USER')
    auth_url = config.get('WFWX_AUTH_URL')
    cache = create_redis()
    # NOTE: Consider using a hashed version of the password as part of the key.
    params = {'user': user}
    key = f'{auth_url}?{urlencode(params)}'
    try:
        cached_json = cache.get(key)
    except Exception as error:  # pylint: disable=broad-except
        cached_json = None
        logger.error(error)
    if cached_json:
        logger.info('redis cache hit %s', auth_url)
        response_json = json.loads(cached_json.decode())
    else:
        logger.info('redis cache miss %s', auth_url)
        async with session.get(auth_url,
                               auth=BasicAuth(login=user,
                                              password=password)) as response:
            response_json = await response.json()
            try:
                if response.status == 200:
                    # We expire when the token expires, or 10 minutes, whichever is less.
                    # NOTE: only caching for 10 minutes right now, since we aren't handling cases
                    # where the token is invalidated.
                    redis_auth_cache_expiry: Final = int(
                        config.get('REDIS_AUTH_CACHE_EXPIRY', 600))
                    expires = min(response_json['expires_in'],
                                  redis_auth_cache_expiry)
                    cache.set(key,
                              json.dumps(response_json).encode(),
                              ex=expires)
            except Exception as error:  # pylint: disable=broad-except
                logger.error(error)
    return response_json
Beispiel #20
0
    async def _get_latest_version_from_repository(
        self, session: ClientSession, url: str
    ) -> str:
        async with session.get(url) as response:
            if response.status == 404:
                raise ValueError("Dependency doesn't exist on repository")

            response.raise_for_status()

            page_content = await response.text()

        # Parse all filenames as version
        versions = []
        all_are_prerelease = True
        for filename in self._get_filenames_from_simple_page(page_content):
            if filename.endswith((".egg", ".whl")):
                version = self._get_version_from_wheel_filename(filename)
            else:
                version = self._get_version_from_source_filename(filename)

            try:
                parsed_version = version_parser.parse(version)
            except ValueError:
                continue

            if not parsed_version.is_prerelease:
                all_are_prerelease = False

            versions.append(parsed_version)

        # If all prerelease get latest one
        if all_are_prerelease and len(versions) != 0:
            return str(versions[0])

        for dep_version in versions:
            if not dep_version.is_prerelease:
                return str(dep_version)

        raise ValueError(f"Cannot check version for {url}")
Beispiel #21
0
async def get_full_depth(symbol: str, session: ClientSession,
                         database: Database, asset_type: AssetType):
    limit = CONFIG.full_fetch_limit
    if asset_type == AssetType.SPOT:
        url = f"https://api.binance.com/api/v3/depth?symbol={symbol}&limit={limit}"
    elif asset_type == AssetType.USD_M:
        url = f"https://fapi.binance.com/fapi/v1/depth?symbol={symbol}&limit={limit}"
    elif asset_type == AssetType.COIN_M:
        url = f"https://dapi.binance.com/dapi/v1/depth?symbol={symbol}&limit={limit}"
    async with session.get(url) as resp:
        resp_json = await resp.json()
        msg = DepthSnapshotMsg(**resp_json)
        snapshot = DepthSnapshot(
            timestamp=datetime.utcnow(),
            last_update_id=msg.lastUpdateId,
            bids_quantity=[pairs[1] for pairs in msg.bids],
            bids_price=[pairs[0] for pairs in msg.bids],
            asks_quantity=[pairs[1] for pairs in msg.asks],
            asks_price=[pairs[0] for pairs in msg.asks],
            symbol=asset_type.value + symbol,
        )
        database.insert([snapshot])
Beispiel #22
0
async def Mavereckki(server_name: str, server_link: str,
                     client: ClientSession):
    async with client.get(server_link.replace('/embed/',
                                              '/api/source/')) as resp:
        data: dict = await resp.json()
    subtitle_url = []
    if data.get('subtitles') and len(data.get('subtitles')):
        for idx, sub in enumerate(data.get('subtitles')):
            if 'eng' in sub.get('name').lower():
                subtitle_url.append(
                    server_link.split('/embed/')[0] + sub.get('src'))
                del data.get('subtitles')[idx]
        subtitle_url += [
            server_link.split('/embed/')[0] + i.get('src')
            for i in data.get('subtitles')
        ]
    return [
        server_name,
        server_link.split('/embed/')[0] + data.get('hls'), {
            'Referer': server_link,
            'Subtitle': subtitle_url
        }
    ]
Beispiel #23
0
async def fetch_hourlies(
        session: ClientSession, raw_station: dict, headers: dict,
        start_timestamp: datetime, end_timestamp: datetime, use_cache: bool,
        eco_division: EcodivisionSeasons) -> WeatherStationHourlyReadings:
    """ Fetch hourly weather readings for the specified time range for a give station """
    logger.debug('fetching hourlies for %s(%s)', raw_station['displayLabel'],
                 raw_station['stationCode'])

    url, params = prepare_fetch_hourlies_query(raw_station, start_timestamp,
                                               end_timestamp)

    cache_expiry_seconds = cache_expiry_seconds = config.get(
        'REDIS_HOURLIES_BY_STATION_CODE_CACHE_EXPIRY', 300)

    # Get hourlies
    if use_cache and cache_expiry_seconds is not None and config.get(
            'REDIS_USE') == 'True':
        hourlies_json = await _fetch_cached_response(session, headers, url,
                                                     params,
                                                     cache_expiry_seconds)
    else:
        async with session.get(url, params=params,
                               headers=headers) as response:
            hourlies_json = await response.json()

    hourlies = []
    for hourly in hourlies_json['_embedded']['hourlies']:
        # We only accept "ACTUAL" values
        if hourly.get('hourlyMeasurementTypeCode', '').get('id') == 'ACTUAL':
            hourlies.append(parse_hourly(hourly))

    logger.debug('fetched %d hourlies for %s(%s)', len(hourlies),
                 raw_station['displayLabel'], raw_station['stationCode'])

    return WeatherStationHourlyReadings(values=hourlies,
                                        station=parse_station(
                                            raw_station, eco_division))
Beispiel #24
0
class AIOResponsesTestCase(TestCase):
    use_default_loop = False

    @asyncio.coroutine
    def setUp(self):
        self.url = 'http://example.com/api'
        self.session = ClientSession()
        super().setUp()

    @asyncio.coroutine
    def tearDown(self):
        self.session.close()
        super().tearDown()

    @data(
        hdrs.METH_HEAD,
        hdrs.METH_GET,
        hdrs.METH_POST,
        hdrs.METH_PUT,
        hdrs.METH_PATCH,
        hdrs.METH_DELETE,
        hdrs.METH_OPTIONS,
    )
    @patch('aioresponses.aioresponses.add')
    @fail_on(unused_loop=False)
    def test_shortcut_method(self, http_method, mocked):
        with aioresponses() as m:
            getattr(m, http_method.lower())(self.url)
            mocked.assert_called_once_with(self.url, method=http_method)

    @aioresponses()
    def test_returned_instance(self, m):
        m.get(self.url)
        response = self.loop.run_until_complete(self.session.get(self.url))
        self.assertIsInstance(response, ClientResponse)

    @aioresponses()
    @asyncio.coroutine
    def test_returned_instance_and_status_code(self, m):
        m.get(self.url, status=204)
        response = yield from self.session.get(self.url)
        self.assertIsInstance(response, ClientResponse)
        self.assertEqual(response.status, 204)

    @aioresponses()
    @asyncio.coroutine
    def test_returned_response_headers(self, m):
        m.get(self.url,
              content_type='text/html',
              headers={'Connection': 'keep-alive'})
        response = yield from self.session.get(self.url)

        self.assertEqual(response.headers['Connection'], 'keep-alive')
        self.assertEqual(response.headers[hdrs.CONTENT_TYPE], 'text/html')

    @aioresponses()
    def test_method_dont_match(self, m):
        m.get(self.url)
        with self.assertRaises(ClientConnectionError):
            self.loop.run_until_complete(self.session.post(self.url))

    @aioresponses()
    @asyncio.coroutine
    def test_streaming(self, m):
        m.get(self.url, body='Test')
        resp = yield from self.session.get(self.url)
        content = yield from resp.content.read()
        self.assertEqual(content, b'Test')

    @aioresponses()
    @asyncio.coroutine
    def test_streaming_up_to(self, m):
        m.get(self.url, body='Test')
        resp = yield from self.session.get(self.url)
        content = yield from resp.content.read(2)
        self.assertEqual(content, b'Te')
        content = yield from resp.content.read(2)
        self.assertEqual(content, b'st')

    @asyncio.coroutine
    def test_mocking_as_context_manager(self):
        with aioresponses() as aiomock:
            aiomock.add(self.url, payload={'foo': 'bar'})
            resp = yield from self.session.get(self.url)
            self.assertEqual(resp.status, 200)
            payload = yield from resp.json()
            self.assertDictEqual(payload, {'foo': 'bar'})

    def test_mocking_as_decorator(self):
        @aioresponses()
        def foo(loop, m):
            m.add(self.url, payload={'foo': 'bar'})

            resp = loop.run_until_complete(self.session.get(self.url))
            self.assertEqual(resp.status, 200)
            payload = loop.run_until_complete(resp.json())
            self.assertDictEqual(payload, {'foo': 'bar'})

        foo(self.loop)

    @asyncio.coroutine
    def test_passing_argument(self):
        @aioresponses(param='mocked')
        @asyncio.coroutine
        def foo(mocked):
            mocked.add(self.url, payload={'foo': 'bar'})
            resp = yield from self.session.get(self.url)
            self.assertEqual(resp.status, 200)

        yield from foo()

    @fail_on(unused_loop=False)
    def test_mocking_as_decorator_wrong_mocked_arg_name(self):
        @aioresponses(param='foo')
        def foo(bar):
            # no matter what is here it should raise an error
            pass

        with self.assertRaises(TypeError) as cm:
            foo()
        exc = cm.exception
        self.assertIn("foo() got an unexpected keyword argument 'foo'",
                      str(exc))

    @asyncio.coroutine
    def test_unknown_request(self):
        with aioresponses() as aiomock:
            aiomock.add(self.url, payload={'foo': 'bar'})
            with self.assertRaises(ClientConnectionError):
                yield from self.session.get('http://example.com/foo')

    @asyncio.coroutine
    def test_raising_custom_error(self):
        with aioresponses() as aiomock:
            aiomock.get(self.url, exception=HttpProcessingError(message='foo'))
            with self.assertRaises(HttpProcessingError):
                yield from self.session.get(self.url)

    @asyncio.coroutine
    def test_multiple_requests(self):
        with aioresponses() as m:
            m.get(self.url, status=200)
            m.get(self.url, status=201)
            m.get(self.url, status=202)
            resp = yield from self.session.get(self.url)
            self.assertEqual(resp.status, 200)
            resp = yield from self.session.get(self.url)
            self.assertEqual(resp.status, 201)
            resp = yield from self.session.get(self.url)
            self.assertEqual(resp.status, 202)

            key = ('GET', self.url)
            self.assertIn(key, m.requests)
            self.assertEqual(len(m.requests[key]), 3)
            self.assertEqual(m.requests[key][0].args, tuple())
            self.assertEqual(m.requests[key][0].kwargs,
                             {'allow_redirects': True})

    @asyncio.coroutine
    def test_address_as_instance_of_url_combined_with_pass_through(self):
        external_api = 'http://httpbin.org/status/201'

        @asyncio.coroutine
        def doit():
            api_resp = yield from self.session.get(self.url)
            # we have to hit actual url,
            # otherwise we do not test pass through option properly
            ext_rep = yield from self.session.get(URL(external_api))
            return api_resp, ext_rep

        with aioresponses(passthrough=[external_api]) as m:
            m.get(self.url, status=200)
            api, ext = yield from doit()

            self.assertEqual(api.status, 200)
            self.assertEqual(ext.status, 201)
class AIOResponsesTestCase(AsyncTestCase):
    async def setup(self):
        self.url = 'http://example.com/api?foo=bar#fragment'
        self.session = ClientSession()

    async def teardown(self):
        close_result = self.session.close()
        if close_result is not None:
            await close_result

    def run_async(self, coroutine: Union[Coroutine, Generator]):
        return self.loop.run_until_complete(coroutine)

    async def request(self, url: str):
        return await self.session.get(url)

    @data(
        hdrs.METH_HEAD,
        hdrs.METH_GET,
        hdrs.METH_POST,
        hdrs.METH_PUT,
        hdrs.METH_PATCH,
        hdrs.METH_DELETE,
        hdrs.METH_OPTIONS,
    )
    @patch('aioresponses.aioresponses.add')
    @fail_on(unused_loop=False)
    def test_shortcut_method(self, http_method, mocked):
        with aioresponses() as m:
            getattr(m, http_method.lower())(self.url)
            mocked.assert_called_once_with(self.url, method=http_method)

    @aioresponses()
    def test_returned_instance(self, m):
        m.get(self.url)
        response = self.run_async(self.session.get(self.url))
        self.assertIsInstance(response, ClientResponse)

    @aioresponses()
    async def test_returned_instance_and_status_code(self, m):
        m.get(self.url, status=204)
        response = await self.session.get(self.url)
        self.assertIsInstance(response, ClientResponse)
        self.assertEqual(response.status, 204)

    @aioresponses()
    async def test_returned_response_headers(self, m):
        m.get(self.url,
              content_type='text/html',
              headers={'Connection': 'keep-alive'})
        response = await self.session.get(self.url)

        self.assertEqual(response.headers['Connection'], 'keep-alive')
        self.assertEqual(response.headers[hdrs.CONTENT_TYPE], 'text/html')

    @aioresponses()
    async def test_returned_response_cookies(self, m):
        m.get(self.url, headers={'Set-Cookie': 'cookie=value'})
        response = await self.session.get(self.url)

        self.assertEqual(response.cookies['cookie'].value, 'value')

    @aioresponses()
    async def test_returned_response_raw_headers(self, m):
        m.get(self.url,
              content_type='text/html',
              headers={'Connection': 'keep-alive'})
        response = await self.session.get(self.url)
        expected_raw_headers = ((hdrs.CONTENT_TYPE.encode(), b'text/html'),
                                (b'Connection', b'keep-alive'))

        self.assertEqual(response.raw_headers, expected_raw_headers)

    @aioresponses()
    async def test_raise_for_status(self, m):
        m.get(self.url, status=400)
        with self.assertRaises(ClientResponseError) as cm:
            response = await self.session.get(self.url)
            response.raise_for_status()
        self.assertEqual(cm.exception.message, http.RESPONSES[400][0])

    @aioresponses()
    @skipIf(condition=AIOHTTP_VERSION < '3.4.0',
            reason='aiohttp<3.4.0 does not support raise_for_status '
            'arguments for requests')
    async def test_request_raise_for_status(self, m):
        m.get(self.url, status=400)
        with self.assertRaises(ClientResponseError) as cm:
            await self.session.get(self.url, raise_for_status=True)
        self.assertEqual(cm.exception.message, http.RESPONSES[400][0])

    @aioresponses()
    async def test_returned_instance_and_params_handling(self, m):
        expected_url = 'http://example.com/api?foo=bar&x=42#fragment'
        m.get(expected_url)
        response = await self.session.get(self.url, params={'x': 42})
        self.assertIsInstance(response, ClientResponse)
        self.assertEqual(response.status, 200)

        expected_url = 'http://example.com/api?x=42#fragment'
        m.get(expected_url)
        response = await self.session.get('http://example.com/api#fragment',
                                          params={'x': 42})
        self.assertIsInstance(response, ClientResponse)
        self.assertEqual(response.status, 200)

    @aioresponses()
    def test_method_dont_match(self, m):
        m.get(self.url)
        with self.assertRaises(ClientConnectionError):
            self.run_async(self.session.post(self.url))

    @aioresponses()
    async def test_streaming(self, m):
        m.get(self.url, body='Test')
        resp = await self.session.get(self.url)
        content = await resp.content.read()
        self.assertEqual(content, b'Test')

    @aioresponses()
    async def test_streaming_up_to(self, m):
        m.get(self.url, body='Test')
        resp = await self.session.get(self.url)
        content = await resp.content.read(2)
        self.assertEqual(content, b'Te')
        content = await resp.content.read(2)
        self.assertEqual(content, b'st')

    async def test_mocking_as_context_manager(self):
        with aioresponses() as aiomock:
            aiomock.add(self.url, payload={'foo': 'bar'})
            resp = await self.session.get(self.url)
            self.assertEqual(resp.status, 200)
            payload = await resp.json()
            self.assertDictEqual(payload, {'foo': 'bar'})

    def test_mocking_as_decorator(self):
        @aioresponses()
        def foo(loop, m):
            m.add(self.url, payload={'foo': 'bar'})

            resp = loop.run_until_complete(self.session.get(self.url))
            self.assertEqual(resp.status, 200)
            payload = loop.run_until_complete(resp.json())
            self.assertDictEqual(payload, {'foo': 'bar'})

        foo(self.loop)

    async def test_passing_argument(self):
        @aioresponses(param='mocked')
        async def foo(mocked):
            mocked.add(self.url, payload={'foo': 'bar'})
            resp = await self.session.get(self.url)
            self.assertEqual(resp.status, 200)

        await foo()

    @fail_on(unused_loop=False)
    def test_mocking_as_decorator_wrong_mocked_arg_name(self):
        @aioresponses(param='foo')
        def foo(bar):
            # no matter what is here it should raise an error
            pass

        with self.assertRaises(TypeError) as cm:
            foo()
        exc = cm.exception
        self.assertIn("foo() got an unexpected keyword argument 'foo'",
                      str(exc))

    async def test_unknown_request(self):
        with aioresponses() as aiomock:
            aiomock.add(self.url, payload={'foo': 'bar'})
            with self.assertRaises(ClientConnectionError):
                await self.session.get('http://example.com/foo')

    async def test_raising_exception(self):
        with aioresponses() as aiomock:
            url = 'http://example.com/Exception'
            aiomock.get(url, exception=Exception)
            with self.assertRaises(Exception):
                await self.session.get(url)

            url = 'http://example.com/Exception_object'
            aiomock.get(url, exception=Exception())
            with self.assertRaises(Exception):
                await self.session.get(url)

            url = 'http://example.com/BaseException'
            aiomock.get(url, exception=BaseException)
            with self.assertRaises(BaseException):
                await self.session.get(url)

            url = 'http://example.com/BaseException_object'
            aiomock.get(url, exception=BaseException())
            with self.assertRaises(BaseException):
                await self.session.get(url)

            url = 'http://example.com/CancelError'
            aiomock.get(url, exception=CancelledError)
            with self.assertRaises(CancelledError):
                await self.session.get(url)

            url = 'http://example.com/TimeoutError'
            aiomock.get(url, exception=TimeoutError)
            with self.assertRaises(TimeoutError):
                await self.session.get(url)

            url = 'http://example.com/HttpProcessingError'
            aiomock.get(url, exception=HttpProcessingError(message='foo'))
            with self.assertRaises(HttpProcessingError):
                await self.session.get(url)

    async def test_multiple_requests(self):
        """Ensure that requests are saved the way they would have been sent."""
        with aioresponses() as m:
            m.get(self.url, status=200)
            m.get(self.url, status=201)
            m.get(self.url, status=202)
            json_content_as_ref = [1]
            resp = await self.session.get(self.url, json=json_content_as_ref)
            self.assertEqual(resp.status, 200)
            json_content_as_ref[:] = [2]
            resp = await self.session.get(self.url, json=json_content_as_ref)
            self.assertEqual(resp.status, 201)
            json_content_as_ref[:] = [3]
            resp = await self.session.get(self.url, json=json_content_as_ref)
            self.assertEqual(resp.status, 202)

            key = ('GET', URL(self.url))
            self.assertIn(key, m.requests)
            self.assertEqual(len(m.requests[key]), 3)

            first_request = m.requests[key][0]
            self.assertEqual(first_request.args, tuple())
            self.assertEqual(first_request.kwargs, {
                'allow_redirects': True,
                "json": [1]
            })

            second_request = m.requests[key][1]
            self.assertEqual(second_request.args, tuple())
            self.assertEqual(second_request.kwargs, {
                'allow_redirects': True,
                "json": [2]
            })

            third_request = m.requests[key][2]
            self.assertEqual(third_request.args, tuple())
            self.assertEqual(third_request.kwargs, {
                'allow_redirects': True,
                "json": [3]
            })

    async def test_request_with_non_deepcopyable_parameter(self):
        def non_deep_copyable():
            """A generator does not allow deepcopy."""
            for line in ["header1,header2", "v1,v2", "v10,v20"]:
                yield line

        generator_value = non_deep_copyable()

        with aioresponses() as m:
            m.get(self.url, status=200)
            resp = await self.session.get(self.url, data=generator_value)
            self.assertEqual(resp.status, 200)

            key = ('GET', URL(self.url))
            self.assertIn(key, m.requests)
            self.assertEqual(len(m.requests[key]), 1)

            request = m.requests[key][0]
            self.assertEqual(request.args, tuple())
            self.assertEqual(request.kwargs, {
                'allow_redirects': True,
                "data": generator_value
            })

    async def test_request_retrieval_in_case_no_response(self):
        with aioresponses() as m:
            with self.assertRaises(ClientConnectionError):
                await self.session.get(self.url)

            key = ('GET', URL(self.url))
            self.assertIn(key, m.requests)
            self.assertEqual(len(m.requests[key]), 1)
            self.assertEqual(m.requests[key][0].args, tuple())
            self.assertEqual(m.requests[key][0].kwargs,
                             {'allow_redirects': True})

    async def test_request_failure_in_case_session_is_closed(self):
        async def do_request(session):
            return (await session.get(self.url))

        with aioresponses():
            coro = do_request(self.session)
            await self.session.close()

            with self.assertRaises(RuntimeError) as exception_info:
                await coro
            assert str(exception_info.exception) == "Session is closed"

    async def test_address_as_instance_of_url_combined_with_pass_through(self):
        external_api = 'http://httpbin.org/status/201'

        async def doit():
            api_resp = await self.session.get(self.url)
            # we have to hit actual url,
            # otherwise we do not test pass through option properly
            ext_rep = await self.session.get(URL(external_api))
            return api_resp, ext_rep

        with aioresponses(passthrough=[external_api]) as m:
            m.get(self.url, status=200)
            api, ext = await doit()

            self.assertEqual(api.status, 200)
            self.assertEqual(ext.status, 201)

    async def test_pass_through_with_origin_params(self):
        external_api = 'http://httpbin.org/get'

        async def doit(params):
            # we have to hit actual url,
            # otherwise we do not test pass through option properly
            ext_rep = await self.session.get(URL(external_api), params=params)
            return ext_rep

        with aioresponses(passthrough=[external_api]) as m:
            params = {'foo': 'bar'}
            ext = await doit(params=params)
            self.assertEqual(ext.status, 200)
            self.assertEqual(str(ext.url), 'http://httpbin.org/get?foo=bar')

    @aioresponses()
    async def test_custom_response_class(self, m):
        class CustomClientResponse(ClientResponse):
            pass

        m.get(self.url, body='Test', response_class=CustomClientResponse)
        resp = await self.session.get(self.url)
        self.assertTrue(isinstance(resp, CustomClientResponse))

    @aioresponses()
    def test_exceptions_in_the_middle_of_responses(self, mocked):
        mocked.get(self.url, payload={}, status=204)
        mocked.get(
            self.url,
            exception=ValueError('oops'),
        )
        mocked.get(self.url, payload={}, status=204)
        mocked.get(
            self.url,
            exception=ValueError('oops'),
        )
        mocked.get(self.url, payload={}, status=200)

        async def doit():
            return (await self.session.get(self.url))

        self.assertEqual(self.run_async(doit()).status, 204)
        with self.assertRaises(ValueError):
            self.run_async(doit())
        self.assertEqual(self.run_async(doit()).status, 204)
        with self.assertRaises(ValueError):
            self.run_async(doit())
        self.assertEqual(self.run_async(doit()).status, 200)

    @aioresponses()
    async def test_request_should_match_regexp(self, mocked):
        mocked.get(re.compile(r'^http://example\.com/api\?foo=.*$'),
                   payload={},
                   status=200)

        response = await self.request(self.url)
        self.assertEqual(response.status, 200)

    @aioresponses()
    async def test_request_does_not_match_regexp(self, mocked):
        mocked.get(re.compile(r'^http://exampleexample\.com/api\?foo=.*$'),
                   payload={},
                   status=200)
        with self.assertRaises(ClientConnectionError):
            await self.request(self.url)

    @aioresponses()
    def test_timeout(self, mocked):
        mocked.get(self.url, timeout=True)

        with self.assertRaises(asyncio.TimeoutError):
            self.run_async(self.request(self.url))

    @aioresponses()
    def test_callback(self, m):
        body = b'New body'

        def callback(url, **kwargs):
            self.assertEqual(str(url), self.url)
            self.assertEqual(kwargs, {'allow_redirects': True})
            return CallbackResult(body=body)

        m.get(self.url, callback=callback)
        response = self.run_async(self.request(self.url))
        data = self.run_async(response.read())
        assert data == body

    @aioresponses()
    def test_callback_coroutine(self, m):
        body = b'New body'
        event = asyncio.Event()

        async def callback(url, **kwargs):
            await event.wait()
            self.assertEqual(str(url), self.url)
            self.assertEqual(kwargs, {'allow_redirects': True})
            return CallbackResult(body=body)

        m.get(self.url, callback=callback)
        future = asyncio.ensure_future(self.request(self.url))
        self.run_async(asyncio.wait([future], timeout=0))
        assert not future.done()
        event.set()
        self.run_async(asyncio.wait([future], timeout=0))
        assert future.done()
        response = future.result()
        data = self.run_async(response.read())
        assert data == body

    @aioresponses()
    async def test_exception_requests_are_tracked(self, mocked):
        kwargs = {"json": [42], "allow_redirects": True}
        mocked.get(self.url, exception=ValueError('oops'))

        with self.assertRaises(ValueError):
            await self.session.get(self.url, **kwargs)

        key = ('GET', URL(self.url))
        mocked_requests = mocked.requests[key]
        self.assertEqual(len(mocked_requests), 1)

        request = mocked_requests[0]
        self.assertEqual(request.args, ())
        self.assertEqual(request.kwargs, kwargs)

    async def test_possible_race_condition(self):
        async def random_sleep_cb(url, **kwargs):
            await asyncio.sleep(uniform(0.1, 1))
            return CallbackResult(body='test')

        with aioresponses() as mocked:
            for i in range(20):
                mocked.get('http://example.org/id-{}'.format(i),
                           callback=random_sleep_cb)

            tasks = [
                self.session.get('http://example.org/id-{}'.format(i))
                for i in range(20)
            ]
            await asyncio.gather(*tasks)
class AIOResponsesTestCase(TestCase):
    use_default_loop = False

    @asyncio.coroutine
    def setUp(self):
        self.url = 'http://example.com/api?foo=bar#fragment'
        self.session = ClientSession()
        super().setUp()

    @asyncio.coroutine
    def tearDown(self):
        close_result = self.session.close()
        if close_result is not None:
            yield from close_result
        super().tearDown()

    def run_async(self, coroutine: Union[Coroutine, Generator]):
        return self.loop.run_until_complete(coroutine)

    @asyncio.coroutine
    def request(self, url: str):
        return (yield from self.session.get(url))

    @data(
        hdrs.METH_HEAD,
        hdrs.METH_GET,
        hdrs.METH_POST,
        hdrs.METH_PUT,
        hdrs.METH_PATCH,
        hdrs.METH_DELETE,
        hdrs.METH_OPTIONS,
    )
    @patch('aioresponses.aioresponses.add')
    @fail_on(unused_loop=False)
    def test_shortcut_method(self, http_method, mocked):
        with aioresponses() as m:
            getattr(m, http_method.lower())(self.url)
            mocked.assert_called_once_with(self.url, method=http_method)

    @aioresponses()
    def test_returned_instance(self, m):
        m.get(self.url)
        response = self.run_async(self.session.get(self.url))
        self.assertIsInstance(response, ClientResponse)

    @aioresponses()
    @asyncio.coroutine
    def test_returned_instance_and_status_code(self, m):
        m.get(self.url, status=204)
        response = yield from self.session.get(self.url)
        self.assertIsInstance(response, ClientResponse)
        self.assertEqual(response.status, 204)

    @aioresponses()
    @asyncio.coroutine
    def test_returned_response_headers(self, m):
        m.get(self.url,
              content_type='text/html',
              headers={'Connection': 'keep-alive'})
        response = yield from self.session.get(self.url)

        self.assertEqual(response.headers['Connection'], 'keep-alive')
        self.assertEqual(response.headers[hdrs.CONTENT_TYPE], 'text/html')

    @aioresponses()
    @asyncio.coroutine
    def test_returned_response_raw_headers(self, m):
        m.get(self.url,
              content_type='text/html',
              headers={'Connection': 'keep-alive'})
        response = yield from self.session.get(self.url)
        expected_raw_headers = ((b'Content-Type', b'text/html'),
                                (b'Connection', b'keep-alive'))

        self.assertEqual(response.raw_headers, expected_raw_headers)

    @aioresponses()
    @asyncio.coroutine
    def test_raise_for_status(self, m):
        m.get(self.url, status=400)
        with self.assertRaises(ClientResponseError) as cm:
            response = yield from self.session.get(self.url)
            response.raise_for_status()
        self.assertEqual(cm.exception.message, http.RESPONSES[400][0])

    @aioresponses()
    @asyncio.coroutine
    def test_returned_instance_and_params_handling(self, m):
        expected_url = 'http://example.com/api?foo=bar&x=42#fragment'
        m.get(expected_url)
        response = yield from self.session.get(self.url, params={'x': 42})
        self.assertIsInstance(response, ClientResponse)
        self.assertEqual(response.status, 200)

        expected_url = 'http://example.com/api?x=42#fragment'
        m.get(expected_url)
        response = yield from self.session.get(
            'http://example.com/api#fragment', params={'x': 42})
        self.assertIsInstance(response, ClientResponse)
        self.assertEqual(response.status, 200)

    @aioresponses()
    def test_method_dont_match(self, m):
        m.get(self.url)
        with self.assertRaises(ClientConnectionError):
            self.run_async(self.session.post(self.url))

    @aioresponses()
    @asyncio.coroutine
    def test_streaming(self, m):
        m.get(self.url, body='Test')
        resp = yield from self.session.get(self.url)
        content = yield from resp.content.read()
        self.assertEqual(content, b'Test')

    @aioresponses()
    @asyncio.coroutine
    def test_streaming_up_to(self, m):
        m.get(self.url, body='Test')
        resp = yield from self.session.get(self.url)
        content = yield from resp.content.read(2)
        self.assertEqual(content, b'Te')
        content = yield from resp.content.read(2)
        self.assertEqual(content, b'st')

    @asyncio.coroutine
    def test_mocking_as_context_manager(self):
        with aioresponses() as aiomock:
            aiomock.add(self.url, payload={'foo': 'bar'})
            resp = yield from self.session.get(self.url)
            self.assertEqual(resp.status, 200)
            payload = yield from resp.json()
            self.assertDictEqual(payload, {'foo': 'bar'})

    def test_mocking_as_decorator(self):
        @aioresponses()
        def foo(loop, m):
            m.add(self.url, payload={'foo': 'bar'})

            resp = loop.run_until_complete(self.session.get(self.url))
            self.assertEqual(resp.status, 200)
            payload = loop.run_until_complete(resp.json())
            self.assertDictEqual(payload, {'foo': 'bar'})

        foo(self.loop)

    @asyncio.coroutine
    def test_passing_argument(self):
        @aioresponses(param='mocked')
        @asyncio.coroutine
        def foo(mocked):
            mocked.add(self.url, payload={'foo': 'bar'})
            resp = yield from self.session.get(self.url)
            self.assertEqual(resp.status, 200)

        yield from foo()

    @fail_on(unused_loop=False)
    def test_mocking_as_decorator_wrong_mocked_arg_name(self):
        @aioresponses(param='foo')
        def foo(bar):
            # no matter what is here it should raise an error
            pass

        with self.assertRaises(TypeError) as cm:
            foo()
        exc = cm.exception
        self.assertIn("foo() got an unexpected keyword argument 'foo'",
                      str(exc))

    @asyncio.coroutine
    def test_unknown_request(self):
        with aioresponses() as aiomock:
            aiomock.add(self.url, payload={'foo': 'bar'})
            with self.assertRaises(ClientConnectionError):
                yield from self.session.get('http://example.com/foo')

    @asyncio.coroutine
    def test_raising_custom_error(self):
        with aioresponses() as aiomock:
            aiomock.get(self.url, exception=HttpProcessingError(message='foo'))
            with self.assertRaises(HttpProcessingError):
                yield from self.session.get(self.url)

    @asyncio.coroutine
    def test_multiple_requests(self):
        with aioresponses() as m:
            m.get(self.url, status=200)
            m.get(self.url, status=201)
            m.get(self.url, status=202)
            resp = yield from self.session.get(self.url)
            self.assertEqual(resp.status, 200)
            resp = yield from self.session.get(self.url)
            self.assertEqual(resp.status, 201)
            resp = yield from self.session.get(self.url)
            self.assertEqual(resp.status, 202)

            key = ('GET', URL(self.url))
            self.assertIn(key, m.requests)
            self.assertEqual(len(m.requests[key]), 3)
            self.assertEqual(m.requests[key][0].args, tuple())
            self.assertEqual(m.requests[key][0].kwargs,
                             {'allow_redirects': True})

    @asyncio.coroutine
    def test_address_as_instance_of_url_combined_with_pass_through(self):
        external_api = 'http://httpbin.org/status/201'

        @asyncio.coroutine
        def doit():
            api_resp = yield from self.session.get(self.url)
            # we have to hit actual url,
            # otherwise we do not test pass through option properly
            ext_rep = yield from self.session.get(URL(external_api))
            return api_resp, ext_rep

        with aioresponses(passthrough=[external_api]) as m:
            m.get(self.url, status=200)
            api, ext = yield from doit()

            self.assertEqual(api.status, 200)
            self.assertEqual(ext.status, 201)

    @aioresponses()
    @asyncio.coroutine
    def test_custom_response_class(self, m):
        class CustomClientResponse(ClientResponse):
            pass

        m.get(self.url, body='Test', response_class=CustomClientResponse)
        resp = yield from self.session.get(self.url)
        self.assertTrue(isinstance(resp, CustomClientResponse))

    @aioresponses()
    def test_exceptions_in_the_middle_of_responses(self, mocked):
        mocked.get(self.url, payload={}, status=204)
        mocked.get(
            self.url,
            exception=ValueError('oops'),
        )
        mocked.get(self.url, payload={}, status=204)
        mocked.get(
            self.url,
            exception=ValueError('oops'),
        )
        mocked.get(self.url, payload={}, status=200)

        @asyncio.coroutine
        def doit():
            return (yield from self.session.get(self.url))

        self.assertEqual(self.run_async(doit()).status, 204)
        with self.assertRaises(ValueError):
            self.run_async(doit())
        self.assertEqual(self.run_async(doit()).status, 204)
        with self.assertRaises(ValueError):
            self.run_async(doit())
        self.assertEqual(self.run_async(doit()).status, 200)

    @aioresponses()
    @asyncio.coroutine
    def test_request_should_match_regexp(self, mocked):
        mocked.get(re.compile(r'^http://example\.com/api\?foo=.*$'),
                   payload={},
                   status=200)

        response = yield from self.request(self.url)
        self.assertEqual(response.status, 200)

    @aioresponses()
    @asyncio.coroutine
    def test_request_does_not_match_regexp(self, mocked):
        mocked.get(re.compile(r'^http://exampleexample\.com/api\?foo=.*$'),
                   payload={},
                   status=200)
        with self.assertRaises(ClientConnectionError):
            yield from self.request(self.url)

    @aioresponses()
    def test_timeout(self, mocked):
        mocked.get(self.url, timeout=True)

        with self.assertRaises(asyncio.TimeoutError):
            self.run_async(self.request(self.url))

    @aioresponses()
    def test_callback(self, m):
        body = b'New body'

        def callback(url, **kwargs):
            self.assertEqual(str(url), self.url)
            self.assertEqual(kwargs, {'allow_redirects': True})
            return CallbackResult(body=body)

        m.get(self.url, callback=callback)
        response = self.run_async(self.request(self.url))
        data = self.run_async(response.read())
        assert data == body
Beispiel #27
0
class AIOResponsesTestCase(TestCase):
    def setUp(self):
        self.url = 'http://example.com/api'
        self.loop = asyncio.get_event_loop()
        self.session = ClientSession()

    def tearDown(self):
        self.session.close()

    @data(
        hdrs.METH_GET,
        hdrs.METH_POST,
        hdrs.METH_PUT,
        hdrs.METH_PATCH,
        hdrs.METH_DELETE,
        hdrs.METH_OPTIONS,
    )
    @patch('aioresponses.aioresponses.add')
    def test_shortcut_method(self, http_method, mocked):
        with aioresponses() as m:
            getattr(m, http_method.lower())(self.url)
            mocked.assert_called_once_with(self.url, method=http_method)

    @aioresponses()
    def test_returned_instance(self, m):
        m.get(self.url)
        response = self.loop.run_until_complete(self.session.get(self.url))
        self.assertIsInstance(response, ClientResponse)

    @aioresponses()
    def test_returned_response_headers(self, m):
        m.get(self.url,
              content_type='text/html',
              headers={'Connection': 'keep-alive'})
        response = self.loop.run_until_complete(self.session.get(self.url))

        self.assertEqual(response.headers['Connection'], 'keep-alive')
        self.assertEqual(response.headers[hdrs.CONTENT_TYPE], 'text/html')

    @aioresponses()
    def test_method_dont_match(self, m):
        m.get(self.url)
        with self.assertRaises(ClientConnectionError):
            self.loop.run_until_complete(self.session.post(self.url))

    def test_mocking_as_context_manager(self):
        with aioresponses() as aiomock:
            aiomock.add(self.url, payload={'foo': 'bar'})
            resp = self.loop.run_until_complete(self.session.get(self.url))
            self.assertEqual(resp.status, 200)
            payload = self.loop.run_until_complete(resp.json())
            self.assertDictEqual(payload, {'foo': 'bar'})

    def test_mocking_as_decorator(self):
        @aioresponses()
        def foo(m):
            m.add(self.url, payload={'foo': 'bar'})

            resp = self.loop.run_until_complete(self.session.get(self.url))
            self.assertEqual(resp.status, 200)
            payload = self.loop.run_until_complete(resp.json())
            self.assertDictEqual(payload, {'foo': 'bar'})

        foo()

    def test_passing_argument(self):
        @aioresponses(param='mocked')
        def foo(mocked):
            mocked.add(self.url, payload={'foo': 'bar'})
            resp = self.loop.run_until_complete(self.session.get(self.url))
            self.assertEqual(resp.status, 200)

        foo()

    def test_mocking_as_decorator_wrong_mocked_arg_name(self):
        @aioresponses(param='foo')
        def foo(bar):
            # no matter what is here it should raise an error
            pass

        with self.assertRaises(TypeError) as cm:
            foo()
        exc = cm.exception
        self.assertIn("foo() got an unexpected keyword argument 'foo'",
                      str(exc))

    def test_unknown_request(self):
        with aioresponses() as aiomock:
            aiomock.add(self.url, payload={'foo': 'bar'})
            with self.assertRaises(ClientConnectionError):
                self.loop.run_until_complete(
                    self.session.get('http://example.com/foo'))

    def test_multiple_requests(self):
        with aioresponses() as m:
            m.get(self.url, status=200)
            m.get(self.url, status=201)
            m.get(self.url, status=202)
            resp = self.loop.run_until_complete(self.session.get(self.url))
            self.assertEqual(resp.status, 200)
            resp = self.loop.run_until_complete(self.session.get(self.url))
            self.assertEqual(resp.status, 201)
            resp = self.loop.run_until_complete(self.session.get(self.url))
            self.assertEqual(resp.status, 202)
Beispiel #28
0
class GraphitePlugin(BasePlugin):
    _type = 'graphite'
    condition_map = {
        '>': operator.gt,
        '>=': operator.ge,
        '<': operator.lt,
        '<=': operator.le,
        '==': operator.eq
    }

    def __init__(self):
        super().__init__()
        self.session = ClientSession()

    def supported_types(self):
        return settings.GRAPHITE_TYPES

    def init(self, *args):
        super().init(*args)

    def get_graphite_url(self, env):
        return 'http://{}/render'.format(
            settings.ENVIRONMENTS[env]['metrics_service']['host']
        )

    def get_graph_url(self, env: str):
        """Returns a formatted URL for a graph."""
        return settings.ENVIRONMENTS[env]['metrics_service'].get(
            'graph_url',
            self.get_graphite_url(env)
        )

    @asyncio.coroutine
    @stats.increment(metric_name='graphite.execute')
    @stats.timer(metric_name='graphite.execute')
    def execute(self, check):
        logger.info(
            'Processing check: id="{}", metric="{}"'.format(
                check['id'], check['fields']['metric']
            )
        )
        fields = check['fields']
        check['target'] = fields['metric']

        env = check.get('environment', 'test')
        url = self.get_graphite_url(env)

        try:
            metric = fields['metric']
            with Timeout(settings.GRAPHITE_TIMEOUT):
                response = yield from self.session.get(
                    url,
                    params={
                        'target': metric,
                        'from': settings.GRAPHITE_TIME_RANGE,
                        'format': 'json'
                    }
                )

                response.raise_for_status()
                data = yield from response.json()
            self.process(check, data, env)

        except (ValueError, HttpProcessingError,
                TimeoutError, ClientOSError) as e:
            message = self.exception_repr(e, pattern='{}: {}')
            logger.error('Could not execute plugin for check `%s`, %s',
                         check['id'], message)
            self.set_check_status(check, RESULT_UNKNOWN, message)

    def add_meta(self, check, trigger, env, threshold=None):
        threshold = threshold or 0.0
        now = datetime.now()
        # - 1 hour
        from_date = now - timedelta(0, 3600)
        params = [
            ('width', settings.GRAPHITE_CHART_WIDTH),
            ('height', settings.GRAPHITE_CHART_HEIGHT),
            ('tz', settings.GRAPHITE_TIME_ZONE),
            ('from', from_date.strftime('%H:%M_%Y%m%d')),
            ('until', now.strftime('%H:%M_%Y%m%d')),
            ('target', check['fields']['metric']),
            ('target', 'threshold({}, \"{}\", red)'.format(
                threshold, 'threshold = {}'.format(threshold),
            )),
        ]
        graphite_url = "{}?{}".format(
            self.get_graph_url(env), urlencode(params)
        )
        trigger['meta']['links']['graphite_url'] = dict(
            type='link', href=graphite_url
        )

    @stats.increment(metric_name='graphite.process')
    @stats.timer(metric_name='graphite.process')
    def process(self, check, data, env):
        frequency = int(check['fields']['frequency'])
        expected_num_hosts = check['fields'].get('expected_num_hosts', 0)

        ret = Result()

        for result in data:
            datapoints = result['datapoints']

            values = [float(t[0])
                      for t in datapoints[-int(frequency / 60) - 1:-1]
                      if t[0] is not None]

            ret.num_series_with_data += 1 if values else 0
            ret.num_series_no_data += 0 if values else 1
            ret.all_values.extend(values)

        for trigger in check['triggers']:
            status, failed_values, message = RESULT_OK, [], ''
            hyst_status, hystfail_values = RESULT_OK, []

            debounce = trigger.get('debounce') or check['fields'].get(
                'debounce', 1)

            trigger_expected_num_hosts = int(
                trigger.get('expected_num_hosts', 0) or expected_num_hosts
            )
            try:
                threshold = float(trigger.get('threshold'))
                condition = self.condition_map[trigger.get('condition')]
                hysteresis_value = float(trigger.get('hysteresis', "0.0"))
            except (ValueError, TypeError) as e:
                message = INVALID_THRESHOLD.format(
                    trigger.get('threshold', '<unknown>'), check['id'], e
                )
                logger.error(message)
                self.set_status(
                    trigger, RESULT_UNKNOWN, message, hysteresis=RESULT_UNKNOWN
                )
                continue
            except KeyError:
                message = INVALID_OPERATOR.format(
                    trigger.get('condition'),
                    list(self.condition_map.keys())
                )
                logger.error(message)
                self.set_status(
                    trigger, RESULT_UNKNOWN, message, hysteresis=RESULT_UNKNOWN
                )
                self.add_meta(check, trigger, env, threshold=threshold)
                continue

            if trigger.get('condition') in ['<', '<=']:
                hyst_threshold = threshold + hysteresis_value
            elif trigger.get('condition') in ['>', '>=']:
                hyst_threshold = threshold - hysteresis_value
            else:
                hyst_threshold = threshold

            failed_values = list(filter(
                lambda x: condition(x, threshold), ret.all_values
            ))

            hystfail_values = list(filter(
                lambda x: condition(x, hyst_threshold), ret.all_values
            ))

            if failed_values:
                status = hyst_status = RESULT_FAILED
                values = ', '.join(
                    [str(x) for x in failed_values[-int(debounce):]]
                )
                message = '[{}] {} {}'.format(
                    values, trigger['condition'], threshold
                )
            elif hystfail_values:
                # if hysteresis is off or hardfail threshold hits,
                # we won't even reach here
                hyst_status = RESULT_FAILED

                hysteresis_sign = '-' if hyst_threshold < threshold else '+'
                values = ', '.join(
                    [str(x) for x in hystfail_values[-int(debounce):]]
                )
                message = '[{}] {} ({} {} {} hysteresis)'.format(
                    values, trigger['condition'],
                    threshold, hysteresis_sign, hysteresis_value
                )

            if ret.num_series_with_data < trigger_expected_num_hosts:
                status = hyst_status = RESULT_FAILED
                msg = MESSAGE_TPL.format(
                    ret.num_series_with_data, trigger_expected_num_hosts
                )
                message = '\n '.join(msg for msg in [msg, message] if msg)

            self.set_status(trigger, status, message, hysteresis=hyst_status)
            self.add_meta(check, trigger, env, threshold=threshold)
Beispiel #29
0
 async def get_activation(self, session: ClientSession, heuristics):
     url = self._format_template('activation')
     param_dict = dict(heuristics=heuristics)
     session = session.get(url, json=param_dict)
     response = await self._get_response(session)
     return response.json["activation"]
class AIOResponseRedirectTest(TestCase):
    @asyncio.coroutine
    def setUp(self):
        self.url = "http://10.1.1.1:8080/redirect"
        self.session = ClientSession()
        super().setUp()

    @asyncio.coroutine
    def tearDown(self):
        close_result = self.session.close()
        if close_result is not None:
            yield from close_result
        super().tearDown()

    @aioresponses()
    @asyncio.coroutine
    def test_redirect_followed(self, rsps):
        rsps.get(
            self.url,
            status=307,
            headers={"Location": "https://httpbin.org"},
        )
        rsps.get("https://httpbin.org")
        response = yield from self.session.get(
            self.url, allow_redirects=True
        )
        self.assertEqual(response.status, 200)
        self.assertEqual(str(response.url), "https://httpbin.org")
        self.assertEqual(len(response.history), 1)
        self.assertEqual(str(response.history[0].url), self.url)

    @aioresponses()
    @asyncio.coroutine
    def test_post_redirect_followed(self, rsps):
        rsps.post(
            self.url,
            status=307,
            headers={"Location": "https://httpbin.org"},
        )
        rsps.get("https://httpbin.org")
        response = yield from self.session.post(
            self.url, allow_redirects=True
        )
        self.assertEqual(response.status, 200)
        self.assertEqual(str(response.url), "https://httpbin.org")
        self.assertEqual(response.method, "get")
        self.assertEqual(len(response.history), 1)
        self.assertEqual(str(response.history[0].url), self.url)

    @aioresponses()
    @asyncio.coroutine
    def test_redirect_missing_mocked_match(self, rsps):
        rsps.get(
            self.url,
            status=307,
            headers={"Location": "https://httpbin.org"},
        )
        with self.assertRaises(ClientConnectionError) as cm:
            yield from self.session.get(
                self.url, allow_redirects=True
            )
        self.assertEqual(
            str(cm.exception),
            'Connection refused: GET http://10.1.1.1:8080/redirect'
        )

    @aioresponses()
    @asyncio.coroutine
    def test_redirect_missing_location_header(self, rsps):
        rsps.get(self.url, status=307)
        response = yield from self.session.get(self.url, allow_redirects=True)
        self.assertEqual(str(response.url), self.url)

    @aioresponses()
    @asyncio.coroutine
    @skipIf(condition=AIOHTTP_VERSION < '3.1.0',
            reason='aiohttp<3.1.0 does not add request info on response')
    def test_request_info(self, rsps):
        rsps.get(self.url, status=200)

        response = yield from self.session.get(self.url)

        request_info = response.request_info
        assert str(request_info.url) == self.url
        assert request_info.headers == {}

    @aioresponses()
    @asyncio.coroutine
    @skipIf(condition=AIOHTTP_VERSION < '3.1.0',
            reason='aiohttp<3.1.0 does not add request info on response')
    def test_request_info_with_original_request_headers(self, rsps):
        headers = {"Authorization": "Bearer access-token"}
        rsps.get(self.url, status=200)

        response = yield from self.session.get(self.url, headers=headers)

        request_info = response.request_info
        assert str(request_info.url) == self.url
        assert request_info.headers == headers
Beispiel #31
0
 async def get_content_for_user(self, session: ClientSession,
                                request: RecommendationRequest):
     url = self._format_template('recommendation')
     session = session.get(url, params=request._asdict())
     return await self._get_response(session)
Beispiel #32
0
 def test_http_methods(self, patched):
     session = ClientSession(loop=self.loop)
     add_params = dict(
         headers={"Authorization": "Basic ..."},
         max_redirects=2,
         encoding="latin1",
         version=aiohttp.HttpVersion10,
         compress="deflate",
         chunked=True,
         expect100=True,
         read_until_eof=False)
     run = self.loop.run_until_complete
     # Check GET
     run(session.get(
         "http://test.example.com",
         params={"x": 1},
         **add_params))
     self.assertEqual(
         patched.call_count, 1, "`ClientSession.request` not called")
     self.assertEqual(
         list(patched.call_args),
         [("GET", "http://test.example.com",),
          dict(
             params={"x": 1},
             allow_redirects=True,
             **add_params)])
     # Check OPTIONS
     run(session.options(
         "http://opt.example.com",
         params={"x": 2},
         **add_params))
     self.assertEqual(
         patched.call_count, 2, "`ClientSession.request` not called")
     self.assertEqual(
         list(patched.call_args),
         [("OPTIONS", "http://opt.example.com",),
          dict(
             params={"x": 2},
             allow_redirects=True,
             **add_params)])
     # Check HEAD
     run(session.head(
         "http://head.example.com",
         params={"x": 2},
         **add_params))
     self.assertEqual(
         patched.call_count, 3, "`ClientSession.request` not called")
     self.assertEqual(
         list(patched.call_args),
         [("HEAD", "http://head.example.com",),
          dict(
             params={"x": 2},
             allow_redirects=False,
             **add_params)])
     # Check POST
     run(session.post(
         "http://post.example.com",
         params={"x": 2},
         data="Some_data",
         files={"x": '1'},
         **add_params))
     self.assertEqual(
         patched.call_count, 4, "`ClientSession.request` not called")
     self.assertEqual(
         list(patched.call_args),
         [("POST", "http://post.example.com",),
          dict(
             params={"x": 2},
             data="Some_data",
             files={"x": '1'},
             **add_params)])
     # Check PUT
     run(session.put(
         "http://put.example.com",
         params={"x": 2},
         data="Some_data",
         files={"x": '1'},
         **add_params))
     self.assertEqual(
         patched.call_count, 5, "`ClientSession.request` not called")
     self.assertEqual(
         list(patched.call_args),
         [("PUT", "http://put.example.com",),
          dict(
             params={"x": 2},
             data="Some_data",
             files={"x": '1'},
             **add_params)])
     # Check PATCH
     run(session.patch(
         "http://patch.example.com",
         params={"x": 2},
         data="Some_data",
         files={"x": '1'},
         **add_params))
     self.assertEqual(
         patched.call_count, 6, "`ClientSession.request` not called")
     self.assertEqual(
         list(patched.call_args),
         [("PATCH", "http://patch.example.com",),
          dict(
             params={"x": 2},
             data="Some_data",
             files={"x": '1'},
             **add_params)])
     # Check DELETE
     run(session.delete(
         "http://delete.example.com",
         params={"x": 2},
         **add_params))
     self.assertEqual(
         patched.call_count, 7, "`ClientSession.request` not called")
     self.assertEqual(
         list(patched.call_args),
         [("DELETE", "http://delete.example.com",),
          dict(
             params={"x": 2},
             **add_params)])