예제 #1
0
    def test_fetch(self, dummy_endpoint):
        with aioresponses() as m, Session() as session:
            body = b'hello world'
            m.post(
                dummy_endpoint + 'function', status=200, body=body,
                headers={'Content-Type': 'text/plain; charset=utf-8',
                         'Content-Length': str(len(body))},
            )
            rqst = Request(session, 'POST', 'function')
            with rqst.fetch() as resp:
                assert isinstance(resp, Response)
                assert resp.status == 200
                assert resp.content_type == 'text/plain'
                assert resp.text() == body.decode()
                assert resp.content_length == len(body)

        with aioresponses() as m, Session() as session:
            body = b'{"a": 1234, "b": null}'
            m.post(
                dummy_endpoint + 'function', status=200, body=body,
                headers={'Content-Type': 'application/json; charset=utf-8',
                         'Content-Length': str(len(body))},
            )
            rqst = Request(session, 'POST', 'function')
            with rqst.fetch() as resp:
                assert isinstance(resp, Response)
                assert resp.status == 200
                assert resp.content_type == 'application/json'
                assert resp.text() == body.decode()
                assert resp.json() == {'a': 1234, 'b': None}
                assert resp.content_length == len(body)
예제 #2
0
    async def info(self, fields: Iterable[str] = None) -> dict:
        '''
        Returns the keypair's information such as resource limits.

        :param fields: Additional per-agent query fields to fetch.

        .. versionadded:: 18.12
        '''
        if fields is None:
            fields = (
                'access_key', 'secret_key',
                'is_active', 'is_admin',
            )
        q = 'query {' \
            '  keypair {' \
            '    $fields' \
            '  }' \
            '}'
        q = q.replace('$fields', ' '.join(fields))
        rqst = Request(self.session, 'POST', '/admin/graphql')
        rqst.set_json({
            'query': q,
        })
        async with rqst.fetch() as resp:
            data = await resp.json()
            return data['keypair']
예제 #3
0
 async def update(cls, access_key: str,
                  is_active: bool = None,
                  is_admin: bool = None,
                  resource_policy: str = None,
                  rate_limit: int = None) -> dict:
     """
     Creates a new keypair with the given options.
     You need an admin privilege for this operation.
     """
     q = 'mutation($access_key: String!, $input: ModifyKeyPairInput!) {' + \
         '  modify_keypair(access_key: $access_key, props: $input) {' \
         '    ok msg' \
         '  }' \
         '}'
     variables = {
         'access_key': access_key,
         'input': {
             'is_active': is_active,
             'is_admin': is_admin,
             'resource_policy': resource_policy,
             'rate_limit': rate_limit,
         },
     }
     rqst = Request(cls.session, 'POST', '/admin/graphql')
     rqst.set_json({
         'query': q,
         'variables': variables,
     })
     async with rqst.fetch() as resp:
         data = await resp.json()
         return data['modify_keypair']
예제 #4
0
    def test_fetch_invalid_method(self, mock_request_params):
        mock_request_params['method'] = 'STRANGE'
        rqst = Request(**mock_request_params)

        with pytest.raises(AssertionError):
            with rqst.fetch():
                pass
예제 #5
0
    async def list(cls,
                   operation: bool = False,
                   fields: Iterable[str] = None) -> Sequence[dict]:
        '''
        Fetches the list of registered images in this cluster.
        '''

        if fields is None:
            fields = (
                'name',
                'tag',
                'hash',
            )
        q = 'query($is_operation: Boolean) {' \
            '  images(is_operation: $is_operation) {' \
            '    $fields' \
            '  }' \
            '}'
        q = q.replace('$fields', ' '.join(fields))
        variables = {
            'is_operation': operation,
        }
        rqst = Request(cls.session, 'POST', '/admin/graphql')
        rqst.set_json({
            'query': q,
            'variables': variables,
        })
        async with rqst.fetch() as resp:
            data = await resp.json()
            return data['images']
예제 #6
0
async def test_fetch_invalid_method(mock_request_params):
    mock_request_params['method'] = 'STRANGE'
    rqst = Request(**mock_request_params)

    with pytest.raises(AssertionError):
        async with rqst.fetch():
            pass
