def load_models(models_root, args): models = [] model_names = set() composite_models = [] schema = _common.get_schema() for composite_model_config in sorted(models_root.glob('**/composite-model.yml')): composite_model_name = composite_model_config.parent.name with validation.deserialization_context('In model "{}"'.format(composite_model_name)): if not RE_MODEL_NAME.fullmatch(composite_model_name): raise validation.DeserializationError('Invalid name, must consist only of letters, digits or ._-') check_composite_model_dir(composite_model_config.parent) if composite_model_name in composite_models: raise validation.DeserializationError( 'Duplicate composite model name "{}"'.format(composite_model_name)) composite_models.append(composite_model_name) for config_path in sorted(models_root.glob('**/model.yml')): subdirectory = config_path.parent is_composite = (subdirectory.parent / 'composite-model.yml').exists() composite_model_name = subdirectory.parent.name if is_composite else None subdirectory = subdirectory.relative_to(models_root) with config_path.open('rb') as config_file, \ validation.deserialization_context('In config "{}"'.format(config_path)): model = yaml.safe_load(config_file) if not schema.check(model): raise validation.DeserializationError('Configuration file check was\'t successful.') for bad_key in ['name', 'subdirectory']: if bad_key in model: raise validation.DeserializationError('Unsupported key "{}"'.format(bad_key)) models.append(Model.deserialize(model, subdirectory.name, subdirectory, composite_model_name)) if models[-1].name in model_names: raise validation.DeserializationError( 'Duplicate model name "{}"'.format(models[-1].name)) model_names.add(models[-1].name) return models
def deserialize(cls, model, name, subdirectory, stages): with validation.deserialization_context('In model "{}"'.format(name)): if not RE_MODEL_NAME.fullmatch(name): raise validation.DeserializationError( 'Invalid name, must consist only of letters, digits or ._-' ) task_type = validation.validate_string_enum( '"task_type"', model['task_type'], _common.KNOWN_TASK_TYPES) description = validation.validate_string('"description"', model['description']) license_url = validation.validate_string('"license"', model['license']) framework = validation.validate_string_enum( '"framework"', model['framework'], _common.KNOWN_FRAMEWORKS.keys()) model_stages = [] for model_subdirectory, model_part in stages.items(): model_stages.append( Model.deserialize(model_part, model_subdirectory.name, model_subdirectory, name)) quantization_output_precisions = model_stages[ 0].quantization_output_precisions precisions = model_stages[0].precisions return cls(name, subdirectory, task_type, model_stages, description, framework, license_url, precisions, quantization_output_precisions, name)
def check_composite_model_dir(model_dir): with validation.deserialization_context('In directory "{}"'.format(model_dir)): if list(model_dir.glob('*/*/**/model.yml')): raise validation.DeserializationError( 'Directory should not contain any model.yml files in any subdirectories ' 'that are not direct children of the composite model directory') if (model_dir / 'model.yml').exists(): raise validation.DeserializationError('Directory should not contain a model.yml file') model_name = model_dir.name model_stages = list(model_dir.glob('*/model.yml')) for model in model_stages: if not model.parent.name.startswith(f'{model_name}-'): raise validation.DeserializationError( 'Names of composite model parts should start with composite model name')
def deserialize(cls, checksum): sha384_str = validation.validate_string('"sha384"', checksum['value']) if not cls.RE_SHA384SUM.fullmatch(sha384_str): raise validation.DeserializationError( '"sha384": got invalid hash {!r}'.format(sha384_str)) sha384 = bytes.fromhex(sha384_str) return cls(sha384)
def deserialize(cls, model, name, subdirectory, composite_model_name): with validation.deserialization_context('In model "{}"'.format(name)): if not RE_MODEL_NAME.fullmatch(name): raise validation.DeserializationError( 'Invalid name, must consist only of letters, digits or ._-' ) files = [] file_names = set() for file in model['files']: files.append(ModelFile.deserialize(file)) if files[-1].name in file_names: raise validation.DeserializationError( 'Duplicate file name "{}"'.format(files[-1].name)) file_names.add(files[-1].name) postprocessings = [] for i, postproc in enumerate(model.get('postprocessing', [])): with validation.deserialization_context( '"postprocessing" #{}'.format(i)): postprocessings.append( postprocessing.Postproc.deserialize(postproc)) framework = validation.validate_string_enum( '"framework"', model['framework'], _common.KNOWN_FRAMEWORKS.keys()) conversion_to_onnx_args = model.get('conversion_to_onnx_args', None) if _common.KNOWN_FRAMEWORKS[framework]: if not conversion_to_onnx_args: raise validation.DeserializationError( '"conversion_to_onnx_args" is absent. ' 'Framework "{}" is supported only by conversion to ONNX.' .format(framework)) conversion_to_onnx_args = [ validation.validate_string( '"conversion_to_onnx_args" #{}'.format(i), arg) for i, arg in enumerate(model['conversion_to_onnx_args']) ] else: if conversion_to_onnx_args: raise validation.DeserializationError( 'Conversion to ONNX not supported for "{}" framework'. format(framework)) quantized = model.get('quantized', None) if quantized is not None and quantized != 'INT8': raise validation.DeserializationError( '"quantized": expected "INT8", got {!r}'.format(quantized)) if 'model_optimizer_args' in model: mo_args = [ validation.validate_string( '"model_optimizer_args" #{}'.format(i), arg) for i, arg in enumerate(model['model_optimizer_args']) ] precisions = {f'FP16-{quantized}', f'FP32-{quantized}' } if quantized is not None else {'FP16', 'FP32'} else: if framework != 'dldt': raise validation.DeserializationError( 'Model not in IR format, but no conversions defined') mo_args = None files_per_precision = {} for file in files: if len(file.name.parts) != 2: raise validation.DeserializationError( 'Can\'t derive precision from file name {!r}'. format(file.name)) p = file.name.parts[0] if p not in _common.KNOWN_PRECISIONS: raise validation.DeserializationError( 'Unknown precision {!r} derived from file name {!r}, expected one of {!r}' .format(p, file.name, _common.KNOWN_PRECISIONS)) files_per_precision.setdefault(p, set()).add( file.name.parts[1]) for precision, precision_files in files_per_precision.items(): for ext in ['xml', 'bin']: if (name + '.' + ext) not in precision_files: raise validation.DeserializationError( 'No {} file for precision "{}"'.format( ext.upper(), precision)) precisions = set(files_per_precision.keys()) quantizable = model.get('quantizable', False) if not isinstance(quantizable, bool): raise validation.DeserializationError( '"quantizable": expected a boolean, got {!r}'.format( quantizable)) quantization_output_precisions = _common.KNOWN_QUANTIZED_PRECISIONS.keys( ) if quantizable else set() description = validation.validate_string('"description"', model['description']) license_url = validation.validate_string('"license"', model['license']) task_type = validation.validate_string_enum( '"task_type"', model['task_type'], _common.KNOWN_TASK_TYPES) return cls(name, subdirectory, files, postprocessings, mo_args, framework, description, license_url, precisions, quantization_output_precisions, task_type, conversion_to_onnx_args, composite_model_name)
def load_models(models_root, args, mode=ModelLoadingMode.all): models = [] model_names = set() composite_models = [] composite_model_names = set() if mode in (ModelLoadingMode.all, ModelLoadingMode.composite_only): for composite_model_config in sorted( models_root.glob('**/composite-model.yml')): composite_model_name = composite_model_config.parent.name with validation.deserialization_context( 'In model "{}"'.format(composite_model_name)): if not RE_MODEL_NAME.fullmatch(composite_model_name): raise validation.DeserializationError( 'Invalid name, must consist only of letters, digits or ._-' ) check_composite_model_dir(composite_model_config.parent) with composite_model_config.open('rb') as config_file, \ validation.deserialization_context('In config "{}"'.format(composite_model_config)): composite_model = yaml.safe_load(config_file) model_stages = {} for stage in sorted( composite_model_config.parent.glob('*/model.yml')): with stage.open('rb') as stage_config_file, \ validation.deserialization_context('In config "{}"'.format(stage_config_file)): model = yaml.safe_load(stage_config_file) stage_subdirectory = stage.parent.relative_to( models_root) model_stages[stage_subdirectory] = model if len(model_stages) == 0: continue subdirectory = composite_model_config.parent.relative_to( models_root) composite_models.append( CompositeModel.deserialize(composite_model, composite_model_name, subdirectory, model_stages)) if composite_model_name in composite_model_names: raise validation.DeserializationError( 'Duplicate composite model name "{}"'.format( composite_model_name)) composite_model_names.add(composite_model_name) if mode != ModelLoadingMode.composite_only: for config_path in sorted(models_root.glob('**/model.yml')): subdirectory = config_path.parent is_composite = (subdirectory.parent / 'composite-model.yml').exists() composite_model_name = None if is_composite: if mode != ModelLoadingMode.ignore_composite: continue composite_model_name = subdirectory.parent.name subdirectory = subdirectory.relative_to(models_root) with config_path.open('rb') as config_file, \ validation.deserialization_context('In config "{}"'.format(config_path)): model = yaml.safe_load(config_file) for bad_key in ['name', 'subdirectory']: if bad_key in model: raise validation.DeserializationError( 'Unsupported key "{}"'.format(bad_key)) if subdirectory.name not in EXCLUDED_MODELS: models.append( Model.deserialize(model, subdirectory.name, subdirectory, composite_model_name)) continue if models[-1].name in model_names: raise validation.DeserializationError( 'Duplicate model name "{}"'.format(models[-1].name)) model_names.add(models[-1].name) return sorted(models + composite_models, key=lambda model: model.name)
def deserialize(cls, value): try: return cls.types[value['$type']].deserialize(value) except KeyError: raise validation.DeserializationError( 'Unknown "$type": "{}"'.format(value['$type']))