示例#1
0
class CustomerResource(ModelResource, CustomerMixin):
    model = Customer

    method_decorators = {'delete': roles_required('admin')}

    validation = t.Dict({
        'first_name': t.String,
        'last_name': t.String,
        'email': t.Email,
        'phone': t.Or(t.String(allow_blank=True), t.Null),
        'notes': t.Or(t.String(allow_blank=True), t.Null),
        'sex': t.String,
        'birthdate': t.DateTime,
    }).make_optional('phone', 'notes').ignore_extra('*')

    @method_wrapper(http.ACCEPTED)
    def put(self, id):
        data = self.clean(g.request_data)
        instance = self.get_object(id).update(with_reload=True, **data)
        return self.serialize(instance)

    def get_objects(self, **kwargs):
        if current_user.is_anonymous() or not current_user.is_superuser():
            kwargs['id'] = self._customer.id
        return super(CustomerResource, self).get_objects(**kwargs)
示例#2
0
async def identity_user(*, conn: SAConnection, credentials_data: dict) -> User:
    """
    Check if user exist in database
    :param conn: connection to database
    :param credentials_data: dict user credentials
    :return: dict with user data
    """
    credentials_format = t.Dict({
        t.Key('username'): t.Or(t.String, t.Int),
        t.Key('password'): t.Or(t.String, t.Int)
    })

    try:
        credentials_data = validate(data_to_check=credentials_data,
                                    trafaret_format=credentials_format)
    except app_exceptions.ValidateDataError:
        raise auth_exceptions.AuthenticateNoCredentials

    # look for user in database
    try:
        user = await get_user(conn=conn, username=credentials_data['username'])
    except app_exceptions.DoesNotExist:
        raise auth_exceptions.AuthenticateErrorCredentials

    # check user password
    if not validate_password(password=credentials_data['password'],
                             password_hash=user['password']):
        raise auth_exceptions.AuthenticateErrorCredentials

    return User(user)  # type: ignore
示例#3
0
class AddressResource(ModelResource, CustomerMixin):
    model = Address
    validation = t.Dict({
        'country_id':
        t.Int,
        'apartment':
        t.Or(t.String(allow_blank=True), t.Null),
        'city':
        t.String,
        'street':
        t.String,
        'zip_code':
        t.String,
        'type':
        t.Or(t.String(allow_blank=True, regex="(billing|delivery)"), t.Null),
    }).make_optional('apartment', 'type').ignore_extra('*')

    @method_wrapper(http.CREATED)
    def post(self):
        data = self.clean(g.request_data)
        address_type = data.pop('type')
        address = self.model.create(**data)
        self._customer.set_address(address_type, address)
        return self.serialize(address)

    def get_objects(self, **kwargs):
        """ Method for extraction object list query
        """
        if current_user.is_anonymous() or not current_user.is_superuser():
            kwargs['customer_id'] = self._customer.id

        return super(AddressResource, self).get_objects(**kwargs)
示例#4
0
 def from_row(cls, row, imageable: ImageableType, alias=None):
     table = cls.Options.db_table if alias is None else alias
     return cls(
         id=t.Int().check(row[table.c.id]),
         title=t.String().check(row[table.c.title]),
         uri=str(t.URL.check(row[table.c.uri])),
         imageable=t.Or(t.Type(Author),
                        t.Type(Book),
                        t.Type(Series)).check(imageable),
         created_at=DateTime().check(row[table.c.created_at]),
         updated_at=t.Or(DateTime, t.Null).check(row[table.c.updated_at]),
         is_populated=True
     )
示例#5
0
 def from_row(cls, row, books: List['Book'], photos: List['Photo'],
              alias=None):
     table = cls.Options.db_table if alias is None else alias
     return cls(
         id=t.Int().check(row[table.c.id]),
         name=t.String().check(row[table.c.name]),
         date_of_birth=Date().check(row[table.c.date_of_birth]),
         date_of_death=t.Or(Date, t.Null).check(row[table.c.date_of_death]),
         books=t.List(t.Type(Book)).check(books),
         photos=t.List(t.Type(Photo)).check(photos),
         created_at=DateTime().check(row[table.c.created_at]),
         updated_at=t.Or(DateTime, t.Null).check(row[table.c.updated_at]),
         is_populated=True
     )