예제 #7
0
 async def deactivate(cls, access_key: str) -> dict:
     '''
     Deactivates this keypair.
     Deactivated keypairs cannot make any API requests
     unless activated again by an administrator.
     You need an admin privilege for this operation.
     '''
     q = 'mutation($access_key: String!, $input: ModifyKeyPairInput!) {' + \
         '  modify_keypair(access_key: $access_key, props: $input) {' \
         '    ok msg' \
         '  }' \
         '}'
     variables = {
         'access_key': access_key,
         'input': {
             'is_active': False,
             'is_admin': None,
             'resource_policy': None,
             'rate_limit': None,
         },
     }
     rqst = Request(cls.session, 'POST', '/admin/graphql')
     rqst.set_json({
         'query': q,
         'variables': variables,
     })
     async with rqst.fetch() as resp:
         data = await resp.json()
         return data['modify_keypair']
def test_auth_missing_signature(defconfig):
    random_msg = uuid.uuid4().hex
    request = Request('GET', '/authorize', {
        'echo': random_msg,
    })
    resp = request.send()
    assert resp.status == 401
예제 #9
0
    async def detail(cls, gid: str, fields: Iterable[str] = None) -> Sequence[dict]:
        '''
        Fetch information of a group with group ID.

        :param gid: ID of the group to fetch.
        :param fields: Additional per-group query fields to fetch.
        '''
        if fields is None:
            fields = ('id', 'name', 'description', 'is_active', 'created_at', 'domain_name',
                      'total_resource_slots', 'allowed_vfolder_hosts', 'integration_id')
        query = textwrap.dedent('''\
            query($gid: String!) {
                group(id: $gid) {$fields}
            }
        ''')
        query = query.replace('$fields', ' '.join(fields))
        variables = {'gid': gid}
        rqst = Request(cls.session, 'POST', '/admin/graphql')
        rqst.set_json({
            'query': query,
            'variables': variables,
        })
        async with rqst.fetch() as resp:
            data = await resp.json()
            return data['group']
예제 #10
0
    async def list(cls, domain_name: str,
                   fields: Iterable[str] = None) -> Sequence[dict]:
        '''
        Fetches the list of groups.

        :param domain_name: Name of domain to list groups.
        :param fields: Additional per-group query fields to fetch.
        '''
        if fields is None:
            fields = ('id', 'name', 'description', 'is_active',
                      'created_at', 'domain_name',
                      'total_resource_slots', 'allowed_vfolder_hosts',
                      'integration_id')
        query = textwrap.dedent('''\
            query($domain_name: String) {
                groups(domain_name: $domain_name) {$fields}
            }
        ''')
        query = query.replace('$fields', ' '.join(fields))
        variables = {'domain_name': domain_name}
        rqst = Request(cls.session, 'POST', '/admin/graphql')
        rqst.set_json({
            'query': query,
            'variables': variables,
        })
        async with rqst.fetch() as resp:
            data = await resp.json()
            return data['groups']
예제 #11
0
 async def remove_users(cls, gid: str, user_uuids: Iterable[str],
                        fields: Iterable[str] = None) -> dict:
     '''
     Remove users from a group.
     You need an admin privilege for this operation.
     '''
     query = textwrap.dedent('''\
         mutation($gid: String!, $input: ModifyGroupInput!) {
             modify_group(gid: $gid, props: $input) {
                 ok msg
             }
         }
     ''')
     variables = {
         'gid': gid,
         'input': {
             'user_update_mode': 'remove',
             'user_uuids': user_uuids,
         },
     }
     rqst = Request(cls.session, 'POST', '/admin/graphql')
     rqst.set_json({
         'query': query,
         'variables': variables,
     })
     async with rqst.fetch() as resp:
         data = await resp.json()
         return data['modify_group']
