Beispiel #1
0
 def call_impl(self):
     """
     Returns:
         onnx_graphsurgeon.Graph: The ONNX-GraphSurgeon representation of the ONNX model
     """
     model, _ = util.invoke_if_callable(self._model)
     return gs.import_onnx(model)
Beispiel #2
0
    def call_impl(self):
        """
        Returns:
            onnx.ModelProto: The ONNX model.
        """
        (graph, output_names), _ = util.invoke_if_callable(self._graph)
        input_names = list(tf_util.get_input_metadata(graph).keys())

        if self.fold_constant:
            G_LOGGER.info(
                "Folding constants in graph using tf2onnx.tfonnx.tf_optimize")
        graphdef = graph.as_graph_def()
        if self.optimize:
            graphdef = tf2onnx.tfonnx.tf_optimize(
                input_names,
                output_names,
                graph.as_graph_def(),
                fold_constant=self.fold_constant)

        with tf.Graph().as_default() as graph, tf.compat.v1.Session(
                graph=graph) as sess:
            tf.import_graph_def(graphdef, name="")

            onnx_graph = tf2onnx.tfonnx.process_tf_graph(
                graph,
                input_names=input_names,
                output_names=output_names,
                opset=self.opset)
            if self.optimize:
                onnx_graph = tf2onnx.optimizer.optimize_graph(onnx_graph)
            return onnx_graph.make_model("model")
Beispiel #3
0
    def call_impl(self):
        """
        Returns:
            Tuple[tf.Graph, Sequence[str]]: The TensorFlow graph, and the names of its outputs.
        """
        (graph, outputs), _ = util.invoke_if_callable(self._graph)

        if self.path:
            util.save_file(graph.as_graph_def().SerializeToString(),
                           dest=self.path)
        if self.tensorboard_dir:
            G_LOGGER.info("Writing tensorboard events to {:}".format(
                self.tensorboard_dir))
            train_writer = tf.compat.v1.summary.FileWriter(
                self.tensorboard_dir)
            train_writer.add_graph(graph)

        if self.engine_dir is not None:
            graphdef = graph.as_graph_def()
            segment_number = 0
            for node in graphdef.node:
                if node.op == "TRTEngineOp":
                    engine = node.attr["serialized_segment"].s
                    if self.engine_dir is not None:
                        util.save_file(
                            contents=engine,
                            dest=os.path.join(
                                self.engine_dir,
                                "segment-{:}".format(segment_number)))
                    segment_number += 1

        return graph, outputs
Beispiel #4
0
 def call_impl(self):
     """
     Returns:
         bytes: The serialized model.
     """
     model, _ = util.invoke_if_callable(self._model)
     return model.SerializeToString()
Beispiel #5
0
    def call_impl(self):
        """
        Returns:
            tf.Session: The TensorFlow session.
        """
        config, _ = util.invoke_if_callable(self.config)
        (graph, output_names), _ = util.invoke_if_callable(self.graph)

        with graph.as_default() as graph, tf.compat.v1.Session(
                graph=graph, config=config).as_default() as sess:
            G_LOGGER.verbose(
                "Using TensorFlow outputs: {:}".format(output_names))
            G_LOGGER.extra_verbose(
                "Initializing variables in TensorFlow Graph")
            sess.run(tf.compat.v1.initializers.global_variables())
            return sess, output_names
Beispiel #6
0
 def call_impl(self):
     """
     Returns:
         onnxruntime.InferenceSession: The inference session.
     """
     model_bytes, _ = util.invoke_if_callable(self._model_bytes_or_path)
     return onnxruntime.InferenceSession(model_bytes)