示例#6
0
 def test_nullable_datetime(self):
     nullable_datetime = t.Or(DateTime, t.Null)
     assert nullable_datetime.check(None) is None
     assert nullable_datetime.check(datetime.datetime(
         2017, 9, 1, 23, 59)) == datetime.datetime(2017, 9, 1, 23, 59)
     assert nullable_datetime.check(
         '2017-09-01 23:59') == datetime.datetime(2017, 9, 1, 23, 59)
示例#7
0
 def test_nullable_date(self):
     nullable_date = t.Or(Date, t.Null)
     assert nullable_date.check(None) is None
     assert nullable_date.check(datetime.date(1954, 7,
                                              29)) == datetime.date(
                                                  1954, 7, 29)
     assert nullable_date.check('1954-07-29') == datetime.date(1954, 7, 29)
示例#8
0
 def from_row(cls, row, author: 'Author', photos: List['Photo'],
              chapters: List['Chapter'], series: 'Series' = None,
              alias=None):
     table = cls.Options.db_table if alias is None else alias
     return cls(
         id=t.Int().check(row[table.c.id]),
         title=t.String().check(row[table.c.title]),
         date_published=Date().check(row[table.c.date_published]),
         author=t.Type(Author).check(author),
         photos=t.List(t.Type(Photo)).check(photos),
         chapters=t.List(t.Type(Chapter)).check(chapters),
         series=t.Or(t.Type(Series), t.Null).check(series),
         created_at=DateTime().check(row[table.c.created_at]),
         updated_at=t.Or(DateTime, t.Null).check(row[table.c.updated_at]),
         is_populated=True
     )
示例#9
0
 def test_nullable_date(self):
     nullable_date = t.Or(Date, t.Null)
     self.assertIsNone(nullable_date.check(None))
     self.assertEqual(nullable_date.check(datetime.date(1954, 7, 29)),
                      datetime.date(1954, 7, 29))
     self.assertEqual(nullable_date.check('1954-07-29'),
                      datetime.date(1954, 7, 29))
示例#10
0
    def _validate_outcome_dicts(self) -> None:
        template = Dict({
            t.Key('id'): t.Int,
            t.Key('producer'): t.String,
            t.Key('model'): t.String,
            t.Key('image_id'): t.Or(t.String(allow_blank=True), t.Null),
            t.Key('description'): t.String(allow_blank=True),
            t.Key('priori_probability'): t.Float,
            t.Key('questions_estimation'): Dict({
                t.Key(1): Dict({
                    t.Key('probability_in_presence'): t.Float,
                    t.Key('probability_in_absence'): t.Float,
                }),
            }).allow_extra('*'),
        })

        if not isinstance(self.outcomes, list):
            raise TypeError('The outcomes variable must be a list instance')

        number_of_products = len(self.outcomes)

        for current_product_number, outcome in enumerate(self.outcomes, 1):
            try:
                template.check(outcome)
            except t.DataError as error:
                self.log.error('Validation error occurred: {err_msg}'.format(err_msg=error))
                raise OutcomesValidationException('Validation error occurred: {}'.format(error))
            else:
                self.log.debug('({}/{}) outcome dictionaries checked.'.format(
                    current_product_number, number_of_products))
        self.log.debug('Outcome dictionaries are OK')