예제 #12
0
 async def update(cls, gid: str, name: str = None, description: str = None,
                  is_active: bool = None, total_resource_slots: str = None,
                  allowed_vfolder_hosts: Iterable[str] = None,
                  integration_id: str = None,
                  fields: Iterable[str] = None) -> dict:
     '''
     Update existing group.
     You need an admin privilege for this operation.
     '''
     query = textwrap.dedent('''\
         mutation($gid: String!, $input: ModifyGroupInput!) {
             modify_group(gid: $gid, props: $input) {
                 ok msg
             }
         }
     ''')
     variables = {
         'gid': gid,
         'input': {
             'name': name,
             'description': description,
             'is_active': is_active,
             'total_resource_slots': total_resource_slots,
             'allowed_vfolder_hosts': allowed_vfolder_hosts,
             'integration_id': integration_id,
         },
     }
     rqst = Request(cls.session, 'POST', '/admin/graphql')
     rqst.set_json({
         'query': query,
         'variables': variables,
     })
     async with rqst.fetch() as resp:
         data = await resp.json()
         return data['modify_group']
예제 #13
0
async def test_fetch_timeout_async(dummy_endpoint):
    with aioresponses() as m:
        async with AsyncSession() as session:
            m.post(dummy_endpoint, exception=asyncio.TimeoutError())
            rqst = Request(session, 'POST', '/')
            with pytest.raises(asyncio.TimeoutError):
                async with rqst.fetch():
                    pass
예제 #14
0
async def test_fetch_client_error_async(dummy_endpoint):
    with aioresponses() as m:
        async with AsyncSession() as session:
            m.post(dummy_endpoint, exception=aiohttp.ClientConnectionError())
            rqst = Request(session, 'POST', '/')
            with pytest.raises(BackendClientError):
                async with rqst.fetch():
                    pass
예제 #15
0
def test_build_correct_url(req_params):
    config = req_params['config']
    req = Request(**req_params)

    major_ver = config.version.split('.', 1)[0]
    path = '/' + req.path if len(req.path) > 0 else ''

    assert req.build_url() == urljoin(config.endpoint, major_ver + path)
예제 #16
0
def test_set_content_correctly(req_params):
    req_params['content'] = OrderedDict()
    req = Request(**req_params)
    new_data = b'new-data'

    assert not req.content
    req.content = new_data
    assert req.content == new_data
    assert req.headers['Content-Length'] == str(len(new_data))
예제 #17
0
 async def test_fetch_cancellation_async(self, dummy_endpoint):
     with aioresponses() as m:
         async with AsyncSession() as session:
             m.post(dummy_endpoint,
                    exception=asyncio.CancelledError())
             rqst = Request(session, 'POST', '/')
             with pytest.raises(asyncio.CancelledError):
                 async with rqst.fetch():
                     pass
예제 #18
0
 async def test_fetch_timeout_async(self, dummy_endpoint):
     with aioresponses() as m:
         async with AsyncSession() as session:
             m.post(dummy_endpoint,
                    exception=asyncio.TimeoutError())
             rqst = Request(session, 'POST', '/')
             with pytest.raises(asyncio.TimeoutError):
                 async with rqst.fetch():
                     pass
def test_auth_missing_signature(defconfig):
    random_msg = uuid.uuid4().hex
    request = Request('GET', '/authorize', {
        'echo': random_msg,
    })
    # let it bypass actual signing
    request._sign = lambda *args, **kwargs: None
    resp = request.send()
    assert resp.status == 401
예제 #20
0
 async def test_fetch_client_error_async(self, dummy_endpoint):
     with aioresponses() as m:
         async with AsyncSession() as session:
             m.post(dummy_endpoint,
                    exception=aiohttp.ClientConnectionError())
             rqst = Request(session, 'POST', '/')
             with pytest.raises(BackendClientError):
                 async with rqst.fetch():
                     pass
def test_auth(defconfig):
    random_msg = uuid.uuid4().hex
    request = Request('GET', '/authorize', {
        'echo': random_msg,
    })
    resp = request.send()
    assert resp.status == 200
    data = resp.json()
    assert data['authorized'] == 'yes'
    assert data['echo'] == random_msg