Beispiel #7
0
    def call_impl(self):
        """
        Returns:
            onnx.ModelProto: The new ONNX model with shapes inferred.
        """
        model, _ = util.invoke_if_callable(self._model)

        G_LOGGER.verbose("Starting ONNX shape inference")
        try:
            if isinstance(model, onnx.ModelProto):
                if model.ByteSize() > LARGE_MODEL_THRESHOLD:
                    G_LOGGER.warning(
                        "Attempting to run shape inference on a large model. "
                        "This may require a large amount of memory.\nIf memory consumption becomes too high, "
                        "the process may be killed. You may want to try disabling shape inference in that case. ",
                        mode=LogMode.ONCE)
                model = shape_inference.infer_shapes(model)
            else:
                with tempfile.NamedTemporaryFile(suffix=".onnx") as f:
                    shape_inference.infer_shapes_path(model, f.name)
                    model = onnx_from_path(f.name)
            G_LOGGER.verbose("ONNX Shape Inference completed successfully")
        except Exception as err:
            if not self.error_ok:
                raise
            G_LOGGER.warning(
                "ONNX shape inference exited with an error:\n{:}".format(err))
        return model
Beispiel #8
0
 def call_impl(self):
     """
     Returns:
         bytes: The bytes saved.
     """
     obj, _ = util.invoke_if_callable(self._bytes)
     util.save_file(obj, self._path)
     return obj
Beispiel #9
0
 def __enter__(self):
     model, _ = util.invoke_if_callable(self._model)
     self.USE_GS_GRAPH = isinstance(model, gs.Graph)
     if self.USE_GS_GRAPH:
         self.graph = model.copy()
     else:
         self.graph = gs_from_onnx(model)
     return self
Beispiel #10
0
 def call_impl(self):
     """
     Returns:
         (trt.IBuilder, trt.INetworkDefinition, trt.OnnxParser):
                 A TensorRT network, as well as the builder used to create it, and the parser
                 used to populate it.
     """
     with util.FreeOnException(super().call_impl()) as (builder, network, parser):
         success = parser.parse(util.invoke_if_callable(self._model_bytes)[0])
         trt_util.check_onnx_parser_errors(parser, success)
         return builder, network, parser
Beispiel #11
0
    def call_impl(self, *args, **kwargs):
        """
        Returns:
            object:
                    The provided ``obj`` argument, or its return value if it is
                    callable. Returns ``None`` if ``obj`` was not set.
        """
        for plugin in self.plugins:
            G_LOGGER.info("Loading plugin library: {:}".format(plugin))
            ctypes.CDLL(plugin)

        ret, _ = util.invoke_if_callable(self.obj, *args, **kwargs)
        return ret
Beispiel #12
0
    def call_impl(self):
        from polygraphy.backend.onnx import util as onnx_util

        with util.FreeOnException(super().call_impl()) as (builder, network,
                                                           parser):
            onnx_model, _ = util.invoke_if_callable(self.onnx_loader)
            _, shape = list(
                onnx_util.get_input_metadata(onnx_model.graph).values())[0]

            success = parser.parse(onnx_model.SerializeToString())
            trt_util.check_onnx_parser_errors(parser, success)

            return builder, network, parser, shape[0]
Beispiel #13
0
    def call_impl(self):
        """
        Returns:
            bytes: The serialized engine.
        """
        engine, owns_engine = util.invoke_if_callable(self._engine)

        with contextlib.ExitStack() as stack:
            if owns_engine:
                stack.enter_context(util.FreeOnException([engine]))

            with engine.serialize() as buffer:
                return bytes(buffer)
Beispiel #14
0
    def call_impl(self):
        """
        Returns:
            trt.ICudaEngine: The engine that was saved.
        """
        engine, owns_engine = util.invoke_if_callable(self._engine)

        with contextlib.ExitStack() as stack:
            if owns_engine:
                stack.enter_context(util.FreeOnException([engine]))

            util.save_file(contents=bytes_from_engine(engine), dest=self.path, description="engine")
            return engine