示例#11
0
def create_schema() -> T.Dict:
    """
    Build schema for the configuration's file
    by aggregating all the subsystem configurations
    """
    # pylint: disable=protected-access
    schema = T.Dict(
        {
            "version": T.String(),
            "main": T.Dict(
                {
                    "host": T.IP,
                    "port": T.ToInt(),
                    "client_outdir": T.String(),
                    "log_level": T.Enum(*logging._nameToLevel.keys()),
                    "testing": T.Bool(),
                    T.Key("studies_access_enabled", default=False): T.Or(
                        T.Bool(), T.ToInt
                    ),
                }
            ),
            addon_section(tracing.tracing_section_name, optional=True): tracing.schema,
            db_config.CONFIG_SECTION_NAME: db_config.schema,
            director_config.CONFIG_SECTION_NAME: director_config.schema,
            rest_config.CONFIG_SECTION_NAME: rest_config.schema,
            projects_config.CONFIG_SECTION_NAME: projects_config.schema,
            email_config.CONFIG_SECTION_NAME: email_config.schema,
            storage_config.CONFIG_SECTION_NAME: storage_config.schema,
            addon_section(
                login_config.CONFIG_SECTION_NAME, optional=True
            ): login_config.schema,
            addon_section(
                socketio_config.CONFIG_SECTION_NAME, optional=True
            ): socketio_config.schema,
            session_config.CONFIG_SECTION_NAME: session_config.schema,
            activity_config.CONFIG_SECTION_NAME: activity_config.schema,
            resource_manager_config.CONFIG_SECTION_NAME: resource_manager_config.schema,
            # BELOW HERE minimal sections until more options are needed
            addon_section("diagnostics", optional=True): minimal_addon_schema(),
            addon_section("users", optional=True): minimal_addon_schema(),
            addon_section("groups", optional=True): minimal_addon_schema(),
            addon_section("tags", optional=True): minimal_addon_schema(),
            addon_section("publications", optional=True): minimal_addon_schema(),
            addon_section("catalog", optional=True): catalog_config.schema,
            addon_section("products", optional=True): minimal_addon_schema(),
            addon_section("computation", optional=True): minimal_addon_schema(),
            addon_section("director-v2", optional=True): minimal_addon_schema(),
            addon_section("studies_access", optional=True): minimal_addon_schema(),
            addon_section("studies_dispatcher", optional=True): minimal_addon_schema(),
        }
    )

    section_names = [k.name for k in schema.keys]

    # fmt: off
    assert len(section_names) == len(set(section_names)), f"Found repeated section names in {section_names}"  # nosec
    # fmt: on

    return schema
示例#12
0
 def test_or(self):
     null_string = t.Or(t.String, t.Null)
     res = null_string.check(None)
     assert res is None
     res = null_string.check(u"test")
     assert res == u'test'
     res = extract_error(null_string, 1)
     assert res == {0: 'value is not a string', 1: 'value should be None'}
示例#13
0
    def __init__(self, allow_extra):
        self.schema = t.Dict({
            'id': t.Int(),
            'client_name': t.String(max_length=255),
            'sort_index': t.Float,
            # t.Key('client_email', optional=True): t.Or(t.Null | t.Email()),
            t.Key('client_phone', optional=True): t.Or(t.Null | t.String(max_length=255)),

            t.Key('location', optional=True): t.Or(t.Null | t.Dict({
                'latitude': t.Or(t.Float | t.Null),
                'longitude': t.Or(t.Float | t.Null),
            })),

            t.Key('contractor', optional=True): t.Or(t.Null | t.Int(gt=0)),
            t.Key('upstream_http_referrer', optional=True): t.Or(t.Null | t.String(max_length=1023)),
            t.Key('grecaptcha_response'): t.String(min_length=20, max_length=1000),

            t.Key('last_updated', optional=True): t.Or(t.Null | t.String >> parse),

            t.Key('skills', default=[]): t.List(t.Dict({
                'subject': t.String,
                'subject_id': t.Int,
                'category': t.String,
                'qual_level': t.String,
                'qual_level_id': t.Int,
                t.Key('qual_level_ranking', default=0): t.Float,
            })),
        })
        if allow_extra:
            self.schema.allow_extra('*')
示例#14
0
 def test_nullable_datetime(self):
     nullable_datetime = t.Or(DateTime, t.Null)
     self.assertIsNone(nullable_datetime.check(None))
     self.assertEqual(
         nullable_datetime.check(datetime.datetime(2017, 9, 1, 23, 59)),
         datetime.datetime(2017, 9, 1, 23, 59)
     )
     self.assertEqual(nullable_datetime.check('2017-09-01 23:59'),
                      datetime.datetime(2017, 9, 1, 23, 59))
