Example #1
0
def load_graph(path):
    """
    Loads a TensorFlow frozen model.

    Args:
        path (Union[str, tf.Graph, tf.GraphDef]):
                A path to the frozen model, or a frozen TensorFlow graph or graphdef.

    Returns:
        tf.Graph: The TensorFlow graph
    """
    if isinstance(path, tf.Graph):
        return path

    if isinstance(path, str):
        graphdef = tf.compat.v1.GraphDef()

        import google

        try:
            graphdef.ParseFromString(
                util.load_file(path, description="GraphDef"))
        except google.protobuf.message.DecodeError:
            G_LOGGER.backtrace()
            G_LOGGER.critical(
                "Could not import TensorFlow GraphDef from: {:}. Is this a valid TensorFlow model?"
                .format(path))
    elif isinstance(path, tf.compat.v1.GraphDef):
        graphdef = path

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graphdef, name="")
        return graph
Example #2
0
            def load_from_cache():
                if self._cache is None or not util.get_file_size(self._cache):
                    return None

                try:
                    return util.load_file(self._cache, description="calibration cache")
                except:
                    G_LOGGER.warning("Could not read from calibration cache: {:}".format(self._cache))
                    return None
Example #3
0
            def load_from_cache():
                if self._cache is None or not util.get_file_size(self._cache):
                    return None

                try:
                    return util.load_file(self._cache,
                                          description="calibration cache")
                except Exception as err:
                    G_LOGGER.error(
                        "Could not read from calibration cache: {:}\nNote: Error was: {:}"
                        .format(self._cache, err))
                    return None
Example #4
0
 def call_impl(self):
     """
     Returns:
         bytes: The contents of the file.
     """
     return util.load_file(self._path, description="bytes")
Example #5
0
    def call_impl(self, builder, network):
        """
        Args:
            builder (trt.Builder):
                    The TensorRT builder to use to create the configuration.
            network (trt.INetworkDefinition):
                    The TensorRT network for which to create the config. The network is used to
                    automatically create a default optimization profile if none are provided.

        Returns:
            trt.IBuilderConfig: The TensorRT builder configuration.
        """
        with util.FreeOnException([builder.create_builder_config()]) as (config, ):
            def try_run(func, name):
                try:
                    return func()
                except AttributeError:
                    trt_util.fail_unavailable("{:} in CreateConfig".format(name))


            def try_set_flag(flag_name):
                return try_run(lambda: config.set_flag(getattr(trt.BuilderFlag, flag_name)), flag_name.lower())


            with G_LOGGER.indent():
                G_LOGGER.verbose("Setting TensorRT Optimization Profiles")
                profiles = copy.deepcopy(self.profiles)
                for profile in profiles:
                    # Last trt_profile is used for set_calibration_profile.
                    trt_profile = profile.fill_defaults(network).to_trt(builder, network)
                    config.add_optimization_profile(trt_profile)
                G_LOGGER.info("Configuring with profiles: {:}".format(profiles))

            config.max_workspace_size = int(self.max_workspace_size)

            if self.strict_types:
                try_set_flag("STRICT_TYPES")

            if self.tf32:
                try_set_flag("TF32")
            else: # TF32 is on by default
                with contextlib.suppress(AttributeError):
                    config.clear_flag(trt.BuilderFlag.TF32)

            if self.fp16:
                try_set_flag("FP16")

            if self.int8:
                try_set_flag("INT8")
                if not network.has_explicit_precision:
                    if self.calibrator is not None:
                        input_metadata = trt_util.get_input_metadata_from_profile(trt_profile, network)
                        with contextlib.suppress(AttributeError): # Polygraphy calibrator has a reset method
                            self.calibrator.reset(input_metadata)
                        config.int8_calibrator = self.calibrator
                        try:
                            config.set_calibration_profile(trt_profile)
                        except:
                            G_LOGGER.extra_verbose("Cannot set calibration profile on TensorRT 7.0 and older.")
                    else:
                        G_LOGGER.warning("Network does not have explicit precision and no calibrator was provided. Please ensure "
                                         "that tensors in the network have dynamic ranges set, or provide a calibrator in order to use int8 mode.")

            if self.sparse_weights:
                try_set_flag("SPARSE_WEIGHTS")

            if self.tactic_sources is not None:
                tactic_sources_flag = 0
                for source in self.tactic_sources:
                    tactic_sources_flag |= (1 << int(source))
                try_run(lambda: config.set_tactic_sources(tactic_sources_flag), name="tactic_sources")

            try:
                if self.timing_cache_path:
                    timing_cache_data = util.load_file(self.timing_cache_path, description="tactic timing cache")
                    cache = config.create_timing_cache(timing_cache_data)
                else:
                    # Create an empty timing cache by default so it will be populated during engine build.
                    # This way, consumers of CreateConfig have the option to use the cache later.
                    cache = config.create_timing_cache(b"")
            except AttributeError:
                if self.timing_cache_path:
                    trt_util.fail_unavailable("load_timing_cache in CreateConfig")
            else:
                config.set_timing_cache(cache, ignore_mismatch=False)

            if self.algorithm_selector is not None:
                def set_algo_selector():
                    config.algorithm_selector = self.algorithm_selector
                try_run(set_algo_selector, "algorithm_selector")

            return config