Beispiel #15
0
    def call_impl(self):
        """
        Returns:
            Tuple[tf.Graph, Sequence[str]]: The TensorFlow graph, and the names of its outputs.
        """
        (graph, outputs), _ = util.invoke_if_callable(self._graph)

        if self.outputs == constants.MARK_ALL:
            outputs = list(
                tf_util.get_output_metadata(graph, layerwise=True).keys())
        elif self.outputs is not None:
            outputs = self.outputs

        return graph, outputs
Beispiel #16
0
    def activate_impl(self):
        def make_buffers(engine):
            """
            Creates empty host and device buffers for the specified engine.
            Always uses binding names from Profile 0.
            """
            device_buffers = OrderedDict()
            host_output_buffers = OrderedDict()

            for idx in range(trt_util.get_bindings_per_profile(engine)):
                binding = engine[idx]
                dtype = trt_util.np_dtype_from_trt(
                    engine.get_binding_dtype(binding))
                device_buffers[binding] = cuda.DeviceArray(dtype=dtype)
                if not engine.binding_is_input(binding):
                    host_output_buffers[binding] = np.empty(shape=tuple(),
                                                            dtype=dtype)
            G_LOGGER.extra_verbose(
                "Created device buffers: {:}".format(device_buffers))
            return device_buffers, host_output_buffers

        engine_or_context, owning = util.invoke_if_callable(
            self._engine_or_context)

        if isinstance(engine_or_context, trt.ICudaEngine):
            self.engine = engine_or_context
            self.owns_engine = owning
            self.context = self.engine.create_execution_context()
            self.owns_context = True
            if not self.context:
                G_LOGGER.critical(
                    "Invalid Context. See error log for details.")
        elif isinstance(engine_or_context, trt.IExecutionContext):
            self.engine = None
            self.owns_engine = False
            self.context = engine_or_context
            self.owns_context = owning
        else:
            G_LOGGER.critical(
                "Invalid Engine or Context. Please ensure the engine was built correctly. See error log for details."
            )

        if not owning:
            G_LOGGER.verbose(
                "Object was provided directly instead of via a Callable. This runner will not assume ownership. "
                "Please ensure it is freed.")

        self.device_buffers, self.host_output_buffers = make_buffers(
            self.context.engine)
        self.stream = cuda.Stream()
Beispiel #17
0
    def call_impl(self):
        """
        Returns:
            Tuple[tf.Graph, Sequence[str]]: The TensorFlow graph, and the names of its outputs.
        """
        (graph, output_names), _ = util.invoke_if_callable(self._graph)
        with tf.Session(graph=graph) as sess:
            sess.run(tf.initializers.global_variables())
            sess.run(tf.initializers.local_variables())

            graphdef = sess.graph.as_graph_def()
            removed = tf.graph_util.remove_training_nodes(graphdef)
            G_LOGGER.ultra_verbose("Removed nodes: {:}".format(removed))

            for node in graphdef.node:
                if node.op == "RefSwitch":
                    node.op = "Switch"
                    for index in range(len(node.input)):
                        if "moving_" in node.input[index]:
                            node.input[index] = node.input[index] + "/read"
                elif node.op == "AssignSub":
                    node.op = "Sub"
                    if "use_locking" in node.attr:
                        del node.attr["use_locking"]
                elif node.op == "AssignAdd":
                    node.op = "Add"
                    if "use_locking" in node.attr:
                        del node.attr["use_locking"]
                elif node.op == "Assign":
                    node.op = "Identity"
                    if "use_locking" in node.attr:
                        del node.attr["use_locking"]
                    if "validate_shape" in node.attr:
                        del node.attr["validate_shape"]
                    if len(node.input) == 2:
                        # input0: ref: Should be from a Variable node. May be uninitialized.
                        # input1: value: The value to be assigned to the variable.
                        node.input[0] = node.input[1]
                        del node.input[1]

            # Strip port information from outputs
            output_names = [name.split(":")[0] for name in output_names]
            output_graph_def = tf.graph_util.convert_variables_to_constants(
                sess, graphdef, output_names)
            output_graph_def = self.constfold(output_graph_def, output_names)
            return graph_from_frozen(output_graph_def)