예제 #22
0
    def test_build_correct_url(self, mock_request_params):
        canonical_url = 'http://127.0.0.1:8081/function?app=999'

        mock_request_params['path'] = '/function'
        rqst = Request(**mock_request_params)
        assert str(rqst._build_url()) == canonical_url

        mock_request_params['path'] = 'function'
        rqst = Request(**mock_request_params)
        assert str(rqst._build_url()) == canonical_url
예제 #23
0
async def test_fetch_cancellation_async(dummy_endpoint):
    # It seems that aiohttp swallows asyncio.CancelledError
    with aioresponses() as m:
        async with AsyncSession() as session:
            m.post(dummy_endpoint,
                   exception=asyncio.CancelledError())
            rqst = Request('POST', '/')
            with pytest.raises(asyncio.CancelledError):
                async with rqst.fetch():
                    pass
예제 #24
0
 def test_auth(self):
     random_msg = uuid.uuid4().hex
     with Session() as sess:
         request = Request(sess, 'GET', '/auth')
         request.set_json({
             'echo': random_msg,
         })
         with request.fetch() as resp:
             assert resp.status == 200
             data = resp.json()
             assert data['authorized'] == 'yes'
             assert data['echo'] == random_msg
예제 #25
0
def test_auth():
    random_msg = uuid.uuid4().hex
    with Session() as sess:
        request = Request('GET', '/auth')
        request.set_json({
            'echo': random_msg,
        })
        with request.fetch() as resp:
            assert resp.status == 200
            data = resp.json()
            assert data['authorized'] == 'yes'
            assert data['echo'] == random_msg
예제 #26
0
 def test_not_found(self):
     with Session() as sess:
         request = Request(sess, 'GET', '/invalid-url-wow')
         with pytest.raises(BackendAPIError) as e:
             with request.fetch():
                 pass
         assert e.value.status == 404
         request = Request(sess, 'GET', '/auth/uh-oh')
         with pytest.raises(BackendAPIError) as e:
             with request.fetch():
                 pass
         assert e.value.status == 404
예제 #27
0
def test_auth_missing_signature(monkeypatch):
    random_msg = uuid.uuid4().hex
    with Session() as sess:
        rqst = Request('GET', '/auth')
        rqst.set_json({'echo': random_msg})
        # let it bypass actual signing
        from ai.backend.client import request
        noop_sign = lambda *args, **kwargs: ({}, None)
        monkeypatch.setattr(request, 'generate_signature', noop_sign)
        with pytest.raises(BackendAPIError) as e:
            with rqst.fetch():
                pass
        assert e.value.status == 401
예제 #28
0
async def test_response_async(defconfig, dummy_endpoint):
    body = b'{"test": 5678}'
    with aioresponses() as m:
        m.post(
            dummy_endpoint + 'function', status=200, body=body,
            headers={'Content-Type': 'application/json',
                     'Content-Length': str(len(body))},
        )
        async with AsyncSession(config=defconfig) as session:
            rqst = Request('POST', '/function')
            async with rqst.fetch() as resp:
                assert await resp.text() == '{"test": 5678}'
                assert await resp.json() == {'test': 5678}
예제 #29
0
 async def test_response_async(self, defconfig, dummy_endpoint):
     body = b'{"test": 5678}'
     with aioresponses() as m:
         m.post(
             dummy_endpoint + 'function', status=200, body=body,
             headers={'Content-Type': 'application/json',
                      'Content-Length': str(len(body))},
         )
         async with AsyncSession(config=defconfig) as session:
             rqst = Request(session, 'POST', '/function')
             async with rqst.fetch() as resp:
                 assert await resp.text() == '{"test": 5678}'
                 assert await resp.json() == {'test': 5678}
예제 #30
0
 def test_auth_missing_signature(self, monkeypatch):
     random_msg = uuid.uuid4().hex
     with Session() as sess:
         rqst = Request(sess, 'GET', '/auth')
         rqst.set_json({'echo': random_msg})
         # let it bypass actual signing
         from ai.backend.client import request
         noop_sign = lambda *args, **kwargs: ({}, None)
         monkeypatch.setattr(request, 'generate_signature', noop_sign)
         with pytest.raises(BackendAPIError) as e:
             with rqst.fetch():
                 pass
         assert e.value.status == 401
