Exemplo n.º 1
0
def test_custom_validator():
    @validator(string=True)
    def choice_validator(compiler, items):
        choices = set(items.split())

        def validate(value):
            if value in choices:
                return value
            raise Invalid('invalid choice')

        return validate

    compiler = Compiler(validators={'choice': choice_validator})
    schema = T.list(T.choice('A B C D').default('A'))
    assert T(schema) == schema  # test copy custom validator
    validate = compiler.compile(schema)
    assert validate(['A', 'B', 'C', 'D', None]) == ['A', 'B', 'C', 'D', 'A']
Exemplo n.º 2
0
def test_wrapped_validator():
    str_validator = builtin_validators['str']
    assert str_validator.is_string

    logs = []

    @validator(string=True)
    def wrapped_str_validator(*args, **kwargs):
        _validate = str_validator.validator(*args, **kwargs)

        def validate(value):
            logs.append(value)
            return _validate(value)

        return validate

    compiler = Compiler(validators={'str': wrapped_str_validator})
    validate = compiler.compile(T.str.optional)
    assert validate('abc') == 'abc'
    assert logs == ['abc']
Exemplo n.º 3
0
@validator(accept=bytes, output=bytes)
def bytes_validator(compiler, maxlen=None):

    def validate(value):
        if not isinstance(value, bytes):
            raise Invalid('invalid bytes type')
        if maxlen is not None:
            if len(value) > maxlen:
                raise Invalid('value must <= {}'.format(maxlen))
        return value

    return validate


VALIDATORS = {
    'cursor': cursor_validator,
    'url': url_validator,
    'datetime': datetime_validator,
    'feed_unionid': create_unionid_validator(FeedUnionId),
    'story_unionid': create_unionid_validator(StoryUnionId),
    'detail': detail_validator,
    'str': str_validator,
    'bytes': bytes_validator,
}


compiler = Compiler(validators=VALIDATORS)

# warming up django url validator
compiler.compile(T.url)('https://example.com/')
Exemplo n.º 4
0
class RestRouter:
    def __init__(self, name=None, permission_classes=None):
        self.name = name
        if permission_classes:
            permission_classes = tuple(permission_classes)
        self.permission_classes = permission_classes
        self._schema_compiler = Compiler(validators=VALIDATORS)
        self._routes = []

    @property
    def urls(self):
        def key_func(r):
            f, url, methods, params, returns = r
            return url

        urls_map = {}
        routes = sorted(self._routes, key=key_func)
        groups = itertools.groupby(routes, key=key_func)
        for url, group in groups:
            view = self._make_view(list(group))
            urls_map[url] = path(url, view)
        # keep urls same order with self._routes
        # and constant url should priority then path argument
        urls = []
        urls_priority = []
        urls_added = set()
        for f, url, methods, params, returns in self._routes:
            if url not in urls_added:
                urls_added.add(url)
                if '<' in url and ':' in url and '>' in url:
                    urls.append(urls_map[url])
                else:
                    urls_priority.append(urls_map[url])
        return urls_priority + urls

    @staticmethod
    def _response_from_invalid(ex):
        return Response(
            {
                'description': str(ex),
                'position': ex.position,
                'message': ex.message,
                'field': ex.field,
                'value': ex.value,
            },
            status=400)

    @staticmethod
    def _make_method(method, f, params, returns):
        def rest_method(self, request, format=None, **kwargs):
            ret = None
            validr_cost = 0
            if params is not None:
                maps = [kwargs]
                if request.method in ['GET', 'DELETE']:
                    maps.append(request.query_params)
                else:
                    maps.append(request.data)
                t_begin = time.time()
                try:
                    kwargs = params(ChainMap(*maps))
                except Invalid as ex:
                    ret = RestRouter._response_from_invalid(ex)
                validr_cost += time.time() - t_begin
            t_begin = time.time()
            if ret is None:
                ret = f(request, **kwargs)
            api_cost = time.time() - t_begin
            if returns is not None:
                if not isinstance(ret, (Response, HttpResponse)):
                    t_begin = time.time()
                    ret = returns(ret)
                    validr_cost += time.time() - t_begin
                    ret = Response(ret)
            elif ret is None:
                ret = Response(status=204)
            if validr_cost > 0:
                ret['X-Validr-Time'] = '{:.0f}ms'.format(validr_cost * 1000)
            if api_cost > 0:
                ret['X-API-Time'] = '{:.0f}ms'.format(api_cost * 1000)
            return ret

        rest_method.__name__ = method.lower()
        rest_method.__qualname__ = method.lower()
        rest_method.__doc__ = f.__doc__
        return rest_method

    def _make_view(self, group):
        method_maps = {}
        method_meta = {}
        for f, url, methods, params, returns in group:
            for method in methods:
                if method in method_maps:
                    raise ValueError(f'duplicated method {method} of {url}')
                m = self._make_method(method, f, params, returns)
                method_maps[method] = m
                method_meta[method] = f, url, params, returns

        class RestApiView(APIView):

            if self.permission_classes:
                permission_classes = self.permission_classes

            schema = RestViewSchema(method_meta)

            if 'GET' in method_maps:
                get = method_maps['GET']
            if 'POST' in method_maps:
                post = method_maps['POST']
            if 'PUT' in method_maps:
                put = method_maps['PUT']
            if 'DELETE' in method_maps:
                delete = method_maps['DELETE']
            if 'PATCH' in method_maps:
                patch = method_maps['PATCH']

        return RestApiView.as_view()

    def _route(self, url, methods):
        if isinstance(methods, str):
            methods = set(methods.strip().replace(',', ' ').split())
        else:
            methods = set(methods)
        methods = set(x.upper() for x in methods)

        def wrapper(f):
            params = get_params(f)
            if params is not None:
                params = self._schema_compiler.compile(params)
            returns = get_returns(f)
            if returns is not None:
                returns = self._schema_compiler.compile(returns)
            self._routes.append((f, url, methods, params, returns))
            return f

        return wrapper

    def get(self, url=''):
        return self._route(url, methods='GET')

    def post(self, url=''):
        return self._route(url, methods='POST')

    def put(self, url=''):
        return self._route(url, methods='PUT')

    def delete(self, url=''):
        return self._route(url, methods='DELETE')

    def patch(self, url=''):
        return self._route(url, methods='PATCH')

    def route(self, url='', methods='GET'):
        return self._route(url, methods=methods)

    __call__ = route