Beispiel #18
0
    def call_impl(self):
        """
        Returns:
            Tuple[tf.Graph, Sequence[str]]: The TensorFlow graph, and the names of its outputs.
        """
        from tensorflow.contrib import tensorrt as tf_trt

        (graph, output_names), _ = util.invoke_if_callable(self._graph)

        precision_mode = "FP16" if self.fp16 else "FP32"
        precision_mode = "INT8" if self.int8 else precision_mode

        G_LOGGER.info(
            "For TF-TRT, using outputs={:}, max_workspace_size_bytes={:}, max_batch_size={:}, "
            "minimum_segment_size={:}, is_dynamic_op={:}, precision_mode={:}".
            format(
                output_names,
                self.max_workspace_size,
                self.max_batch_size,
                self.minimum_segment_size,
                self.is_dynamic_op,
                precision_mode,
            ))

        graphdef = tf_trt.create_inference_graph(
            graph.as_graph_def(),
            outputs=output_names,
            max_workspace_size_bytes=self.max_workspace_size,
            max_batch_size=self.max_batch_size,
            minimum_segment_size=self.minimum_segment_size,
            is_dynamic_op=self.is_dynamic_op,
            precision_mode=precision_mode,
        )

        segment_number = 0
        for node in graphdef.node:
            if node.op == "TRTEngineOp":
                engine = node.attr["serialized_segment"].s
                segment_number += 1
        G_LOGGER.info(
            "Found {:} engines in TFTRT graph".format(segment_number))

        with tf.Graph().as_default() as graph:
            tf.import_graph_def(graphdef, name="")
            return graph, tf_util.get_graph_output_names(graph)
Beispiel #19
0
 def call_impl(self):
     """
     Returns:
         (trt.IBuilder, trt.INetworkDefinition, trt.OnnxParser):
                 A TensorRT network, as well as the builder used to create it, and the parser
                 used to populate it.
     """
     path = util.invoke_if_callable(self.path)[0]
     if mod.version(trt.__version__) >= mod.version("7.1"):
         with util.FreeOnException(super().call_impl()) as (builder, network, parser):
             # We need to use parse_from_file for the ONNX parser to keep track of the location of the ONNX file for
             # potentially parsing any external weights.
             success = parser.parse_from_file(path)
             trt_util.check_onnx_parser_errors(parser, success)
             return builder, network, parser
     else:
         from polygraphy.backend.common import bytes_from_path
         return network_from_onnx_bytes(bytes_from_path(path), self.explicit_precision)
Beispiel #20
0
    def call_impl(self):
        """
        Returns:
            onnx.ModelProto: The model, after saving it.
        """
        model, _ = util.invoke_if_callable(self._model)
        G_LOGGER.info("Saving ONNX model to: {:}".format(self.path))
        if self.external_data_path is not None:
            G_LOGGER.verbose(
                "Saving external data for ONNX model to: {:}".format(
                    self.external_data_path))
            try:
                external_data_helper.convert_model_to_external_data(
                    model,
                    location=self.external_data_path,
                    all_tensors_to_one_file=util.default(
                        self.all_tensors_to_one_file, True),
                    size_threshold=util.default(self.size_threshold, 1024),
                )
            except TypeError:
                if self.size_threshold is not None:
                    G_LOGGER.warning(
                        "This version of onnx does not support size_threshold in convert_model_to_external_data"
                    )
                external_data_helper.convert_model_to_external_data(
                    model,
                    location=self.external_data_path,
                    all_tensors_to_one_file=util.default(
                        self.all_tensors_to_one_file, True),
                )
        else:
            if self.size_threshold is not None:
                G_LOGGER.warning(
                    "size_threshold is set, but external data path has not been set. "
                    "No external data will be written.")
            if self.all_tensors_to_one_file is not None:
                G_LOGGER.warning(
                    "all_tensors_to_one_file is set, but external data path has not been set. "
                    "No external data will be written.")

        util.makedirs(self.path)
        onnx.save(model, self.path)
        return model
