示例#1
0
文件: serde.py 项目: clayne/TensorRT
def load_json(src, description=None):
    """
    Loads a file and decodes the JSON contents.

    NOTE: For Polygraphy objects, you should use the ``load()`` method instead.

    Args:
        src (Union[str, file-like]): The path or file-like object to load from.

    Returns:
        object: The object, or `None` if nothing could be read.
    """
    try:
        return from_json(util.load_file(src, mode="r", description=description))
    except UnicodeDecodeError:
        # This is a pickle file from Polygraphy 0.26.1 or older.
        mod.warn_deprecated("pickle", use_instead="JSON", remove_in="0.31.0")
        G_LOGGER.critical("It looks like you're trying to load a Pickle file.\nPolygraphy migrated to using JSON "
                          "instead of Pickle in version 0.27.0 for security reasons.\nYou can convert your existing "
                          "pickled data to JSON using the command-line tool: `polygraphy to-json {:} -o new.json`.\nAll data serialized "
                          "from this and future versions of Polygraphy will always use JSON. ".format(src))
示例#2
0
#
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from polygraphy import mod
from polygraphy.exception import *

mod.warn_deprecated("polygraphy.common.exception",
                    "polygraphy.exception",
                    remove_in="0.32.0")
示例#3
0
#
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from polygraphy import mod
from polygraphy.func import *

mod.warn_deprecated("polygraphy.common.func",
                    "polygraphy.func",
                    remove_in="0.34.0")
示例#4
0
 def __init__(self):
     mod.warn_deprecated("to-json", use_instead="JSON serialization", remove_in="0.31.0")
     super().__init__(name="to-json")
示例#5
0
#
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from polygraphy import mod
from polygraphy.json import *

mod.warn_deprecated("JSON utilities in polygraphy.util",
                    "polygraphy.json",
                    remove_in="0.35.0")
示例#6
0
    def __init__(self,
                 seed=None,
                 iterations=None,
                 input_metadata=None,
                 int_range=None,
                 float_range=None,
                 val_range=None):
        """
        Args:
            seed (int):
                    The seed to use when generating random inputs.
                    Defaults to ``util.constants.DEFAULT_SEED``.
            iterations (int):
                    The number of iterations for which to supply data.
                    Defaults to 1.
            input_metadata (TensorMetadata):
                    A mapping of input names to their corresponding shapes and data types.
                    This will be used to determine what shapes to supply for inputs with dynamic shape, as
                    well as to set the data type of the generated inputs.
                    If either dtype or shape are None, then the value will be automatically determined.
                    For input shape tensors, i.e. inputs whose *value* describes a shape in the model, the
                    provided shape will be used to populate the values of the inputs, rather than to determine
                    their shape.
            val_range (Union[Tuple[number], Dict[str, Tuple[number]]]):
                    A tuple containing exactly 2 numbers, indicating the minimum and maximum values (inclusive)
                    the data loader should generate.
                    If either value in the tuple is None, the default will be used for that value.
                    If None is provided instead of a tuple, then the default values will be used for both the
                    minimum and maximum.
                    This can be specified on a per-input basis using a dictionary. In that case,
                    use an empty string ("") as the key to specify default range for inputs not explicitly listed.
                    Defaults to (0.0, 1.0).

            int_range (Tuple[int]):
                    [DEPRECATED - Use val_range instead]
                    A tuple containing exactly 2 integers, indicating the minimum and maximum integer values (inclusive)
                    the data loader should generate. If either value in the tuple is None, the default will be used
                    for that value.
                    If None is provided instead of a tuple, then the default values will be used for both the
                    minimum and maximum.
            float_range (Tuple[float]):
                    [DEPRECATED - Use val_range instead]
                    A tuple containing exactly 2 floats, indicating the minimum and maximum float values (inclusive)
                    the data loader should generate. If either value in the tuple is None, the default will be used
                    for that value.
                    If None is provided instead of a tuple, then the default values will be used for both the
                    minimum and maximum.
        """
        def default_tuple(tup, default):
            if tup is None or (not isinstance(tup, tuple)
                               and not isinstance(tup, list)):
                return default
            new_tup = []
            for elem, default_elem in zip(tup, default):
                new_tup.append(util.default(elem, default_elem))
            return tuple(new_tup)

        self.seed = util.default(seed, constants.DEFAULT_SEED)
        self.iterations = util.default(iterations, 1)
        self.user_input_metadata = util.default(input_metadata, {})

        self.int_range_set = int_range is not None
        if self.int_range_set:
            mod.warn_deprecated("The int_range parameter in DataLoader",
                                "val_range",
                                remove_in="0.35.0")
        self.int_range = default_tuple(int_range, (1, 25))

        self.float_range_set = float_range is not None
        if self.float_range_set:
            mod.warn_deprecated("The float_range parameter in DataLoader",
                                "val_range",
                                remove_in="0.35.0")
        self.float_range = default_tuple(float_range, (-1.0, 1.0))

        self.input_metadata = None
        self.default_val_range = default_tuple(val_range, (0.0, 1.0))
        self.val_range = util.default(val_range, self.default_val_range)

        if self.user_input_metadata:
            G_LOGGER.info(
                "Will generate inference input data according to provided TensorMetadata: {}"
                .format(self.user_input_metadata))