示例#15
0
    def put(self, id):
        # we should check for password matching if user is trying to update it
        self.validation = t.Dict({
            'first_name':
            t.String,
            'last_name':
            t.String,
            'phone':
            t.Or(t.Null, t.String),
            'roles':
            self._roles_list(id),
            'avatar_id':
            t.Or(t.Null, t.String),
            te.KeysSubset('password', 'confirmation'):
            self._cmp_pwd,
        }).make_optional('roles', 'avatar_id').ignore_extra('*')

        return super(ProfileResource, self).put(id)
示例#16
0
async def create_user(*, conn: SAConnection, user_data: dict) -> User:
    """
    Create user object in database
    :param conn: connector to database
    :param user_data: dict with user data
    :return:
    """
    user_format = t.Dict({
        t.Key('username'): t.Or(t.String, t.Int),
        t.Key('password'): t.Or(t.String, t.Int)
    })

    user_data = validate(data_to_check=user_data, trafaret_format=user_format)
    user_data['password'] = generate_password_hash(
        password=user_data['password'])

    user = await create_objects(conn=conn, table=users, data=user_data)
    return User(user[0])  # type: ignore
示例#17
0
class CustomerResource(ModelResource, CustomerMixin):
    model = Customer

    method_decorators = {'delete': roles_required('admin')}

    validation = t.Dict({
        'first_name': t.String,
        'last_name': t.String,
        'email': t.Email,
        'phone': t.Or(t.String(allow_blank=True), t.Null),
        'notes': t.Or(t.String(allow_blank=True), t.Null),
        'sex': t.String,
        'birthdate': t.DateTime,
    }).make_optional('phone', 'notes').ignore_extra('*')

    @method_wrapper(http.CREATED)
    def post(self):
        try:
            return super(CustomerResource, self).post()
        except CustomerIsTooOldError:
            self._raise_too_old_customer_error()

    @method_wrapper(http.ACCEPTED)
    def put(self, id):
        data = self.clean(g.request_data)

        try:
            instance = self.get_object(id).update(with_reload=True, **data)
            return self.serialize(instance)
        except CustomerIsTooOldError:
            self._raise_too_old_customer_error()

    def _raise_too_old_customer_error(self):
        raise t.DataError({
            'birthdate':
            _('must not be earlier than {}'.format(
                Customer.MIN_BIRTHDATE_YEAR))
        })

    def get_objects(self, **kwargs):
        if current_user.is_anonymous() or not current_user.is_superuser():
            kwargs['id'] = self._customer.id
        return super(CustomerResource, self).get_objects(**kwargs)
示例#18
0
class BaseModel(Model):
    __abstract__ = True
    inc_id = True
    structure = t.Dict({
        'name': t.String,
        'quantity': t.Int,
        'attrs': t.Mapping(t.String, t.Or(t.Int, t.Float, t.String)),
    }).allow_extra('_id', '_ns', '_int_id').ignore_extra('wrong_attr')
    indexes = ['id']
    required_fields = ['name', 'quantity']
示例#19
0
 def test_or(self):
     nullString = t.Or(t.String, t.Null)
     self.assertEqual(repr(nullString), '<Or(<String>, <Null>)>')
     res = nullString.check(None)
     res = nullString.check("test")
     self.assertEqual(res, 'test')
     res = extract_error(nullString, 1)
     self.assertEqual(res, {0: 'value is not a string', 1: 'value should be None'})
     res = t.Or << t.Int << t.String
     self.assertEqual(repr(res), '<Or(<Int>, <String>)>')    
示例#20
0
class ImageFinger(
        namedtuple('ImageFinger', 'id,vectors,message,file_id,chat_id'),
        StorableMix):
    collection = 'images'
    trafaret = t.Dict({
        'id': t.Or(t.String | MongoId(allow_blank=True)),
        'vectors': t.Any,
        'message': t.Dict().allow_extra('*'),
        'file_id': t.String,
        'chat_id': t.Int
    })
