def deserialize(cls, postproc): return cls( validation.validate_relative_path('"file"', postproc['file']), re.compile(validation.validate_string('"pattern"', postproc['pattern'])), validation.validate_string('"replacement"', postproc['replacement']), validation.validate_nonnegative_int('"count"', postproc.get('count', 0)), )
def deserialize(cls, model, name, subdirectory, stages): with validation.deserialization_context('In model "{}"'.format(name)): if not RE_MODEL_NAME.fullmatch(name): raise validation.DeserializationError( 'Invalid name, must consist only of letters, digits or ._-' ) task_type = validation.validate_string_enum( '"task_type"', model['task_type'], _common.KNOWN_TASK_TYPES) description = validation.validate_string('"description"', model['description']) license_url = validation.validate_string('"license"', model['license']) framework = validation.validate_string_enum( '"framework"', model['framework'], _common.KNOWN_FRAMEWORKS.keys()) model_stages = [] for model_subdirectory, model_part in stages.items(): model_stages.append( Model.deserialize(model_part, model_subdirectory.name, model_subdirectory, name)) quantization_output_precisions = model_stages[ 0].quantization_output_precisions precisions = model_stages[0].precisions return cls(name, subdirectory, task_type, model_stages, description, framework, license_url, precisions, quantization_output_precisions, name)
def deserialize(cls, checksum): sha384_str = validation.validate_string('"sha384"', checksum['value']) if not cls.RE_SHA384SUM.fullmatch(sha384_str): raise validation.DeserializationError( '"sha384": got invalid hash {!r}'.format(sha384_str)) sha384 = bytes.fromhex(sha384_str) return cls(sha384)
def deserialize(cls, source): return cls(validation.validate_string('"id"', source['id']))
def deserialize(cls, source): return cls(validation.validate_string('"url"', source['url']))
def deserialize(cls, postproc): return cls( validation.validate_relative_path('"file"', postproc['file']), validation.validate_string('"format"', postproc['format']), )
def deserialize(cls, model, name, subdirectory, composite_model_name): with validation.deserialization_context('In model "{}"'.format(name)): if not RE_MODEL_NAME.fullmatch(name): raise validation.DeserializationError( 'Invalid name, must consist only of letters, digits or ._-' ) files = [] file_names = set() for file in model['files']: files.append(ModelFile.deserialize(file)) if files[-1].name in file_names: raise validation.DeserializationError( 'Duplicate file name "{}"'.format(files[-1].name)) file_names.add(files[-1].name) postprocessings = [] for i, postproc in enumerate(model.get('postprocessing', [])): with validation.deserialization_context( '"postprocessing" #{}'.format(i)): postprocessings.append( postprocessing.Postproc.deserialize(postproc)) framework = validation.validate_string_enum( '"framework"', model['framework'], _common.KNOWN_FRAMEWORKS.keys()) conversion_to_onnx_args = model.get('conversion_to_onnx_args', None) if _common.KNOWN_FRAMEWORKS[framework]: if not conversion_to_onnx_args: raise validation.DeserializationError( '"conversion_to_onnx_args" is absent. ' 'Framework "{}" is supported only by conversion to ONNX.' .format(framework)) conversion_to_onnx_args = [ validation.validate_string( '"conversion_to_onnx_args" #{}'.format(i), arg) for i, arg in enumerate(model['conversion_to_onnx_args']) ] else: if conversion_to_onnx_args: raise validation.DeserializationError( 'Conversion to ONNX not supported for "{}" framework'. format(framework)) quantized = model.get('quantized', None) if quantized is not None and quantized != 'INT8': raise validation.DeserializationError( '"quantized": expected "INT8", got {!r}'.format(quantized)) if 'model_optimizer_args' in model: mo_args = [ validation.validate_string( '"model_optimizer_args" #{}'.format(i), arg) for i, arg in enumerate(model['model_optimizer_args']) ] precisions = {f'FP16-{quantized}', f'FP32-{quantized}' } if quantized is not None else {'FP16', 'FP32'} else: if framework != 'dldt': raise validation.DeserializationError( 'Model not in IR format, but no conversions defined') mo_args = None files_per_precision = {} for file in files: if len(file.name.parts) != 2: raise validation.DeserializationError( 'Can\'t derive precision from file name {!r}'. format(file.name)) p = file.name.parts[0] if p not in _common.KNOWN_PRECISIONS: raise validation.DeserializationError( 'Unknown precision {!r} derived from file name {!r}, expected one of {!r}' .format(p, file.name, _common.KNOWN_PRECISIONS)) files_per_precision.setdefault(p, set()).add( file.name.parts[1]) for precision, precision_files in files_per_precision.items(): for ext in ['xml', 'bin']: if (name + '.' + ext) not in precision_files: raise validation.DeserializationError( 'No {} file for precision "{}"'.format( ext.upper(), precision)) precisions = set(files_per_precision.keys()) quantizable = model.get('quantizable', False) if not isinstance(quantizable, bool): raise validation.DeserializationError( '"quantizable": expected a boolean, got {!r}'.format( quantizable)) quantization_output_precisions = _common.KNOWN_QUANTIZED_PRECISIONS.keys( ) if quantizable else set() description = validation.validate_string('"description"', model['description']) license_url = validation.validate_string('"license"', model['license']) task_type = validation.validate_string_enum( '"task_type"', model['task_type'], _common.KNOWN_TASK_TYPES) return cls(name, subdirectory, files, postprocessings, mo_args, framework, description, license_url, precisions, quantization_output_precisions, task_type, conversion_to_onnx_args, composite_model_name)