示例#7
0
 def basic_compare_func(*args, **kwargs):
     mod.warn_deprecated("basic_compare_func",
                         remove_in="0.40.0",
                         use_instead="simple")
     return CompareFunc.simple(*args, **kwargs)
示例#8
0
    def parse(self, args):
        trt_min_shapes = util.default(args_util.get(args, "trt_min_shapes"),
                                      [])
        trt_max_shapes = util.default(args_util.get(args, "trt_max_shapes"),
                                      [])
        trt_opt_shapes = util.default(args_util.get(args, "trt_opt_shapes"),
                                      [])

        default_shapes = TensorMetadata()
        if self.model_args is not None:
            assert hasattr(self.model_args, "input_shapes"
                           ), "ModelArgs must be parsed before TrtConfigArgs!"
            default_shapes = self.model_args.input_shapes

        self.profile_dicts = parse_profile_shapes(default_shapes,
                                                  trt_min_shapes,
                                                  trt_opt_shapes,
                                                  trt_max_shapes)

        workspace = args_util.get(args, "workspace")
        self.workspace = int(workspace) if workspace is not None else workspace

        self.tf32 = args_util.get(args, "tf32")
        self.fp16 = args_util.get(args, "fp16")
        self.int8 = args_util.get(args, "int8")
        self.strict_types = args_util.get(args, "strict_types")
        self.restricted = args_util.get(args, "restricted")

        self.calibration_cache = args_util.get(args, "calibration_cache")
        calib_base = args_util.get(args, "calibration_base_class")
        self.calibration_base_class = None
        if calib_base is not None:
            calib_base = safe(assert_identifier(calib_base))
            self.calibration_base_class = inline(
                safe("trt.{:}", inline(calib_base)))

        self.quantile = args_util.get(args, "quantile")
        self.regression_cutoff = args_util.get(args, "regression_cutoff")

        self.sparse_weights = args_util.get(args, "sparse_weights")
        self.timing_cache = args_util.get(args, "timing_cache")

        tactic_replay = args_util.get(args, "tactic_replay")
        self.load_tactics = args_util.get(args, "load_tactics")
        self.save_tactics = args_util.get(args, "save_tactics")
        if tactic_replay is not None:
            mod.warn_deprecated("--tactic-replay",
                                "--save-tactics or --load-tactics",
                                remove_in="0.35.0")
            G_LOGGER.warning(
                "--tactic-replay is deprecated. Use either --save-tactics or --load-tactics instead."
            )
            if os.path.exists(
                    tactic_replay) and util.get_file_size(tactic_replay) > 0:
                self.load_tactics = tactic_replay
            else:
                self.save_tactics = tactic_replay

        tactic_sources = args_util.get(args, "tactic_sources")
        self.tactic_sources = None
        if tactic_sources is not None:
            self.tactic_sources = []
            for source in tactic_sources:
                source = safe(assert_identifier(source.upper()))
                source_str = safe("trt.TacticSource.{:}", inline(source))
                self.tactic_sources.append(inline(source_str))

        self.trt_config_script = args_util.get(args, "trt_config_script")
        self.trt_config_func_name = args_util.get(args, "trt_config_func_name")
示例#9
0
#
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from polygraphy import mod
from polygraphy.constants import *

mod.warn_deprecated("polygraphy.common.constants",
                    "polygraphy.constants",
                    remove_in="0.32.0")