Beispiel #21
0
    def call_impl(self):
        """
        Returns:
            trt.ICudaEngine: The deserialized engine.
        """
        buffer, owns_buffer = util.invoke_if_callable(self._serialized_engine)

        trt.init_libnvinfer_plugins(trt_util.get_trt_logger(), "")
        with contextlib.ExitStack() as stack, trt.Runtime(trt_util.get_trt_logger()) as runtime:
            if owns_buffer:
                try:
                    buffer.__enter__ # IHostMemory is freed only in __exit__
                except AttributeError:
                    pass
                else:
                    stack.enter_context(buffer)

            engine = runtime.deserialize_cuda_engine(buffer)
            if not engine:
                G_LOGGER.critical("Could not deserialize engine. See log for details.")
            return engine
Beispiel #22
0
    def call_impl(self):
        """
        Returns:
            trt.INetworkDefinition: The modified network.
        """
        ret, owns_network = util.invoke_if_callable(self._network)
        builder, network, parser = util.unpack_args(ret, num=3)

        with contextlib.ExitStack() as stack:
            if owns_network:
                stack.enter_context(util.FreeOnException([builder, network, parser]))

            if self.outputs == constants.MARK_ALL:
                trt_util.mark_layerwise(network)
            elif self.outputs is not None:
                trt_util.mark_outputs(network, self.outputs)

            if self.exclude_outputs is not None:
                trt_util.unmark_outputs(network, self.exclude_outputs)

            if parser is None:
                return builder, network
            return builder, network, parser
Beispiel #23
0
    def call_impl(self):
        """
        Returns:
            onnx.ModelProto: The model, after saving it.
        """
        model, _ = util.invoke_if_callable(self._model)
        G_LOGGER.info("Saving ONNX model to: {:}".format(self.path))
        if self.external_data_path is not None:
            try:
                external_data_helper.convert_model_to_external_data(
                    model,
                    location=self.external_data_path,
                    size_threshold=util.default(self.size_threshold, 0))
            except TypeError:
                if self.size_threshold is not None:
                    G_LOGGER.warning(
                        "This version of onnx does not support size_threshold in convert_model_to_external_data"
                    )
                external_data_helper.convert_model_to_external_data(
                    model, location=self.external_data_path)

        onnx.save(model, self.path)
        return model