示例#21
0
 def from_row(cls, row, book: 'Book', alias=None):
     table = cls.Options.db_table if alias is None else alias
     return cls(
         id=t.Int().check(row[table.c.id]),
         title=t.String().check(row[table.c.title]),
         ordering=t.Int().check(row[table.c.ordering]),
         book=t.Type(Book).check(book),
         created_at=DateTime().check(row[table.c.created_at]),
         updated_at=t.Or(DateTime, t.Null).check(row[table.c.updated_at]),
         is_populated=True
     )
示例#22
0
class BaseProduct(Model):
    __abstract__ = True
    structure = t.Dict({
        'name':
        t.String,
        'quantity':
        t.Int,
        'attrs':
        t.Mapping(t.String, t.Or(t.Int, t.Float, t.String)),
    }).allow_extra('*')
    i18n = ['name', 'attrs']
    indexes = ['id']
示例#23
0
 def test_2_0_regression(self):
     t_request = t.Dict({
         t.Key('params', optional=True):
         t.Or(t.List(t.Any()), t.Mapping(t.AnyString(), t.Any())),
     })
     assert t_request.check({'params': {
         'aaa': 123
     }}) == {
         'params': {
             'aaa': 123
         }
     }
示例#24
0
class Reaction(
        namedtuple(
            'BaseReaction',
            'id,patterns,image_url,image_id,text,created_at,created_by,last_used'
        ), StorableMix):
    collection = 'reactions'

    trafaret = t.Dict({
        'id': t.Or(t.String | MongoId(allow_blank=True)),
        'patterns': t.List(t.String, min_length=1),
        'image_url': t.URL(allow_blank=True),
        'image_id': t.String(allow_blank=True),
        'text': t.String(allow_blank=True),
        'created_at': t.Int,
        'created_by': User.trafaret,
        t.Key('last_used', default=0): t.Int,
    }).make_optional('image_id', 'image_url', 'text', 'last_used')

    @classmethod
    @inject.params(db=AsyncIOMotorDatabase)
    def find_by_pattern(cls, patterns, db=None):
        return db[cls.collection].find({'patterns': {'$in': patterns}})

    @inject.params(db=AsyncIOMotorDatabase)
    def update_usage(self, db=None):
        epoch_now = int(time.time())
        return db[self.collection].update({'_id': self.id},
                                          {'$set': {
                                              'last_used': epoch_now
                                          }})

    @property
    @inject.params(config=Config)
    def on_hold(self, config=None):
        epoch_now = int(time.time())
        return self.last_used >= (epoch_now - config.reaction_threshold * 60)
示例#25
0
        ),
        T.Key("callback_game", optional=True): T.String,
        T.Key("pay", optional=True): T.String,
        # T.Key("callback_strategy", optional=True, default="sum"): T.String
    }
)

MARKUP = T.Dict(
    {
        "name": T.String(min_length=1),
        "caption": T.String(min_length=1),
        "sign": T.String(max_length=1),
        T.Key("row_width", optional=True, default=1): T.Int(),
        T.Key("next", optional=True, default=""): T.String(allow_blank=True),
        T.Key("buttons", optional=True): T.List(
            T.Or(BUTTON, T.Dict({"name": T.String(min_length=1)}))
        ),
    }
)

CONFIG = T.Dict({"callback_strategy": T.String, "delimiter": T.String(min_length=1)})

TRAFARET = T.Dict(
    {
        T.Key("markups"): T.List(MARKUP),
        T.Key("buttons", optional=True): T.List(BUTTON),
        T.Key("config", optional=True): CONFIG,
    }
)

示例#26
0
from models_library.settings.redis import RedisConfig
from servicelib.application_keys import APP_CONFIG_KEY

CONFIG_SECTION_NAME = "resource_manager"
APP_CLIENT_REDIS_CLIENT_KEY = __name__ + ".resource_manager.redis_client"
APP_CLIENT_REDIS_LOCK_KEY = __name__ + ".resource_manager.redis_lock"
APP_CLIENT_SOCKET_REGISTRY_KEY = __name__ + ".resource_manager.registry"
APP_RESOURCE_MANAGER_TASKS_KEY = __name__ + ".resource_manager.tasks.key"
APP_GARBAGE_COLLECTOR_KEY = __name__ + ".resource_manager.garbage_collector_key"

