def test_parse_shape_with_dim_param_including_x(self): meta_args = ["input0,'batchx'x3x224x224"] meta = parse_meta(meta_args, includes_dtype=False) assert meta["input0"].shape == ("batchx", 3, 224, 224)
def test_parse_shape_only(self): meta_args = ["input0,1x3x224x224"] meta = parse_meta(meta_args, includes_dtype=False) assert meta["input0"].shape == (1, 3, 224, 224) assert meta["input0"].dtype is None
def test_parse_shape_dtype_auto(self): meta_args = ["input0,auto,auto"] meta = parse_meta(meta_args) assert meta["input0"].shape is None assert meta["input0"].dtype is None
def test_parse_shape_with_dim_param_double_quote(self): meta_args = ['input0,"batch"x3x224x224'] meta = parse_meta(meta_args, includes_dtype=False) assert meta["input0"].shape == ("batch", 3, 224, 224)
def test_parse_shape_dtype(self): meta_args = ["input0,1x3x224x224,float32"] meta = parse_meta(meta_args) assert meta["input0"].shape == (1, 3, 224, 224) assert meta["input0"].dtype == np.float32
def test_parse_dtype_only(self): meta_args = ["input0,float32"] meta = parse_meta(meta_args, includes_shape=False) assert meta["input0"].shape is None assert meta["input0"].dtype == np.float32
def test_parse_shape_single_dim(self): meta_args = ["input0,1"] meta = parse_meta(meta_args, includes_dtype=False) assert meta["input0"].shape == (1, )
def run(self, args): onnx_model, graph = super().import_graph(args) TENSOR_MAP = graph.tensors() def get_tensor(name): if name not in TENSOR_MAP: G_LOGGER.critical("Tensor: {:} does not exist in the model.".format(name)) return TENSOR_MAP[name] def missing_meta_tensors(input_metadata, output_metadata): names = [] for name, (dtype, shape) in input_metadata.items(): if dtype is None or not shape: names.append(name) for name, (dtype, shape) in output_metadata.items(): if dtype is None: names.append(name) return names def update_meta_from_tensor_map(meta): for name, (dtype, shape) in meta.items(): tensor = get_tensor(name) meta[name] = (dtype or tensor.dtype, shape or tensor.shape) return meta def meta_from_tensors(tensors): meta = TensorMetadata() for tensor in tensors: meta.add(tensor.name, tensor.dtype, tensor.shape) return meta if args.input_meta: input_metadata = update_meta_from_tensor_map(tools_util.parse_meta(args.input_meta)) else: input_metadata = meta_from_tensors(graph.inputs) if args.output_meta: output_metadata = update_meta_from_tensor_map(tools_util.parse_meta(args.output_meta, includes_shape=False)) else: output_metadata = meta_from_tensors(graph.outputs) missing_tensors = missing_meta_tensors(input_metadata, output_metadata) if missing_tensors: # Use ONNX runtime with static shapes to infer shapes when all else fails # Returns a TensorMetadata for all tensors in the graph. def fallback_shape_inference(onnx_model): from polygraphy.backend.onnx import BytesFromOnnx, ModifyOnnx from polygraphy.backend.onnxrt import (OnnxrtRunner, SessionFromOnnxBytes) load_model = ModifyOnnx(onnx_model, outputs=constants.MARK_ALL) with OnnxrtRunner(SessionFromOnnxBytes(BytesFromOnnx(load_model))) as runner: data_loader = self.makers[DataLoaderArgs].get_data_loader() data_loader.input_metadata = runner.get_input_metadata() outputs = runner.infer(feed_dict=data_loader[0]) meta = TensorMetadata() for name, output in outputs.items(): meta.add(name, output.dtype, output.shape) return meta def update_meta_from_meta(meta, golden_meta): for name, (dtype, shape) in meta.items(): if name in golden_meta: (golden_dtype, golden_shape) = golden_meta[name] meta[name] = (dtype or golden_dtype, shape or golden_shape) G_LOGGER.verbose("Updated tensor: {:} metadata to: {:}".format(name, meta[name])) return meta G_LOGGER.warning("Some tensor shapes or dtypes are missing in the model. Note: Missing Tensors: {:}. " "\nWill run inference to determine shapes. This will cause dynamic " "dimensions to become static.\nTo avoid this, please provide metadata on the command-line. " .format(missing_tensors)) golden_meta = fallback_shape_inference(onnx_model) input_metadata = update_meta_from_meta(input_metadata, golden_meta) output_metadata = update_meta_from_meta(output_metadata, golden_meta) # Set the graph inputs and outputs graph.inputs.clear() for name, (dtype, shape) in input_metadata.items(): tensor = get_tensor(name) tensor.dtype, tensor.shape = dtype, shape tensor.inputs.clear() graph.inputs.append(tensor) graph.outputs.clear() for name, (dtype, shape) in output_metadata.items(): tensor = get_tensor(name) tensor.dtype, tensor.shape = dtype, shape graph.outputs.append(tensor) G_LOGGER.info("Using Graph Inputs:\n{:}{:}".format(constants.TAB, graph.inputs)) G_LOGGER.info("Using Graph Outputs:\n{:}{:}".format(constants.TAB, graph.outputs)) super().export_graph(graph, args)
def parse(self, args): def determine_model_type(): if tools_util.get(args, "model_type") is not None: return args.model_type.lower() if tools_util.get(args, "model_file") is None: return None def use_ext(ext_mapping): file_ext = os.path.splitext(args.model_file)[-1] if file_ext in ext_mapping: return ext_mapping[file_ext] runners = misc.default_value(tools_util.get(args, "runners"), []) if tools_util.get(args, "ckpt") or os.path.isdir(args.model_file): return "ckpt" elif "tf" in runners or "trt_legacy" in runners: if args.caffe_model: return "caffe" ext_mapping = { ".hdf5": "keras", ".uff": "uff", ".prototxt": "caffe", ".onnx": "onnx", ".engine": "engine", ".plan": "engine" } return use_ext(ext_mapping) or "frozen" else: # When no framework is provided, some extensions can be ambiguous ext_mapping = { ".hdf5": "keras", ".graphdef": "frozen", ".onnx": "onnx", ".uff": "uff", ".engine": "engine", ".plan": "engine" } model_type = use_ext(ext_mapping) if model_type: return model_type G_LOGGER.critical( "Could not automatically determine model type for: {:}\n" "Please explicitly specify the type with the --model-type option" .format(args.model_file)) if tools_util.get(args, "model_file"): G_LOGGER.verbose("Model: {:}".format(args.model_file)) if not os.path.exists(args.model_file): G_LOGGER.warning("Model path does not exist: {:}".format( args.model_file)) args.model_file = os.path.abspath(args.model_file) if tools_util.get(args, "input_shapes"): self.input_shapes = tools_util.parse_meta( tools_util.get(args, "input_shapes"), includes_dtype=False) # TensorMetadata else: self.input_shapes = TensorMetadata() self.model_file = args.model_file self.model_type = misc.default_value(self._model_type, determine_model_type())