Beispiel #24
0
    def call_impl(self):
        """
        Returns:
            onnx.ModelProto: The new ONNX model with shapes inferred.
        """
        model, _ = util.invoke_if_callable(self._model)
        external_data_dir = self.external_data_dir

        try:
            if isinstance(model, onnx.ModelProto):
                MODEL_SIZE = model.ByteSize()
                if MODEL_SIZE > LARGE_MODEL_THRESHOLD:
                    G_LOGGER.warning(
                        "Attempting to run shape inference on a large model. "
                        "This may require a large amount of memory.\nIf memory consumption becomes too high, "
                        "the process may be killed. You may want to try disabling shape inference in that case. ",
                        mode=LogMode.ONCE,
                    )

                if MODEL_SIZE > self.save_to_disk_threshold_bytes:
                    G_LOGGER.warning(
                        "Model size ({:.3} MiB) exceeds the in-memory size threshold: {:.3} MiB.\n"
                        "The model will be saved to a temporary file before shape inference is run."
                        .format(
                            MODEL_SIZE / (1024.0**2),
                            self.save_to_disk_threshold_bytes / (1024.0**2)),
                        mode=LogMode.ONCE,
                    )
                    outdir = tempfile.TemporaryDirectory()
                    outpath = os.path.join(outdir.name, "tmp_model.onnx")
                    save_onnx(model, outpath, external_data_path="ext.data")
                    model = outpath
                    external_data_dir = outdir.name

            G_LOGGER.verbose("Starting ONNX shape inference")
            if isinstance(model, onnx.ModelProto):
                model = shape_inference.infer_shapes(model)
            else:
                tmp_path = util.NamedTemporaryFile(prefix="tmp_polygraphy_",
                                                   suffix=".onnx").name
                G_LOGGER.verbose(
                    "Writing shape-inferred model to: {:}".format(tmp_path))
                shape_inference.infer_shapes_path(model, tmp_path)
                # When external_data_dir is unset, use the model's current directory
                model = onnx_from_path(tmp_path,
                                       external_data_dir=util.default(
                                           external_data_dir,
                                           os.path.dirname(model) or None))
            G_LOGGER.verbose("ONNX Shape Inference completed successfully")
        except Exception as err:
            if not self.error_ok:
                raise
            G_LOGGER.warning(
                "ONNX shape inference exited with an error:\n{:}".format(err))
            G_LOGGER.internal_error(
                "ONNX shape inference exited with an error:\n{:}".format(err))

            if not isinstance(model, onnx.ModelProto):
                model = onnx_from_path(
                    model, external_data_dir=self.external_data_dir)
        return model
Beispiel #25
0
 def activate_impl(self):
     self.sess, _ = util.invoke_if_callable(self._sess)
Beispiel #26
0
 def activate_impl(self):
     self.model, _ = util.invoke_if_callable(self._model)
     self.model.eval()
Beispiel #27
0
    def call_impl(self):
        """
        Returns:
            onnx.ModelProto: The ONNX-like, but **not** valid ONNX, representation of the TensorRT network.
        """
        ret, owns_network = util.invoke_if_callable(self._network)
        builder, network, parser = util.unpack_args(ret, num=3)

        if builder is None or network is None:
            G_LOGGER.critical(
                "Expected to recevie a (builder, network) tuple for the `network` parameter, "
                "but received: ({:}, {:})".format(builder, network))

        with contextlib.ExitStack() as stack:
            if owns_network:
                stack.enter_context(builder)
                stack.enter_context(network)
                if parser is not None:
                    stack.enter_context(parser)

            tensor_map = {}

            def tensors_from_meta(meta):
                nonlocal tensor_map
                tensors = []
                for name, (dtype, shape) in meta.items():
                    if name not in tensor_map:
                        tensor_map[name] = gs.Variable(name=name,
                                                       dtype=dtype,
                                                       shape=shape)
                    tensors.append(tensor_map[name])
                return tensors

            nodes = []
            graph_inputs = tensors_from_meta(
                trt_util.get_network_input_metadata(network))
            graph_outputs = tensors_from_meta(
                trt_util.get_network_output_metadata(network))

            LAYER_TYPE_CLASS_MAPPING = trt_util.get_layer_class_mapping()

            for layer in network:
                op_name = layer.type.name
                if layer.type in LAYER_TYPE_CLASS_MAPPING:
                    layer.__class__ = LAYER_TYPE_CLASS_MAPPING[layer.type]

                node_inputs = tensors_from_meta(
                    trt_util.get_layer_input_metadata(layer))
                node_outputs = tensors_from_meta(
                    trt_util.get_layer_output_metadata(layer))
                attrs = {}
                attr_names = trt_util.get_layer_attribute_names(layer)
                for name in attr_names:
                    with G_LOGGER.verbosity():
                        attr = getattr(layer, name)

                    if util.is_sequence(attr) or any(
                            isinstance(attr, cls)
                            for cls in [trt.Dims, trt.Permutation]):
                        try:
                            attr = list(attr)
                        except ValueError:  # Invalid dims
                            attr = []

                    if hasattr(attr, "__entries"):  # TensorRT Enums
                        attr = attr.name

                    if isinstance(attr, trt.ILoop):
                        attr = attr.name

                    VALID_TYPES = [np.ndarray, list, int, str, bool, float]
                    if not any(isinstance(attr, cls) for cls in VALID_TYPES):
                        G_LOGGER.internal_error(
                            "Unknown type: {:} for layer attribute: {:}.\n"
                            "Note: Layer was: {:}".format(
                                type(attr), attr, layer))
                        try:
                            attr = str(attr)
                        except:
                            attr = "<error during conversion>"

                    attrs[name] = attr

                nodes.append(
                    gs.Node(name=layer.name,
                            op=op_name,
                            attrs=attrs,
                            inputs=node_inputs,
                            outputs=node_outputs))

            graph = gs.Graph(name=network.name,
                             inputs=graph_inputs,
                             outputs=graph_outputs,
                             nodes=nodes)

            return gs.export_onnx(graph)