예제 #31
0
def test_auth_missing_body():
    with Session() as sess:
        request = Request('GET', '/auth')
        with pytest.raises(BackendAPIError) as e:
            with request.fetch():
                pass
        assert e.value.status == 400
예제 #32
0
def test_content_is_auto_set_to_blank_if_no_data(req_params):
    req_params = req_params.copy()
    req_params['content'] = None
    req = Request(**req_params)

    assert req.content_type == 'application/octet-stream'
    assert req.content == b''
예제 #33
0
 async def rescan_images(cls, registry: str):
     q = 'mutation($registry: String) {' \
         '  rescan_images(registry:$registry) {' \
         '   ok msg' \
         '  }' \
         '}'
     variables = {
         'registry': registry,
     }
     rqst = Request(cls.session, 'POST', '/admin/graphql')
     rqst.set_json({
         'query': q,
         'variables': variables,
     })
     async with rqst.fetch() as resp:
         data = await resp.json()
         return data['rescan_images']
예제 #34
0
 async def dealias_image(cls, alias: str) -> dict:
     q = 'mutation($alias: String!) {' \
         '  dealias_image(alias: $alias) {' \
         '   ok msg' \
         '  }' \
         '}'
     variables = {
         'alias': alias,
     }
     rqst = Request(cls.session, 'POST', '/admin/graphql')
     rqst.set_json({
         'query': q,
         'variables': variables,
     })
     async with rqst.fetch() as resp:
         data = await resp.json()
         return data['dealias_image']
예제 #35
0
def test_request_initialization(mock_request_params):
    rqst = Request(**mock_request_params)

    assert rqst.method == mock_request_params['method']
    assert rqst.params == mock_request_params['params']
    assert rqst.path == mock_request_params['path'].lstrip('/')
    assert rqst.content == mock_request_params['content']
    assert 'X-BackendAI-Version' in rqst.headers
예제 #36
0
async def test_streaming_fetch(dummy_endpoint):
    # Read content by chunks.
    with aioresponses() as m, Session() as session:
        body = b'hello world'
        m.post(
            dummy_endpoint + 'function', status=200, body=body,
            headers={'Content-Type': 'text/plain; charset=utf-8',
                     'Content-Length': str(len(body))},
        )
        rqst = Request('POST', 'function')
        async with rqst.fetch() as resp:
            assert resp.status == 200
            assert resp.content_type == 'text/plain'
            assert await resp.read(3) == b'hel'
            assert await resp.read(2) == b'lo'
            await resp.read()
            with pytest.raises(AssertionError):
                assert await resp.text()
예제 #37
0
 def test_streaming_fetch(self, dummy_endpoint):
     # Read content by chunks.
     with aioresponses() as m, Session() as session:
         body = b'hello world'
         m.post(
             dummy_endpoint + 'function', status=200, body=body,
             headers={'Content-Type': 'text/plain; charset=utf-8',
                      'Content-Length': str(len(body))},
         )
         rqst = Request(session, 'POST', 'function')
         with rqst.fetch() as resp:
             assert resp.status == 200
             assert resp.content_type == 'text/plain'
             assert resp.read(3) == b'hel'
             assert resp.read(2) == b'lo'
             resp.read()
             with pytest.raises(AssertionError):
                 assert resp.text()
예제 #38
0
async def test_invalid_requests(dummy_endpoint):
    with aioresponses() as m, Session() as session:
        body = json.dumps({
            'type': 'https://api.backend.ai/probs/kernel-not-found',
            'title': 'Kernel Not Found',
        }).encode('utf8')
        m.post(
            dummy_endpoint, status=404, body=body,
            headers={'Content-Type': 'application/problem+json; charset=utf-8',
                     'Content-Length': str(len(body))},
        )
        rqst = Request('POST', '/')
        with pytest.raises(BackendAPIError) as e:
            async with rqst.fetch():
                pass
            assert e.status == 404
            assert e.data['type'] == \
                'https://api.backend.ai/probs/kernel-not-found'
            assert e.data['title'] == 'Kernel Not Found'