# lock names and format strings
GUEST_USER_RC_LOCK_FORMAT = f"{__name__}:redlock:garbage_collect_user:{{user_id}}"

schema = T.Dict({
    T.Key("enabled", default=True, optional=True):
    T.Or(T.Bool(), T.ToInt()),
    T.Key("resource_deletion_timeout_seconds", default=900, optional=True):
    T.ToInt(),
    T.Key("garbage_collection_interval_seconds", default=30, optional=True):
    T.ToInt(),
    T.Key("redis", optional=False):
    T.Dict({
        T.Key("enabled", default=True, optional=True): T.Bool(),
        T.Key("host", default="redis", optional=True): T.String(),
        T.Key("port", default=6793, optional=True): T.ToInt(),
    }),
})


class RedisSection(RedisConfig):
    enabled: bool = True
示例#27
0
import aiozipkin as az
import trafaret as T
from aiohttp import web
from pydantic import AnyHttpUrl, BaseSettings

log = logging.getLogger(__name__)


def setup_tracing(app: web.Application, app_name: str, host: str, port: str,
                  config: Dict) -> bool:
    zipkin_address = f"{config['zipkin_endpoint']}/api/v2/spans"
    endpoint = az.create_endpoint(app_name, ipv4=host, port=port)
    loop = asyncio.get_event_loop()
    tracer = loop.run_until_complete(
        az.create(zipkin_address, endpoint, sample_rate=1.0))
    az.setup(app, tracer)
    return True


schema = T.Dict({
    T.Key("enabled", default=True, optional=True):
    T.Or(T.Bool(), T.ToInt),
    T.Key("zipkin_endpoint", default="http://jaeger:9411"):
    T.String(),
})


class TracingSettings(BaseSettings):
    enabled: Optional[bool] = True
    zipkin_endpoint: AnyHttpUrl = "http://jaeger:9411"
