예제 #1
0
    def __init__(self,
                 *,
                 allow_blank=False,
                 regex=None,
                 choices=None,
                 min_length=None,
                 max_length=None,
                 **kwargs):
        super(String, self).__init__(**kwargs)
        if regex is not None:
            self._trafaret = t.Regexp(regex)
        else:
            self._trafaret = t.String(allow_blank=allow_blank,
                                      min_length=min_length,
                                      max_length=max_length)
        self.choices = None
        if choices and is_collection(choices):
            if isinstance(choices, type(Enum)):
                self.choices = choices
                self._trafaret &= t.Enum(*choices.__members__.keys())
            else:
                self._trafaret &= t.Enum(*choices)

        if self.allow_none:
            self._trafaret |= t.Null()
예제 #2
0
    def __init__(self,
                 request: web.Request,
                 prev_cursor=None,
                 next_cursor=None,
                 cursor_regex: str = None):
        super(Cursor, self).__init__(request)

        self.cursor = request.query.get('page[cursor]', self.FIRST)
        if isinstance(self.cursor, str):
            if cursor_regex is not None:
                try:
                    self.cursor = t.Regexp(cursor_regex).check(self.cursor)
                except t.DataError:
                    raise HTTPBadRequest(detail='The cursor is invalid.',
                                         source_parameter='page[cursor]')
            self.cursor = make_sentinel(var_name=str(self.cursor))

        self.prev_cursor = \
            make_sentinel(var_name=str(prev_cursor)) if prev_cursor else None
        self.next_cursor = \
            make_sentinel(var_name=str(next_cursor)) if next_cursor else None

        self.limit = request.query.get('page[limit]', DEFAULT_LIMIT)
        try:
            self.limit = t.Int(gt=0).check(self.limit)
        except t.DataError:
            raise HTTPBadRequest(detail='The limit must be an integer > 0.',
                                 source_parameter='page[limit]')
예제 #3
0
class TestList(unittest.TestCase):

    TRAFARET = T.Dict({
        "hosts": T.List(T.String() & T.Regexp("\w+:\d+")),
    })

    def test_ok(self):
        self.assertEqual(
            get_err(
                self.TRAFARET, u"""\
            hosts:
            - bear:8080
            - cat:7070
            """), None)

    def test_err(self):
        self.assertEqual(
            get_err(
                self.TRAFARET, u"""\
                hosts:
                - bear:8080
                - cat:x
            """), "config.yaml:3: hosts[1]: "
            "does not match pattern \\w+:\\d+\n")
예제 #4
0
from models_library.basic_types import PortInt, VersionTag
from servicelib.application_keys import APP_CLIENT_SESSION_KEY, APP_CONFIG_KEY

CONFIG_SECTION_NAME = "catalog"

_default_values = {
    "host": os.environ.get("CATALOG_HOST", "catalog"),
    "port": int(os.environ.get("CATALOG_PORT", 8000)),
}

schema = T.Dict({
    T.Key("enabled", default=True, optional=True): T.Bool(),
    T.Key("host", default=_default_values["host"]): T.String(),
    T.Key("port", default=_default_values["port"]): T.ToInt(),
    T.Key("version", default="v0"):
    T.Regexp(regexp=r"^v\d+"),  # catalog API version basepath
})


class CatalogSettings(BaseSettings):
    enabled: bool = True
    host: str = "catalog"
    port: PortInt = 8000
    vtag: VersionTag = Field("v0",
                             alias="version",
                             description="Catalog service API's version tag")

    class Config:
        prefix = "CATALOG_"

예제 #5
0
import trafaret as t

from aiohttp.web import HTTPBadRequest
from .fields import Fields

_email_scheme = t.Dict({t.Key(Fields.EMAIL): t.Email})

_password_scheme = t.Dict(
    {t.Key(Fields.PASSWORD): t.Regexp(regexp=r'^[A-Za-z0-9_]{4,64}$')})

_token_scheme = t.Dict({
    t.Key(Fields.TOKEN):
    t.Regexp(regexp=r'^[A-Za-z0-9-_=]+\.[A-Za-z0-9-_=]+\.[A-Za-z0-9-_=]+$'),
})

