コード例 #1
0
    def deserialize(cls, file):
        name = validation.validate_relative_path('"name"', file['name'])

        with validation.deserialization_context('In file "{}"'.format(name)):
            size = validation.validate_nonnegative_int('"size"', file['size'])

            with validation.deserialization_context('"checksum"'):
                checksum = cache.Checksum.deserialize(file['checksum'])

            with validation.deserialization_context('"source"'):
                source = file_source.FileSource.deserialize(file['source'])

            return cls(name, size, checksum, source)
コード例 #2
0
def load_models(models_root, args):
    models = []
    model_names = set()

    composite_models = []

    schema = _common.get_schema()

    for composite_model_config in sorted(models_root.glob('**/composite-model.yml')):
        composite_model_name = composite_model_config.parent.name
        with validation.deserialization_context('In model "{}"'.format(composite_model_name)):
            if not RE_MODEL_NAME.fullmatch(composite_model_name):
                raise validation.DeserializationError('Invalid name, must consist only of letters, digits or ._-')

            check_composite_model_dir(composite_model_config.parent)

            if composite_model_name in composite_models:
                raise validation.DeserializationError(
                    'Duplicate composite model name "{}"'.format(composite_model_name))
            composite_models.append(composite_model_name)

    for config_path in sorted(models_root.glob('**/model.yml')):
        subdirectory = config_path.parent

        is_composite = (subdirectory.parent / 'composite-model.yml').exists()
        composite_model_name = subdirectory.parent.name if is_composite else None

        subdirectory = subdirectory.relative_to(models_root)

        with config_path.open('rb') as config_file, \
                validation.deserialization_context('In config "{}"'.format(config_path)):

            model = yaml.safe_load(config_file)
            if not schema.check(model):
                raise validation.DeserializationError('Configuration file check was\'t successful.')

            for bad_key in ['name', 'subdirectory']:
                if bad_key in model:
                    raise validation.DeserializationError('Unsupported key "{}"'.format(bad_key))

            models.append(Model.deserialize(model, subdirectory.name, subdirectory, composite_model_name))

            if models[-1].name in model_names:
                raise validation.DeserializationError(
                    'Duplicate model name "{}"'.format(models[-1].name))
            model_names.add(models[-1].name)

    return models
コード例 #3
0
    def deserialize(cls, model, name, subdirectory, stages):
        with validation.deserialization_context('In model "{}"'.format(name)):
            if not RE_MODEL_NAME.fullmatch(name):
                raise validation.DeserializationError(
                    'Invalid name, must consist only of letters, digits or ._-'
                )

            task_type = validation.validate_string_enum(
                '"task_type"', model['task_type'], _common.KNOWN_TASK_TYPES)

            description = validation.validate_string('"description"',
                                                     model['description'])

            license_url = validation.validate_string('"license"',
                                                     model['license'])

            framework = validation.validate_string_enum(
                '"framework"', model['framework'],
                _common.KNOWN_FRAMEWORKS.keys())

            model_stages = []
            for model_subdirectory, model_part in stages.items():
                model_stages.append(
                    Model.deserialize(model_part, model_subdirectory.name,
                                      model_subdirectory, name))

            quantization_output_precisions = model_stages[
                0].quantization_output_precisions
            precisions = model_stages[0].precisions

            return cls(name, subdirectory, task_type, model_stages,
                       description, framework, license_url, precisions,
                       quantization_output_precisions, name)
コード例 #4
0
def check_composite_model_dir(model_dir):
    with validation.deserialization_context('In directory "{}"'.format(model_dir)):
        if list(model_dir.glob('*/*/**/model.yml')):
            raise validation.DeserializationError(
                'Directory should not contain any model.yml files in any subdirectories '
                'that are not direct children of the composite model directory')

        if (model_dir / 'model.yml').exists():
            raise validation.DeserializationError('Directory should not contain a model.yml file')

        model_name = model_dir.name
        model_stages = list(model_dir.glob('*/model.yml'))
        for model in model_stages:
            if not model.parent.name.startswith(f'{model_name}-'):
                raise validation.DeserializationError(
                    'Names of composite model parts should start with composite model name')
