Пример #1
0
def build_profile(builder, network, profile):
    trt_profile = builder.create_optimization_profile()
    unused_keys = set(profile.keys())
    for idx in range(network.num_inputs):
        inp = network.get_input(idx)
        if inp.name in unused_keys:
            unused_keys.remove(inp.name)

        with G_LOGGER.verbosity(): # WAR for spam from TRT
            is_shape_tensor = inp.is_shape_tensor

        if is_shape_tensor:
            if inp.name in profile:
                shapes = profile[inp.name]
                trt_profile.set_shape_input(inp.name, shapes.min, shapes.opt, shapes.max)
                G_LOGGER.extra_verbose("Input shape-tensor: {:24} | Setting values to min: {:}, opt: {:}, max: {:}".format(inp.name, shapes.min, shapes.opt, shapes.max))
            else:
                G_LOGGER.warning("input shape-tensor: {:24} | No values provided. Assuming this is not a dynamic shape-tensor.".format(inp.name), mode=LogMode.ONCE)
        elif misc.is_shape_dynamic(inp.shape):
            shapes = profile[inp.name]
            trt_profile.set_shape(inp.name, shapes.min, shapes.opt, shapes.max)
            G_LOGGER.extra_verbose("Input tensor: {:24} | Setting shape to min: {:}, opt: {:}, max: {:}".format(inp.name, shapes.min, shapes.opt, shapes.max))

    if unused_keys:
        G_LOGGER.warning("Some inputs provided in the profile were unused: {:}".format(list(unused_keys)))

    return check_profile(trt_profile)
Пример #2
0
    def set_shapes_from_feed_dict(self, feed_dict):
        """
        Sets context shapes according to the provided feed_dict, then resizes
        buffers as needed.

        Args:
            feed_dict (OrderedDict[str, numpy.ndarray]): A mapping of input tensor names to corresponding input NumPy arrays.

        Returns:
            Tuple[int, int]: The start and end binding indices of the modified bindings.
        """
        def is_dynamic_shape_input(binding):
            try:
                self.context.engine.get_profile_shape_input(0, binding)
                return True
            except RuntimeError:
                return False

        start_binding, end_binding = trt_util.get_active_profile_bindings(
            self.context)
        for name, inp in feed_dict.items():
            binding = start_binding + self.context.engine[name]
            shape = inp.shape
            # Only set shapes if required.
            # get_shape/get_binding_shape will return what a shape input/data input is currently set to.
            if is_dynamic_shape_input(binding):  # For input shape tensors
                G_LOGGER.verbose(
                    "Setting shape binding: {:} (index: {:}) to: {:}".format(
                        name, binding, inp))
                if tuple(self.context.get_shape(binding)) != tuple(inp):
                    self.context.set_shape_input(binding, inp)

            elif misc.is_shape_dynamic(
                    self.context.engine.get_binding_shape(binding)):
                G_LOGGER.verbose(
                    "Setting binding: {:} (index: {:}) to shape: {:}".format(
                        name, binding, shape))
                if tuple(self.context.get_binding_shape(binding)) != tuple(
                        shape):
                    self.context.set_binding_shape(binding, shape)

        if not self.context.all_binding_shapes_specified:
            G_LOGGER.critical(
                "Some input shapes were not specified.\nNote: Network inputs are: {:}"
                .format(self.get_input_metadata()))
        if not self.context.all_shape_inputs_specified:
            G_LOGGER.critical(
                "Some shape inputs were not specified.\nNote: Network inputs are: {:}"
                .format(self.get_input_metadata()))

        # Resize device buffers - host buffers will be automatically resized by copy_to
        for binding in range(start_binding, end_binding):
            name = self.context.engine[
                binding -
                start_binding]  # Use profile 0 binding names for all buffers.
            shape = tuple(self.context.get_binding_shape(binding))
            self.device_buffers[name].resize(shape)

        return start_binding, end_binding
Пример #3
0
        def is_shape_tensor(name, dtype):
            if name not in self.input_metadata or name not in self.user_input_metadata:
                return False

            _, shape = self.input_metadata[name]
            is_shape = np.issubdtype(dtype, np.integer) and (
                not misc.is_shape_dynamic(shape)) and (len(shape) == 1)

            user_shape = self.user_input_metadata[name][1]
            is_shape &= len(user_shape) == shape[0]
            # Can't have negative values in shapes
            is_shape &= all([elem >= 0 for elem in user_shape])
            return is_shape
Пример #4
0
 def get_static_shape(name, shape):
     static_shape = shape
     if misc.is_shape_dynamic(shape):
         static_shape = misc.override_dynamic_shape(shape)
         if static_shape != shape and name not in self.user_input_metadata:
             if not misc.is_valid_shape_override(static_shape, shape):
                 G_LOGGER.critical(
                     "Input tensor: {:24} | Cannot override original shape: {:} to {:}"
                     .format(name, shape, static_shape))
             G_LOGGER.warning(
                 "Input tensor: {:24} | Adjusted shape: {:} to: {:}. If this is incorrect, please set input_metadata "
                 "or provide a custom data loader.".format(
                     name, shape, static_shape),
                 mode=LogMode.ONCE)
     return static_shape
Пример #5
0
    def infer(self, feed_dict):
        def is_dynamic_shape_input(binding):
            try:
                self.engine.get_profile_shape_input(0, binding)
                return True
            except RuntimeError:
                return False

        start_binding, end_binding = trt_util.get_active_profile_bindings(
            self.engine, self.context)
        for name, inp in feed_dict.items():
            binding = start_binding + self.engine[name]
            shape = inp.shape
            # Only set shapes if required.
            # get_shape/get_binding_shape will return what a shape input/data input is currently set to.
            if is_dynamic_shape_input(binding):
                G_LOGGER.verbose(
                    "Setting shape binding: {:} (index: {:}) to: {:}".format(
                        name, binding, inp))
                if tuple(self.context.get_shape(binding)) != tuple(inp):
                    self.context.set_shape_input(binding, inp)

            elif misc.is_shape_dynamic(self.engine.get_binding_shape(binding)):
                G_LOGGER.verbose(
                    "Setting binding: {:} (index: {:}) to shape: {:}".format(
                        name, binding, shape))
                if tuple(self.context.get_binding_shape(binding)) != tuple(
                        shape):
                    self.context.set_binding_shape(binding, shape)

        if not self.context.all_binding_shapes_specified:
            G_LOGGER.critical(
                "Some input shapes were not specified.\nNote: Network inputs are: {:}"
                .format(self.get_input_metadata()))
        if not self.context.all_shape_inputs_specified:
            G_LOGGER.critical(
                "Some shape inputs were not specified.\nNote: Network inputs are: {:}"
                .format(self.get_input_metadata()))

        # Inference
        # Need to resize output buffers
        self.buffers.resize(self.engine,
                            self.context,
                            start_binding=start_binding,
                            end_binding=end_binding)

        start = time.time()
        self.buffers.copy_inputs(feed_dict, self.stream)
        # Need to offset bindings in case the active profile is not 0.
        status = self.context.execute_async_v2(
            bindings=[0] * start_binding + self.buffers.bindings(),
            stream_handle=self.stream.address())
        if not status:
            G_LOGGER.critical(
                "Model execution failed. Please see the log messages above for details"
            )

        self.buffers.copy_outputs(self.stream)
        self.stream.synchronize()
        end = time.time()

        self.inference_time = end - start
        return self.buffers.outputs