_ref_token_scheme = t.Dict(
    {t.Key(Fields.REF_TOKEN): t.Regexp(regexp=r'^[0-9a-f]{32}$')})


def _fetcher(data, scheme, key, exc_msg):
    scheme = scheme.allow_extra('*')
    try:
        data = scheme(data)
    except t.DataError as exc:
        print(exc)
        raise HTTPBadRequest(reason=exc_msg)
    return data[key]


def fetch_email(json):
    return _fetcher(json, _email_scheme, Fields.EMAIL, "Invalid email")
예제 #6
0
import pydash as _
import trafaret as t
import datetime
from functools import partial

from jinja2.utils import import_string
from trafaret.contrib.object_id import MongoId
from trafaret.contrib.rfc_3339 import DateTime

Optional = partial(t.Key, optional=True)
SimpleType = t.IntRaw | t.Bool | t.String | t.FloatRaw

DateTimeType = DateTime | t.Type(datetime.datetime)
NumericType = t.Float | t.Int >> (lambda val: float(val))
URLType = t.Regexp(r'^([a-z]{2,5}:)?(\/\/?)?[a-z][a-z0-9\.\-\/]+$')

OptionValue = t.String(
    allow_blank=True) | t.Bool | t.Float | t.Int | t.Type(dict)
Optional = partial(t.Key, optional=True)

SimpleDoc = t.Dict({
    t.Key('id', optional=True) >> '_id': MongoId,
    Optional('_id'): MongoId
})

TimestampDoc = SimpleDoc + t.Dict({
    Optional('created', default=datetime.datetime.now):
    DateTimeType | t.Null,
    Optional('modified'):
    DateTimeType
})
예제 #7
0
                        set(g_dev_type))
            objs_per_group[group_id]['g_smp'] += c_info['smp']
            objs_per_group[group_id]['g_gpu_mem_allocated'] += c_info[
                'gpu_mem_allocated']
            objs_per_group[group_id]['g_gpu_allocated'] += c_info[
                'gpu_allocated']
            objs_per_group[group_id]['c_infos'].append(c_info)
    return list(objs_per_group.values())


@atomic
@server_status_required(READ_ALLOWED)
@superadmin_required
@check_api_params(t.Dict({
    tx.MultiKey('group_ids'): t.List(t.String) | t.Null,
    t.Key('month'): t.Regexp(r'^\d{6}', re.ASCII),
}),
                  loads=_json_loads)
