Example #1
0
 def test_parse_shape_with_dim_param_including_x(self):
     meta_args = ["input0,'batchx'x3x224x224"]
     meta = parse_meta(meta_args, includes_dtype=False)
     assert meta["input0"].shape == ("batchx", 3, 224, 224)
Example #2
0
 def test_parse_shape_only(self):
     meta_args = ["input0,1x3x224x224"]
     meta = parse_meta(meta_args, includes_dtype=False)
     assert meta["input0"].shape == (1, 3, 224, 224)
     assert meta["input0"].dtype is None
Example #3
0
 def test_parse_shape_dtype_auto(self):
     meta_args = ["input0,auto,auto"]
     meta = parse_meta(meta_args)
     assert meta["input0"].shape is None
     assert meta["input0"].dtype is None
Example #4
0
 def test_parse_shape_with_dim_param_double_quote(self):
     meta_args = ['input0,"batch"x3x224x224']
     meta = parse_meta(meta_args, includes_dtype=False)
     assert meta["input0"].shape == ("batch", 3, 224, 224)
Example #5
0
 def test_parse_shape_dtype(self):
     meta_args = ["input0,1x3x224x224,float32"]
     meta = parse_meta(meta_args)
     assert meta["input0"].shape == (1, 3, 224, 224)
     assert meta["input0"].dtype == np.float32
Example #6
0
 def test_parse_dtype_only(self):
     meta_args = ["input0,float32"]
     meta = parse_meta(meta_args, includes_shape=False)
     assert meta["input0"].shape is None
     assert meta["input0"].dtype == np.float32
Example #7
0
 def test_parse_shape_single_dim(self):
     meta_args = ["input0,1"]
     meta = parse_meta(meta_args, includes_dtype=False)
     assert meta["input0"].shape == (1, )
Example #8
0
    def run(self, args):
        onnx_model, graph = super().import_graph(args)
        TENSOR_MAP = graph.tensors()


        def get_tensor(name):
            if name not in TENSOR_MAP:
                G_LOGGER.critical("Tensor: {:} does not exist in the model.".format(name))
            return TENSOR_MAP[name]


        def missing_meta_tensors(input_metadata, output_metadata):
            names = []
            for name, (dtype, shape) in input_metadata.items():
                if dtype is None or not shape:
                    names.append(name)
            for name, (dtype, shape) in output_metadata.items():
                if dtype is None:
                    names.append(name)
            return names


        def update_meta_from_tensor_map(meta):
            for name, (dtype, shape) in meta.items():
                tensor = get_tensor(name)
                meta[name] = (dtype or tensor.dtype, shape or tensor.shape)
            return meta


        def meta_from_tensors(tensors):
            meta = TensorMetadata()
            for tensor in tensors:
                meta.add(tensor.name, tensor.dtype, tensor.shape)
            return meta


        if args.input_meta:
            input_metadata = update_meta_from_tensor_map(tools_util.parse_meta(args.input_meta))
        else:
            input_metadata = meta_from_tensors(graph.inputs)

        if args.output_meta:
            output_metadata = update_meta_from_tensor_map(tools_util.parse_meta(args.output_meta, includes_shape=False))
        else:
            output_metadata = meta_from_tensors(graph.outputs)

        missing_tensors = missing_meta_tensors(input_metadata, output_metadata)
        if missing_tensors:
            # Use ONNX runtime with static shapes to infer shapes when all else fails
            # Returns a TensorMetadata for all tensors in the graph.
            def fallback_shape_inference(onnx_model):
                from polygraphy.backend.onnx import BytesFromOnnx, ModifyOnnx
                from polygraphy.backend.onnxrt import (OnnxrtRunner,
                                                       SessionFromOnnxBytes)

                load_model = ModifyOnnx(onnx_model, outputs=constants.MARK_ALL)
                with OnnxrtRunner(SessionFromOnnxBytes(BytesFromOnnx(load_model))) as runner:
                    data_loader = self.makers[DataLoaderArgs].get_data_loader()
                    data_loader.input_metadata = runner.get_input_metadata()
                    outputs = runner.infer(feed_dict=data_loader[0])

                    meta = TensorMetadata()
                    for name, output in outputs.items():
                        meta.add(name, output.dtype, output.shape)
                    return meta


            def update_meta_from_meta(meta, golden_meta):
                for name, (dtype, shape) in meta.items():
                    if name in golden_meta:
                        (golden_dtype, golden_shape) = golden_meta[name]
                        meta[name] = (dtype or golden_dtype, shape or golden_shape)
                        G_LOGGER.verbose("Updated tensor: {:} metadata to: {:}".format(name, meta[name]))
                return meta


            G_LOGGER.warning("Some tensor shapes or dtypes are missing in the model. Note: Missing Tensors: {:}. "
                             "\nWill run inference to determine shapes. This will cause dynamic "
                             "dimensions to become static.\nTo avoid this, please provide metadata on the command-line. "
                                .format(missing_tensors))
            golden_meta = fallback_shape_inference(onnx_model)
            input_metadata = update_meta_from_meta(input_metadata, golden_meta)
            output_metadata = update_meta_from_meta(output_metadata, golden_meta)


        # Set the graph inputs and outputs
        graph.inputs.clear()
        for name, (dtype, shape) in input_metadata.items():
            tensor = get_tensor(name)
            tensor.dtype, tensor.shape = dtype, shape
            tensor.inputs.clear()
            graph.inputs.append(tensor)

        graph.outputs.clear()
        for name, (dtype, shape) in output_metadata.items():
            tensor = get_tensor(name)
            tensor.dtype, tensor.shape = dtype, shape
            graph.outputs.append(tensor)

        G_LOGGER.info("Using Graph Inputs:\n{:}{:}".format(constants.TAB, graph.inputs))
        G_LOGGER.info("Using Graph Outputs:\n{:}{:}".format(constants.TAB, graph.outputs))

        super().export_graph(graph, args)