class TfObjectDetectionModel(object):
    """Implementation for TF Object Detection API Inference"""

    # model's input tensor name from official website
    # https://github.com/tensorflow/models/blob/master/research/object_detection/object_detection_tutorial.ipynb
    input_tensors = {'images': "image_tensor:0"}

    # model's output tensors names from official website
    # https://github.com/tensorflow/models/blob/master/research/object_detection/object_detection_tutorial.ipynb
    output_tensors = {
        "labels": "detection_classes:0",
        "boxes": "detection_boxes:0",
        "scores": "detection_scores:0"
    }

    # default configuration
    _config_schema = t.Dict(
        {
            # path or name of the frozen weights file (*.pb)
            t.Key('weights', default="data/models/*.pb"):
            t.String(min_length=4),
            t.Key('width', default=300):
            t.Int(gt=0),  # input tensor width
            t.Key('height', default=300):
            t.Int(gt=0),  # output tensor width
            t.Key('threshold', default=0.5):
            t.Float(gte=0.0,
                    lte=1.0),  # confidence threshold for detected objects
            # labels dict or file
            t.Key('labels', default={1: 'person'}):
            t.Or(t.Dict({}, allow_extra='*'), t.String(min_length=4)),
            # device to execute graph
            t.Key('device', default='GPU|CPU'):
            t.Regexp(r'GPU\|CPU|CPU(?:\:0)?|GPU(?:\:\d)?') >> _parse_device,
            t.Key('log_device_placement', default=False):
            t.Bool,  # TF specific
            t.Key('per_process_gpu_memory_fraction', default=0.0):
            t.Float(gte=0.0, lte=1.0),  # TF specific
        },
        allow_extra='*')

    def __init__(self, **kwargs):

        # validate config
        try:
            self.config = self._config_schema.check(kwargs or {})
        except t.DataError as err:
            raise ValueError('Wrong model configuration for {}: {}'.format(
                self, err))

        self._session = None  # tf.Session
        self._inputs = None  # typ.Dict[str, tf.Tensor]
        self._outputs = None  # typ.Dict[str, tf.Tensor]
        self._labels = None  # typ.Dict[int, str]

    def __str__(self) -> str:
        return self.__class__.__name__

    def __repr__(self) -> str:
        return '<{}>'.format(self)

    def __enter__(self):
        self.startup()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.shutdown()

    @property
    def log(self) -> logging.Logger:
        return log

    def startup(self):

        self.log.info("Starting %s ...", self)

        self._labels = self.config['labels'] if isinstance(
            self.config['labels'], dict) else load_labels_from_file(
                self.config['labels'])

        if len(self._labels) <= 0:
            raise ValueError(f"Labels can't be empty {self._labels}")

        tf_config = create_config(
            self.config['device'],
            log_device_placement=self.config['log_device_placement'],
            per_process_gpu_memory_fraction=self.
            config['per_process_gpu_memory_fraction'])

        graph = import_graph(parse_graph_def(self.config['weights']),
                             self.config['device'])

        self.log.debug("Model (%s) placed on %s", self.config['weights'],
                       self.config['device'])

        self._session = tf.Session(graph=graph, config=tf_config)

        self._inputs = {
            alias: graph.get_tensor_by_name(name)
            for alias, name in self.input_tensors.items()
        }
        self._outputs = {
            alias: graph.get_tensor_by_name(name)
            for alias, name in self.output_tensors.items()
        }

        # warm up
        self.log.info("Warming up %s ...", self)
        self.process_single(np.zeros((2, 2, 3), dtype=np.uint8))

    def shutdown(self):
        """ Releases model when object deleted """
        self.log.info("Shutdown %s ...", self)

        if self._session is None:
            return

        try:
            self._session.close()
            self._session = None
        except tf.OpError as err:
            self.log.error('%s close TF session error: %s. Skipping...', self,
                           err)

        self.log.info("%s Destroyed successfully", self)

    def process_single(self, image: np.ndarray) -> typ.List[dict]:
        return self.process_batch([image])[0]

    def process_batch(self, images: typ.List[np.ndarray]) -> typ.List[dict]:
        preprocessed = np.stack([self._preprocess(image) for image in images])

        result = self._session.run(
            self._outputs, feed_dict={self._inputs['images']: preprocessed})

        detections_per_image = []

        for image, scores, boxes, labels in zip(images, result['scores'],
                                                result['boxes'],
                                                result['labels']):
            detections = []
            for score, box, label in zip(scores, boxes, labels):

                class_name = self._labels.get(int(label))
                if not class_name or score < self.config['threshold']:
                    continue

                # scale boxes wrt initial image size
                ymin, xmin, ymax, xmax = (np.tile(image.shape[:2], 2) *
                                          box).astype(np.int32).tolist()

                width, height = xmax - xmin, ymax - ymin

                detections.append({
                    'confidence': float(score),
                    'bounding_box': [xmin, ymin, width, height],
                    'class_name': class_name,
                })

            detections_per_image.append(detections)

        return detections_per_image

    def _preprocess(self, image: np.ndarray) -> np.ndarray:
        return cv2.resize(image, (self.config['width'], self.config['height']),
                          interpolation=cv2.INTER_NEAREST)
示例#29
0
 def test_repr(self):
     null_string = t.Or(t.String, t.Null)
     assert repr(null_string) == '<Or(<String>, <Null>)>'
     res = t.ToInt | t.String
     assert repr(res) == '<Or(<ToInt>, <String>)>'
""" rest subsystem's configuration

    - constants
    - config-file schema
"""
import trafaret as T

from servicelib.application_keys import APP_OPENAPI_SPECS_KEY

APP_OPENAPI_SPECS_KEY = APP_OPENAPI_SPECS_KEY  # pylint: disable=self-assigning-variable,bad-option-value

CONFIG_SECTION_NAME = 'rest'

schema = T.Dict({
    "version": T.Enum("{{cookiecutter.openapi_specs_version}}"),
    "location": T.Or(T.String,
                     T.URL),  # either path or url should contain version in it
})