Exemplo n.º 5
0
import os.path
import re
from functools import cached_property
from urllib.parse import urlparse

from dotenv import load_dotenv
from validr import T, Compiler, modelclass, fields, Invalid

from rssant_common.network_helper import LOCAL_NODE_NAME

MAX_FEED_COUNT = 5000

compiler = Compiler()
validate_extra_networks = compiler.compile(
    T.list(T.dict(
        name=T.str,
        url=T.url,
    )))


@modelclass(compiler=compiler)
class ConfigModel:
    pass


class GitHubConfigModel(ConfigModel):
    domain: str = T.str
    client_id: str = T.str
    secret: str = T.str

Exemplo n.º 6
0
def test_create_enum_validator():
    abcd_validator = create_enum_validator('abcd', ['A', 'B', 'C', 'D'])
    compiler = Compiler(validators={'abcd': abcd_validator})
    schema = T.list(T.abcd.default('A'))
    validate = compiler.compile(schema)
    assert validate(['A', 'B', 'C', 'D', None]) == ['A', 'B', 'C', 'D', 'A']
Exemplo n.º 7
0
from validr import T, Compiler, modelclass, asdict


@modelclass
class Model:
    user = T.dict(userid=T.int.min(0).max(9).desc("UserID"))
    tags = T.list(T.int.min(0))
    style = T.dict(
        width=T.int.desc("width"),
        height=T.int.desc("height"),
        border_width=T.int.desc("border_width"),
        border_style=T.str.desc("border_style"),
        border_color=T.str.desc("border_color"),
        color=T.str.desc("color"),
    )
    optional = T.str.optional.desc("unknown value")


compiler = Compiler()
default = compiler.compile(T(Model))


def model(value):
    return asdict(Model(value))


CASES = {"default": default, "model": model}
Exemplo n.º 8
0
class ValidrRouteTableDef(RouteTableDef):
    def __init__(self):
        super().__init__()
        self._schema_compiler = Compiler(validators=VALIDATORS)

    @staticmethod
    def _response_from_invalid(ex):
        return json_response(
            {
                'description': str(ex),
                'position': ex.position,
                'message': ex.message,
                'field': ex.field,
                'value': ex.value,
            },
            status=400)

    def decorate(self, f):
        assert inspect.iscoroutinefunction(f), f'{f} is not coroutine function'
        params = get_params(f)
        if params is not None:
            params = self._schema_compiler.compile(params)
        returns = get_returns(f)
        if returns is not None:
            returns = self._schema_compiler.compile(returns)

        async def wrapped(request, **kwargs):
            ret = None
            if params is not None:
                maps = [kwargs, request.match_info]
                if request.method in ['GET', 'DELETE']:
                    maps.append(request.query)
                else:
                    try:
                        maps.append(await request.json())
                    except json.JSONDecodeError:
                        return json_response({"message": 'Invalid JSON'},
                                             status=400)
                try:
                    kwargs = params(ChainMap(*maps))
                except Invalid as ex:
                    ret = self._response_from_invalid(ex)
            if ret is None:
                ret = await f(request, **kwargs)
            if returns is not None:
                if not isinstance(ret, StreamResponse):
                    ret = returns(ret)
                    ret = json_response(ret)
            elif ret is None:
                ret = Response(status=204)
            return ret

        wrapped.__name__ = f.__name__
        wrapped.__qualname__ = f.__qualname__
        wrapped.__doc__ = f.__doc__
        return wrapped

    def route(self, *args, **kwargs):
        routes_decorate = super().route(*args, **kwargs)

        def wrapper(f):
            return routes_decorate(self.decorate(f))

        return wrapper
Exemplo n.º 9
0
def test_load_schema():
    compiler = Compiler()
    schema = T.list(T.int.min(0))
    assert T(schema) == schema
    assert T(compiler.compile(schema)) == schema
    assert T(['int.min(0)']) == schema
Exemplo n.º 10
0
def test_compiled_items():
    compiler = Compiler()
    value = compiler.compile(T.int.min(0))
    assert repr(T.dict(key=value)) == 'T.dict({key})'
    assert repr(T.list(value)) == 'T.list(int)'