コード例 #5
0
    def deserialize(cls, model, name, subdirectory, composite_model_name):
        with validation.deserialization_context('In model "{}"'.format(name)):
            if not RE_MODEL_NAME.fullmatch(name):
                raise validation.DeserializationError(
                    'Invalid name, must consist only of letters, digits or ._-'
                )

            files = []
            file_names = set()

            for file in model['files']:
                files.append(ModelFile.deserialize(file))

                if files[-1].name in file_names:
                    raise validation.DeserializationError(
                        'Duplicate file name "{}"'.format(files[-1].name))
                file_names.add(files[-1].name)

            postprocessings = []

            for i, postproc in enumerate(model.get('postprocessing', [])):
                with validation.deserialization_context(
                        '"postprocessing" #{}'.format(i)):
                    postprocessings.append(
                        postprocessing.Postproc.deserialize(postproc))

            framework = validation.validate_string_enum(
                '"framework"', model['framework'],
                _common.KNOWN_FRAMEWORKS.keys())

            conversion_to_onnx_args = model.get('conversion_to_onnx_args',
                                                None)
            if _common.KNOWN_FRAMEWORKS[framework]:
                if not conversion_to_onnx_args:
                    raise validation.DeserializationError(
                        '"conversion_to_onnx_args" is absent. '
                        'Framework "{}" is supported only by conversion to ONNX.'
                        .format(framework))
                conversion_to_onnx_args = [
                    validation.validate_string(
                        '"conversion_to_onnx_args" #{}'.format(i), arg)
                    for i, arg in enumerate(model['conversion_to_onnx_args'])
                ]
            else:
                if conversion_to_onnx_args:
                    raise validation.DeserializationError(
                        'Conversion to ONNX not supported for "{}" framework'.
                        format(framework))

            quantized = model.get('quantized', None)
            if quantized is not None and quantized != 'INT8':
                raise validation.DeserializationError(
                    '"quantized": expected "INT8", got {!r}'.format(quantized))

            if 'model_optimizer_args' in model:
                mo_args = [
                    validation.validate_string(
                        '"model_optimizer_args" #{}'.format(i), arg)
                    for i, arg in enumerate(model['model_optimizer_args'])
                ]
                precisions = {f'FP16-{quantized}', f'FP32-{quantized}'
                              } if quantized is not None else {'FP16', 'FP32'}
            else:
                if framework != 'dldt':
                    raise validation.DeserializationError(
                        'Model not in IR format, but no conversions defined')

                mo_args = None

                files_per_precision = {}

                for file in files:
                    if len(file.name.parts) != 2:
                        raise validation.DeserializationError(
                            'Can\'t derive precision from file name {!r}'.
                            format(file.name))
                    p = file.name.parts[0]
                    if p not in _common.KNOWN_PRECISIONS:
                        raise validation.DeserializationError(
                            'Unknown precision {!r} derived from file name {!r}, expected one of {!r}'
                            .format(p, file.name, _common.KNOWN_PRECISIONS))
                    files_per_precision.setdefault(p, set()).add(
                        file.name.parts[1])

                for precision, precision_files in files_per_precision.items():
                    for ext in ['xml', 'bin']:
                        if (name + '.' + ext) not in precision_files:
                            raise validation.DeserializationError(
                                'No {} file for precision "{}"'.format(
                                    ext.upper(), precision))

                precisions = set(files_per_precision.keys())

            quantizable = model.get('quantizable', False)
            if not isinstance(quantizable, bool):
                raise validation.DeserializationError(
                    '"quantizable": expected a boolean, got {!r}'.format(
                        quantizable))

            quantization_output_precisions = _common.KNOWN_QUANTIZED_PRECISIONS.keys(
            ) if quantizable else set()

            description = validation.validate_string('"description"',
                                                     model['description'])

            license_url = validation.validate_string('"license"',
                                                     model['license'])

            task_type = validation.validate_string_enum(
                '"task_type"', model['task_type'], _common.KNOWN_TASK_TYPES)

            return cls(name, subdirectory, files, postprocessings, mo_args,
                       framework, description, license_url, precisions,
                       quantization_output_precisions, task_type,
                       conversion_to_onnx_args, composite_model_name)
コード例 #6
0
def load_models(models_root, args, mode=ModelLoadingMode.all):
    models = []
    model_names = set()

    composite_models = []
    composite_model_names = set()

    if mode in (ModelLoadingMode.all, ModelLoadingMode.composite_only):

        for composite_model_config in sorted(
                models_root.glob('**/composite-model.yml')):
            composite_model_name = composite_model_config.parent.name
            with validation.deserialization_context(
                    'In model "{}"'.format(composite_model_name)):
                if not RE_MODEL_NAME.fullmatch(composite_model_name):
                    raise validation.DeserializationError(
                        'Invalid name, must consist only of letters, digits or ._-'
                    )

                check_composite_model_dir(composite_model_config.parent)

                with composite_model_config.open('rb') as config_file, \
                    validation.deserialization_context('In config "{}"'.format(composite_model_config)):

                    composite_model = yaml.safe_load(config_file)
                    model_stages = {}
                    for stage in sorted(
                            composite_model_config.parent.glob('*/model.yml')):
                        with stage.open('rb') as stage_config_file, \
                            validation.deserialization_context('In config "{}"'.format(stage_config_file)):
                            model = yaml.safe_load(stage_config_file)

                            stage_subdirectory = stage.parent.relative_to(
                                models_root)
                            model_stages[stage_subdirectory] = model

                    if len(model_stages) == 0:
                        continue
                    subdirectory = composite_model_config.parent.relative_to(
                        models_root)
                    composite_models.append(
                        CompositeModel.deserialize(composite_model,
                                                   composite_model_name,
                                                   subdirectory, model_stages))

                    if composite_model_name in composite_model_names:
                        raise validation.DeserializationError(
                            'Duplicate composite model name "{}"'.format(
                                composite_model_name))
                    composite_model_names.add(composite_model_name)

    if mode != ModelLoadingMode.composite_only:
        for config_path in sorted(models_root.glob('**/model.yml')):
            subdirectory = config_path.parent

            is_composite = (subdirectory.parent /
                            'composite-model.yml').exists()
            composite_model_name = None
            if is_composite:
                if mode != ModelLoadingMode.ignore_composite:
                    continue
                composite_model_name = subdirectory.parent.name

            subdirectory = subdirectory.relative_to(models_root)

            with config_path.open('rb') as config_file, \
                    validation.deserialization_context('In config "{}"'.format(config_path)):

                model = yaml.safe_load(config_file)

                for bad_key in ['name', 'subdirectory']:
                    if bad_key in model:
                        raise validation.DeserializationError(
                            'Unsupported key "{}"'.format(bad_key))

                if subdirectory.name not in EXCLUDED_MODELS:
                    models.append(
                        Model.deserialize(model, subdirectory.name,
                                          subdirectory, composite_model_name))
                    continue

                if models[-1].name in model_names:
                    raise validation.DeserializationError(
                        'Duplicate model name "{}"'.format(models[-1].name))
                model_names.add(models[-1].name)

    return sorted(models + composite_models, key=lambda model: model.name)