예제 #39
0
def test_content_is_files(req_params):
    files = [
        ('src', 'test1.txt', io.BytesIO(), 'application/octet-stream'),
        ('src', 'test2.txt', io.BytesIO(), 'application/octet-stream'),
    ]
    req_params['content'] = files
    req = Request(**req_params)

    assert req.content_type == 'multipart/form-data'
    assert req.content == files
예제 #40
0
 def test_invalid_requests(self, dummy_endpoint):
     with aioresponses() as m, Session() as session:
         body = json.dumps({
             'type': 'https://api.backend.ai/probs/kernel-not-found',
             'title': 'Kernel Not Found',
         }).encode('utf8')
         m.post(
             dummy_endpoint, status=404, body=body,
             headers={'Content-Type': 'application/problem+json; charset=utf-8',
                      'Content-Length': str(len(body))},
         )
         rqst = Request(session, 'POST', '/')
         with pytest.raises(BackendAPIError) as e:
             with rqst.fetch():
                 pass
             assert e.status == 404
             assert e.data['type'] == \
                 'https://api.backend.ai/probs/kernel-not-found'
             assert e.data['title'] == 'Kernel Not Found'
예제 #41
0
def test_request_initialization(req_params):
    req = Request(**req_params)

    assert req.config == req_params['config']
    assert req.method == req_params['method']
    assert req.path == req_params['path'][1:]
    assert req.content == req_params['content']
    assert 'Date' in req.headers
    assert 'X-BackendAI-Version' in req.headers
    assert req._content == json.dumps(req_params['content']).encode('utf8')
예제 #42
0
    def test_request_attach_files(self, mock_request_params):
        files = [
            AttachedFile('test1.txt', io.BytesIO(), 'application/octet-stream'),
            AttachedFile('test2.txt', io.BytesIO(), 'application/octet-stream'),
        ]

        mock_request_params['content'] = b'something'
        rqst = Request(**mock_request_params)
        with pytest.raises(AssertionError):
            rqst.attach_files(files)

        mock_request_params['content'] = b''
        rqst = Request(**mock_request_params)
        rqst.attach_files(files)

        assert rqst.content_type == 'multipart/form-data'
        assert rqst.content == b''
        packed_content = rqst._pack_content()
        assert packed_content.is_multipart
예제 #43
0
    def test_request_set_content(self, mock_request_params):
        rqst = Request(**mock_request_params)
        assert rqst.content == mock_request_params['content']
        assert rqst.content_type == 'application/json'
        assert rqst._pack_content() is rqst.content

        mock_request_params['content'] = 'hello'
        mock_request_params['content_type'] = None
        rqst = Request(**mock_request_params)
        assert rqst.content == b'hello'
        assert rqst.content_type == 'text/plain'
        assert rqst._pack_content() is rqst.content

        mock_request_params['content'] = b'\x00\x01\xfe\xff'
        mock_request_params['content_type'] = None
        rqst = Request(**mock_request_params)
        assert rqst.content == b'\x00\x01\xfe\xff'
        assert rqst.content_type == 'application/octet-stream'
        assert rqst._pack_content() is rqst.content
예제 #44
0
 async def test_fetch_invalid_method_async(self):
     async with AsyncSession() as session:
         rqst = Request(session, 'STRANGE', '/')
         with pytest.raises(AssertionError):
             async with rqst.fetch():
                 pass
예제 #45
0
 def test_connection(self):
     with Session() as sess:
         request = Request(sess, 'GET', '/')
         with request.fetch() as resp:
             assert 'version' in resp.json()
예제 #46
0
 async def test_async_connection(self):
     async with AsyncSession() as sess:
         request = Request(sess, 'GET', '/')
         async with request.fetch() as resp:
             assert 'version' in await resp.json()
예제 #47
0
 def test_request_set_content_none(self, mock_request_params):
     mock_request_params = mock_request_params.copy()
     mock_request_params['content'] = None
     rqst = Request(**mock_request_params)
     assert rqst.content == b''
     assert rqst._pack_content() is rqst.content