Example #9
0
    def parse(self, args):
        def determine_model_type():
            if tools_util.get(args, "model_type") is not None:
                return args.model_type.lower()

            if tools_util.get(args, "model_file") is None:
                return None

            def use_ext(ext_mapping):
                file_ext = os.path.splitext(args.model_file)[-1]
                if file_ext in ext_mapping:
                    return ext_mapping[file_ext]

            runners = misc.default_value(tools_util.get(args, "runners"), [])
            if tools_util.get(args, "ckpt") or os.path.isdir(args.model_file):
                return "ckpt"
            elif "tf" in runners or "trt_legacy" in runners:
                if args.caffe_model:
                    return "caffe"
                ext_mapping = {
                    ".hdf5": "keras",
                    ".uff": "uff",
                    ".prototxt": "caffe",
                    ".onnx": "onnx",
                    ".engine": "engine",
                    ".plan": "engine"
                }
                return use_ext(ext_mapping) or "frozen"
            else:
                # When no framework is provided, some extensions can be ambiguous
                ext_mapping = {
                    ".hdf5": "keras",
                    ".graphdef": "frozen",
                    ".onnx": "onnx",
                    ".uff": "uff",
                    ".engine": "engine",
                    ".plan": "engine"
                }
                model_type = use_ext(ext_mapping)
                if model_type:
                    return model_type

            G_LOGGER.critical(
                "Could not automatically determine model type for: {:}\n"
                "Please explicitly specify the type with the --model-type option"
                .format(args.model_file))

        if tools_util.get(args, "model_file"):
            G_LOGGER.verbose("Model: {:}".format(args.model_file))
            if not os.path.exists(args.model_file):
                G_LOGGER.warning("Model path does not exist: {:}".format(
                    args.model_file))
            args.model_file = os.path.abspath(args.model_file)

        if tools_util.get(args, "input_shapes"):
            self.input_shapes = tools_util.parse_meta(
                tools_util.get(args, "input_shapes"),
                includes_dtype=False)  # TensorMetadata
        else:
            self.input_shapes = TensorMetadata()

        self.model_file = args.model_file
        self.model_type = misc.default_value(self._model_type,
                                             determine_model_type())