async def usage_per_month(request: web.Request, params: Any) -> web.Response:
    '''
    Return usage statistics of terminated containers for a specified month.
    The date/time comparison is done using the configured timezone.

    :param group_ids: If not None, query containers only in those groups.
    :param month: The year-month to query usage statistics. ex) "202006" to query for Jun 2020
    '''
    log.info('USAGE_PER_MONTH (g:[{}], month:{})',
             ','.join(params['group_ids']), params['month'])
    local_tz = request.app['config']['system']['timezone']
    try:
        start_date = datetime.strptime(params['month'],
예제 #8
0
class TfObjectDetectionModel(object):
    """Implementation for TF Object Detection API Inference"""

    # model's input tensor name from official website
    # https://github.com/tensorflow/models/blob/master/research/object_detection/object_detection_tutorial.ipynb
    input_tensors = {'images': "image_tensor:0"}

    # model's output tensors names from official website
    # https://github.com/tensorflow/models/blob/master/research/object_detection/object_detection_tutorial.ipynb
    output_tensors = {
        "labels": "detection_classes:0",
        "boxes": "detection_boxes:0",
        "scores": "detection_scores:0",
        "masks": "detection_masks:0"
    }

    # default configuration
    _config_schema = t.Dict(
        {
            # path or name of the frozen weights file (*.pb)
            t.Key('weights', default="data/models/*.pb"):
            t.String(min_length=4),
            t.Key('width', default=300):
            t.Int(gt=0),  # input tensor width
            t.Key('height', default=300):
            t.Int(gt=0),  # output tensor width
            t.Key('threshold', default=0.5):
            t.Float(gte=0.0,
                    lte=1.0),  # confidence threshold for detected objects
            # labels dict or file
            t.Key('labels', default={1: 'person'}):
            t.Or(t.Dict({}, allow_extra='*'), t.String(min_length=4)),
            # device to execute graph
            t.Key('device', default='GPU|CPU'):
            t.Regexp(r'GPU\|CPU|CPU(?:\:0)?|GPU(?:\:\d)?') >> _parse_device,
            t.Key('log_device_placement', default=False):
            t.Bool,  # TF specific
            t.Key('per_process_gpu_memory_fraction', default=0.0):
            t.Float(gte=0.0, lte=1.0),  # TF specific
        },
        allow_extra='*')

    def __init__(self, **kwargs):
        # validate config
        try:
            self.config = self._config_schema.check(kwargs or {})
        except t.DataError as err:
            raise ValueError('Wrong model configuration for {}: {}'.format(
                self, err))

        self._session = None  # tf.Session
        self._inputs = None  # typ.Dict[str, tf.Tensor]
        self._outputs = None  # typ.Dict[str, tf.Tensor]
        self._labels = None  # typ.Dict[int, str]

    def __str__(self) -> str:
        return self.__class__.__name__

    def __repr__(self) -> str:
        return '<{}>'.format(self)

    def __enter__(self):
        self.startup()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.shutdown()

    @property
    def log(self) -> logging.Logger:
        return log

    def startup(self):

        self.log.info("Starting %s ...", self)

        self._labels = self.config['labels'] if isinstance(
            self.config['labels'], dict) else load_labels_from_file(
                self.config['labels'])

        if len(self._labels) <= 0:
            raise ValueError(f"Labels can't be empty {self._labels}")

        tf_config = create_config(
            self.config['device'],
            log_device_placement=self.config['log_device_placement'],
            per_process_gpu_memory_fraction=self.
            config['per_process_gpu_memory_fraction'])

        graph = import_graph(parse_graph_def(self.config['weights']),
                             self.config['device'])
        has_masks = self.output_tensors['masks'].replace(":0", "") in set(
            [n.name for n in graph.as_graph_def().node])

        self.log.debug("Model (%s) placed on %s", self.config['weights'],
                       self.config['device'])

        self._session = tf.Session(graph=graph, config=tf_config)

        output_tensors = self.output_tensors
        if not has_masks:
            output_tensors = deepcopy(self.output_tensors)
            output_tensors.pop('masks')

        self._inputs = {
            alias: graph.get_tensor_by_name(name)
            for alias, name in self.input_tensors.items()
        }
        self._outputs = {
            alias: graph.get_tensor_by_name(name)
            for alias, name in output_tensors.items()
        }

        # warm up
        self.log.info("Warming up %s ...", self)
        self.process_single(np.zeros((2, 2, 3), dtype=np.uint8))

    def shutdown(self):
        """ Releases model when object deleted """
        self.log.info("Shutdown %s ...", self)

        if self._session is None:
            return

        try:
            self._session.close()
            self._session = None
        except tf.OpError as err:
            self.log.error('%s close TF session error: %s. Skipping...', self,
                           err)

        self.log.info("%s Destroyed successfully", self)

    def process_single(self, image: np.ndarray) -> typ.List[typ.List[dict]]:
        """Run inference on single image

        Returns: list of detection* per image (ex: process_batch())
        """
        return self.process_batch([image])[0]

    def process_batch(
            self, images: typ.List[np.ndarray]) -> typ.List[typ.List[dict]]:
        """
        Returns: list of detection* per image

        *detection: dict

        | Field name   | Type                          | Description                        |
        |--------------|-------------------------------|------------------------------------|
        | confidence   | float                         | class score                        |
        | bounding_box | typ.Tuple[int, int, int, int] | x, y, width, height wrt to image   |
        | class_name   | str                           | human readable class name          |
        | mask         | np.ndarray[np.bool]           | with shape (box_width, box_height) |

        """

        preprocessed = np.stack([self._preprocess(image) for image in images])

        result = self._session.run(
            self._outputs, feed_dict={self._inputs['images']: preprocessed})

        detections_per_image = []

        has_masks = 'masks' in result
        for i, image in enumerate(images):

            detections = []
            boxes = result['boxes'][i]
            scores, labels = result['scores'][i], result['labels'][i]
            for j in range(len(scores)):
                score = scores[j]

                class_name = self._labels.get(int(labels[j]))
                if not class_name or score < self.config['threshold']:
                    continue

                # resize bounding box wrt to image size
                ymin, xmin, ymax, xmax = (np.tile(image.shape[:2], 2) *
                                          boxes[j]).astype(np.int32).tolist()
                width, height = xmax - xmin, ymax - ymin

                obj = {
                    'confidence': float(score),
                    'bounding_box': [xmin, ymin, width, height],
                    'class_name': class_name,
                }
                if has_masks:
                    # resize mask wrt to bounding box size
                    mask = cv2.resize(result['masks'][i][j], (width, height),
                                      interpolation=cv2.INTER_NEAREST)
                    obj['mask'] = mask >= self.config[
                        'threshold']  # threshold mask

                detections.append(obj)

            detections_per_image.append(detections)

        return detections_per_image

    def _preprocess(self, image: np.ndarray) -> np.ndarray:
        return cv2.resize(image, (self.config['width'], self.config['height']),
                          interpolation=cv2.INTER_NEAREST)

    @classmethod
    def from_config_file(cls, filename: str) -> "TfObjectDetectionModel":
        """
        :param filename: filename to model config
        """
        return cls(**load_config(filename))
예제 #9
0
class TfObjectDetectionModel(object):
    """Implementation for TF Object Detection API Inference"""

    # model's input tensor name from official website
    # https://github.com/tensorflow/models/blob/master/research/object_detection/object_detection_tutorial.ipynb
    input_tensors = {'images': "image_tensor:0"}

    # model's output tensors names from official website
    # https://github.com/tensorflow/models/blob/master/research/object_detection/object_detection_tutorial.ipynb
    output_tensors = {
        "labels": "detection_classes:0",
        "boxes": "detection_boxes:0",
        "scores": "detection_scores:0"
    }

    # default configuration
    _config_schema = t.Dict(
        {
            # path or name of the frozen weights file (*.pb)
            t.Key('weights', default="data/models/*.pb"):
            t.String(min_length=4),
            t.Key('width', default=300):
            t.Int(gt=0),  # input tensor width
            t.Key('height', default=300):
            t.Int(gt=0),  # output tensor width
            t.Key('threshold', default=0.5):
            t.Float(gte=0.0,
                    lte=1.0),  # confidence threshold for detected objects
            # labels dict or file
            t.Key('labels', default={1: 'person'}):
            t.Or(t.Dict({}, allow_extra='*'), t.String(min_length=4)),
            # device to execute graph
            t.Key('device', default='GPU|CPU'):
            t.Regexp(r'GPU\|CPU|CPU(?:\:0)?|GPU(?:\:\d)?') >> _parse_device,
            t.Key('log_device_placement', default=False):
            t.Bool,  # TF specific
            t.Key('per_process_gpu_memory_fraction', default=0.0):
            t.Float(gte=0.0, lte=1.0),  # TF specific
        },
        allow_extra='*')

    def __init__(self, **kwargs):

        # validate config
        try:
            self.config = self._config_schema.check(kwargs or {})
        except t.DataError as err:
            raise ValueError('Wrong model configuration for {}: {}'.format(
                self, err))

        self._session = None  # tf.Session
        self._inputs = None  # typ.Dict[str, tf.Tensor]
        self._outputs = None  # typ.Dict[str, tf.Tensor]
        self._labels = None  # typ.Dict[int, str]

    def __str__(self) -> str:
        return self.__class__.__name__

    def __repr__(self) -> str:
        return '<{}>'.format(self)

    def __enter__(self):
        self.startup()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.shutdown()

    @property
    def log(self) -> logging.Logger:
        return log

    def startup(self):

        self.log.info("Starting %s ...", self)

        self._labels = self.config['labels'] if isinstance(
            self.config['labels'], dict) else load_labels_from_file(
                self.config['labels'])

        if len(self._labels) <= 0:
            raise ValueError(f"Labels can't be empty {self._labels}")

        tf_config = create_config(
            self.config['device'],
            log_device_placement=self.config['log_device_placement'],
            per_process_gpu_memory_fraction=self.
            config['per_process_gpu_memory_fraction'])

        graph = import_graph(parse_graph_def(self.config['weights']),
                             self.config['device'])

        self.log.debug("Model (%s) placed on %s", self.config['weights'],
                       self.config['device'])

        self._session = tf.Session(graph=graph, config=tf_config)

        self._inputs = {
            alias: graph.get_tensor_by_name(name)
            for alias, name in self.input_tensors.items()
        }
        self._outputs = {
            alias: graph.get_tensor_by_name(name)
            for alias, name in self.output_tensors.items()
        }

        # warm up
        self.log.info("Warming up %s ...", self)
        self.process_single(np.zeros((2, 2, 3), dtype=np.uint8))

    def shutdown(self):
        """ Releases model when object deleted """
        self.log.info("Shutdown %s ...", self)

        if self._session is None:
            return

        try:
            self._session.close()
            self._session = None
        except tf.OpError as err:
            self.log.error('%s close TF session error: %s. Skipping...', self,
                           err)

        self.log.info("%s Destroyed successfully", self)

    def process_single(self, image: np.ndarray) -> typ.List[dict]:
        return self.process_batch([image])[0]

    def process_batch(self, images: typ.List[np.ndarray]) -> typ.List[dict]:
        preprocessed = np.stack([self._preprocess(image) for image in images])

        result = self._session.run(
            self._outputs, feed_dict={self._inputs['images']: preprocessed})

        detections_per_image = []

        for image, scores, boxes, labels in zip(images, result['scores'],
                                                result['boxes'],
                                                result['labels']):
            detections = []
            for score, box, label in zip(scores, boxes, labels):

                class_name = self._labels.get(int(label))
                if not class_name or score < self.config['threshold']:
                    continue

                # scale boxes wrt initial image size
                ymin, xmin, ymax, xmax = (np.tile(image.shape[:2], 2) *
                                          box).astype(np.int32).tolist()

                width, height = xmax - xmin, ymax - ymin

                detections.append({
                    'confidence': float(score),
                    'bounding_box': [xmin, ymin, width, height],
                    'class_name': class_name,
                })

            detections_per_image.append(detections)

        return detections_per_image

    def _preprocess(self, image: np.ndarray) -> np.ndarray:
        return cv2.resize(image, (self.config['width'], self.config['height']),
                          interpolation=cv2.INTER_NEAREST)
예제 #10
0
from aiohttp import ClientSession, web

from servicelib.application_keys import APP_CLIENT_SESSION_KEY, APP_CONFIG_KEY

CONFIG_SECTION_NAME = "catalog"


_default_values = {
    "host": os.environ.get("CATALOG_HOST", "catalog"),
    "port": int(os.environ.get("CATALOG_PORT", 8000)),
}

schema = T.Dict(
    {
        T.Key("enabled", default=True, optional=True): T.Bool(),
        T.Key("host", default=_default_values["host"]): T.String(),
        T.Key("port", default=_default_values["port"]): T.ToInt(),
        T.Key("version", default="v0"): T.Regexp(
            regexp=r"^v\d+"
        ),  # catalog API version basepath
    }
)


def get_config(app: web.Application) -> Dict:
    return app[APP_CONFIG_KEY][CONFIG_SECTION_NAME]


def get_client_session(app: web.Application) -> ClientSession:
    return app[APP_CLIENT_SESSION_KEY]
예제 #11
0
 def test_upper(self):
     trafaret = t.Regexp('\w+-\w+') & str.upper
     self.assertEqual(trafaret('abc-Abc'), 'ABC-ABC')
예제 #12
0
CONFIG_FILENAME = 'batch_scoring.ini'


def verify_objectid(value):
    """Verify if id_ is a proper ObjectId. """
    try:
        t.Regexp(regexp='^[A-Fa-f0-9]{24}$').check(value)
    except t.DataError:
        raise ValueError('id {} not a valid project/model id'.format(value))


config_validator = t.Dict({
    OptKey('host'):
    t.String,
    OptKey('project_id'):
    t.Regexp(regexp='^[A-Fa-f0-9]{24}$'),
    OptKey('model_id'):
    t.Regexp(regexp='^[A-Fa-f0-9]{24}$'),
    OptKey('deployment_id'):
    t.Regexp(regexp='^[A-Fa-f0-9]{24}$'),
    OptKey('import_id'):
    t.String,
    OptKey('n_retry'):
    t.Int,
    OptKey('keep_cols'):
    t.String,
    OptKey('n_concurrent'):
    t.Int,
    OptKey('dataset'):
    t.String,
    OptKey('n_samples'):
예제 #13
0
def verify_objectid(value):
    """Verify if id_ is a proper ObjectId. """
    try:
        t.Regexp(regexp='^[A-Fa-f0-9]{24}$').check(value)
    except t.DataError:
        raise ValueError('id {} not a valid project/model id'.format(value))
예제 #14
0
    - config-file schema
    - settings
"""
from typing import Dict

import trafaret as T
from aiohttp import ClientSession, web

from servicelib.application_keys import APP_CLIENT_SESSION_KEY, APP_CONFIG_KEY

CONFIG_SECTION_NAME = "storage"

schema = T.Dict(
    {
        T.Key("enabled", default=True, optional=True): T.Bool(),
        T.Key("host", default="storage"): T.String(),
        T.Key("port", default=11111): T.ToInt(),
        T.Key("version", default="v0"): T.Regexp(
            regexp=r"^v\d+"
        ),  # storage API version basepath
    }
)


def get_config(app: web.Application) -> Dict:
    return app[APP_CONFIG_KEY][CONFIG_SECTION_NAME]


def get_client_session(app: web.Application) -> ClientSession:
    return app[APP_CLIENT_SESSION_KEY]
예제 #15
0
    tx.AliasedKey(['clusterSize', 'cluster_size'], default=None):
    t.Null | t.Int[1:],
    tx.AliasedKey(['scalingGroup', 'scaling_group'], default=None):
    t.Null | t.String,
    t.Key('resources', default=None):
    t.Null | t.Mapping(t.String, t.Any),
    t.Key('resource_opts', default=None):
    t.Null | t.Mapping(t.String, t.Any),
})


@server_status_required(ALL_ALLOWED)
@auth_required
@check_api_params(t.Dict({
    t.Key('clientSessionToken') >> 'sess_id':
    t.Regexp(r'^(?=.{4,64}$)\w[\w.-]*\w$', re.ASCII),
    tx.AliasedKey(['image', 'lang']):
    t.String,
    tx.AliasedKey(['group', 'groupName', 'group_name'], default='default'):
    t.String,
    tx.AliasedKey(['domain', 'domainName', 'domain_name'], default='default'):
    t.String,
    t.Key('config', default=dict):
    t.Mapping(t.String, t.Any),
    t.Key('tag', default=None):
    t.Null | t.String,
}),
                  loads=_json_loads)
async def create(request: web.Request, params: Any) -> web.Response:
    if params['domain'] is None:
        params['domain'] = request['user']['domain_name']
예제 #16
0
""" director subsystem's configuration

    - config-file schema
    - settings
"""

import trafaret as T

APP_DIRECTOR_API_KEY = __name__ + ".director_api"

CONFIG_SECTION_NAME = "director"

# TODO: deprecate trafaret schema
schema = T.Dict(
    {
        T.Key("enabled", default=True, optional=True): T.Bool(),
        T.Key(
            "host",
            default="director",
        ): T.String(),
        T.Key("port", default=8001): T.ToInt(),
        T.Key("version", default="v0"): T.Regexp(
            regexp=r"^v\d+"
        ),  # director API version basepath
    }
)
예제 #17
0
keywords = (
    t.Key('enum', optional=True, trafaret=t.List(t.Any) & (lambda consts: t.Or(*(t.Atom(cnst) for cnst in consts)))),
    t.Key('const', optional=True, trafaret=t.Any() & then(t.Atom)),
    t.Key('type', optional=True, trafaret=ensure_list(json_schema_type) & then(Any)),

    # number validation
    t.Key('multipleOf', optional=True, trafaret=t.Float(gt=0) & then(multipleOf)),
    t.Key('maximum', optional=True, trafaret=t.Float() & (lambda maximum: t.Float(lte=maximum))),
    t.Key('exclusiveMaximum', optional=True, trafaret=t.Float() & (lambda maximum: t.Float(lt=maximum))),
    t.Key('minimum', optional=True, trafaret=t.Float() & (lambda minimum: t.Float(gte=minimum))),
    t.Key('exclusiveMinimum', optional=True, trafaret=t.Float() & (lambda minimum: t.Float(gt=minimum))),

    # string
    t.Key('maxLength', optional=True, trafaret=t.Int(gte=0) & (lambda length: t.String(max_length=length))),
    t.Key('minLength', optional=True, trafaret=t.Int(gte=0) & (lambda length: t.String(min_length=length))),
    t.Key('pattern', optional=True, trafaret=Pattern() & (lambda pattern: t.Regexp(pattern))),

    # array
    t.Key('maxItems', optional=True, trafaret=t.Int(gte=0) & (lambda length: t.List(t.Any, max_length=length))),
    t.Key('minItems', optional=True, trafaret=t.Int(gte=0) & (lambda length: t.List(t.Any, min_length=length))),
    t.Key(
        'uniqueItems',
        optional=True,
        trafaret=t.Bool() & (lambda need_check: t.List(t.Any) & uniq if need_check else t.Any)
    ),

    # object
    t.Key(
        'maxProperties',
        optional=True,
        trafaret=(
예제 #18
0
    query_domain_dotfiles,
    verify_dotfile_name,
    MAXIMUM_DOTFILE_SIZE,
)

log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.dotfile'))


@server_status_required(READ_ALLOWED)
@admin_required
@check_api_params(
    t.Dict({
        t.Key('domain'): t.String,
        t.Key('data'): t.String(max_length=MAXIMUM_DOTFILE_SIZE),
        t.Key('path'): t.String,
        t.Key('permission'): t.Regexp(r'^[0-7]{3}$', re.ASCII),
    }))
async def create(request: web.Request, params: Any) -> web.Response:
    log.info('CREATE DOTFILE (domain: {0})', params['domain'])
    if not request['is_superadmin'] and request['user'][
            'domain_name'] != params['domain']:
        raise GenericForbidden(
            'Domain admins cannot create dotfiles of other domains')

    dbpool = request.app['dbpool']
    async with dbpool.acquire() as conn, conn.begin():
        dotfiles, leftover_space = await query_domain_dotfiles(
            conn, params['domain'])
        if dotfiles is None:
            raise DomainNotFound('Input domain is not found')
        if leftover_space == 0:
예제 #19
0
 def test_auto_call(self):
     import functools
     to_int_10000 = functools.partial(int, '10000')
     trafaret = t.Regexp('2|10|16') & t.Int & to_int_10000
     self.assertEqual(trafaret('10'), 10000)
예제 #20
0
 def test_regexp(self):
     trafaret = t.Regexp('cat')
     assert trafaret('cat1212'), 'cat'
예제 #21
0
from servicelib.application_keys import APP_CONFIG_KEY, APP_CLIENT_SESSION_KEY
from yarl import URL

APP_DIRECTOR_API_KEY = __name__ + ".director_api"

CONFIG_SECTION_NAME = "director"

schema = T.Dict({
    T.Key("enabled", default=True, optional=True): T.Bool(),
    T.Key(
        "host",
        default="director",
    ): T.String(),
    T.Key("port", default=8001): T.ToInt(),
    T.Key("version", default="v0"):
    T.Regexp(regexp=r"^v\d+"),  # storage API version basepath
})


def build_api_url(config: Dict) -> URL:
    api_baseurl = URL.build(scheme="http",
                            host=config["host"],
                            port=config["port"]).with_path(config["version"])
    return api_baseurl


def get_config(app: web.Application) -> Dict:
    return app[APP_CONFIG_KEY][CONFIG_SECTION_NAME]


def get_client_session(app: web.Application) -> ClientSession:
예제 #22
0
import logging
import os
import pathlib

import trafaret as t
import yaml

logger = logging.getLogger(__name__)

BASE_DIR = pathlib.Path(__file__).parent.parent
DEFAULT_CONFIG = 'config.yml'
CITIZENS_COLLECTION = 'requests_dump'

config_template = t.Dict({
    t.Key('mongo'): t.Dict({
        t.Key('host'): t.Regexp(regexp=r'^\d+.\d+.\d+.\d+$'),
        t.Key('port'): t.Int(gte=0),
        t.Key('database'): t.String(),
        t.Key('max_pool_size'): t.Int(gte=0),
    }),
    t.Key('host'): t.Regexp(regexp=r'^\d+.\d+.\d+.\d+$'),
    t.Key('port'): t.Int(gte=0),
    t.Key('proxy-port'): t.Int(gte=0),
})


def get_config(mode=DEFAULT_CONFIG):
    config_path = os.path.join(BASE_DIR, 'config', mode)

    with open(config_path, 'rt') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
예제 #23
0
 def test_upper(self):
     trafaret = t.Regexp(r'\w+-\w+') & str.upper
     assert trafaret('abc-Abc') == 'ABC-ABC'
예제 #24
0
from servicelib.aiohttp.application_keys import APP_CONFIG_KEY

APP_DIRECTOR_API_KEY = __name__ + ".director_api"

CONFIG_SECTION_NAME = "director"

# TODO: deprecate trafaret schema
schema = T.Dict({
    T.Key("enabled", default=True, optional=True): T.Bool(),
    T.Key(
        "host",
        default="director",
    ): T.String(),
    T.Key("port", default=8001): T.ToInt(),
    T.Key("version", default="v0"):
    T.Regexp(regexp=r"^v\d+"),  # director API version basepath
})


class DirectorSettings(BaseSettings):
    enabled: bool = True
    host: str = "director"
    port: PortInt = 8001
    vtag: VersionTag = Field("v0",
                             alias="version",
                             description="Director service API's version tag")

    url: Optional[AnyHttpUrl] = None

    @validator("url", pre=True)
    @classmethod
예제 #25
0
 def test_callable(self):
     import functools
     to_int_10000 = functools.partial(int, '10000')
     trafaret = t.Regexp('2|10|16') & t.ToInt & t.Call(to_int_10000)
     assert trafaret('10') == 10000
예제 #26
0
from yarl import URL

APP_DIRECTOR_SESSION_KEY = __name__ + ".director_session"
APP_DIRECTOR_API_KEY = __name__ + ".director_api"

CONFIG_SECTION_NAME = 'director'

schema = T.Dict({
    T.Key("enabled", default=True, optional=True): T.Bool(),
    T.Key(
        "host",
        default="director",
    ): T.String(),
    T.Key("port", default=8001): T.Int(),
    T.Key("version", default="v0"):
    T.Regexp(regexp=r'^v\d+')  # storage API version basepath
})


def build_api_url(config: Dict) -> URL:
    api_baseurl = URL.build(scheme='http',
                            host=config['host'],
                            port=config['port']).with_path(config["version"])
    return api_baseurl


def get_config(app: web.Application) -> Dict:
    return app[APP_CONFIG_KEY][CONFIG_SECTION_NAME]


def get_client_session(app: web.Application) -> ClientSession:
예제 #27
0
 def test_regexp(self):
     trafaret = t.Regexp('cat')
     self.assertEqual(trafaret('cat1212'), 'cat')
예제 #28
0

CONFIG_FILENAME = 'batch_scoring.ini'


def verify_objectid(value):
    """Verify if id_ is a proper ObjectId. """
    try:
        t.Regexp(regexp='^[A-Fa-f0-9]{24}$').check(value)
    except t.DataError:
        raise ValueError('id {} not a valid project/model id'.format(value))


config_validator = t.Dict({
    OptKey('host'): t.String,
    OptKey('project_id'): t.Regexp(regexp='^[A-Fa-f0-9]{24}$'),
    OptKey('model_id'): t.Regexp(regexp='^[A-Fa-f0-9]{24}$'),
    OptKey('deployment_id'): t.Regexp(regexp='^[A-Fa-f0-9]{24}$'),
    OptKey('import_id'): t.String,
    OptKey('n_retry'): t.Int,
    OptKey('keep_cols'): t.String,
    OptKey('n_concurrent'): t.Int,
    OptKey('dataset'): t.String,
    OptKey('n_samples'): t.Int,
    OptKey('delimiter'): t.String,
    OptKey('out'): t.String,
    OptKey('user'): t.String,
    OptKey('password'): t.String,
    OptKey('datarobot_key'): t.String,
    OptKey('timeout'): t.Int,
    OptKey('api_token'): t.String,