Beispiel #28
0
    def call_impl(self):
        """
        Returns:
            bytes: The serialized engine that was created.
        """
        # If network is a callable, then we own its return value
        ret, owns_network = util.invoke_if_callable(self._network)
        builder, network, parser = util.unpack_args(ret, num=3)

        if builder is None or network is None:
            G_LOGGER.critical("Expected to recevie a (builder, network) tuple for the `network` parameter, "
                              "but received: ({:}, {:})".format(builder, network))

        with contextlib.ExitStack() as stack:
            if owns_network:
                stack.enter_context(builder)
                stack.enter_context(network)
                if parser is not None:
                    stack.enter_context(parser)
            else:
                provided = "Builder and Network" if parser is None else "Builder, Network, and Parser"
                G_LOGGER.verbose("{:} were provided directly instead of via a Callable. This loader will not assume ownership. "
                                 "Please ensure that they are freed.".format(provided))

            config, owns_config = util.invoke_if_callable(self._config, builder, network)
            if owns_config:
                stack.enter_context(config)
            else:
                G_LOGGER.verbose("Builder configuration was provided directly instead of via a Callable. This loader will not assume "
                                 "ownership. Please ensure it is freed.")

            try:
                config.int8_calibrator.__enter__ # Polygraphy calibrator frees device buffers on exit.
            except AttributeError:
                pass
            else:
                stack.enter_context(config.int8_calibrator)

            network_log_mode = "full" if G_LOGGER.severity <= G_LOGGER.ULTRA_VERBOSE else "attrs"
            G_LOGGER.super_verbose(lambda: ("Displaying TensorRT Network:\n" + trt_util.str_from_network(network, mode=network_log_mode)))

            G_LOGGER.start("Building engine with configuration:\n{:}".format(trt_util.str_from_config(config)))

            try:
                engine_bytes = builder.build_serialized_network(network, config)
            except AttributeError:
                engine = builder.build_engine(network, config)
                if not engine:
                    G_LOGGER.critical("Invalid Engine. Please ensure the engine was built correctly")
                stack.enter_context(engine)
                engine_bytes = engine.serialize()

            if not engine_bytes:
                G_LOGGER.critical("Invalid Engine. Please ensure the engine_bytes was built correctly")

            try:
                timing_cache = config.get_timing_cache()
            except AttributeError:
                if self.timing_cache_path:
                    trt_util.fail_unavailable("save_timing_cache in EngineBytesFromNetwork")
            else:
                if timing_cache and self.timing_cache_path:
                    with timing_cache.serialize() as buffer:
                        util.save_file(buffer, self.timing_cache_path, description="tactic timing cache")

            return engine_bytes
Beispiel #29
0
 def activate_impl(self):
     (self.sess, self.output_names), _ = util.invoke_if_callable(self._sess)
Beispiel #30
0
 def load(self):
     model, _ = util.invoke_if_callable(self._model)
     if self.copy:
         model = copy.copy(model)
     return model