def build_profile(builder, network, profile): trt_profile = builder.create_optimization_profile() unused_keys = set(profile.keys()) for idx in range(network.num_inputs): inp = network.get_input(idx) if inp.name in unused_keys: unused_keys.remove(inp.name) with G_LOGGER.verbosity(): # WAR for spam from TRT is_shape_tensor = inp.is_shape_tensor if is_shape_tensor: if inp.name in profile: shapes = profile[inp.name] trt_profile.set_shape_input(inp.name, shapes.min, shapes.opt, shapes.max) G_LOGGER.extra_verbose("Input shape-tensor: {:24} | Setting values to min: {:}, opt: {:}, max: {:}".format(inp.name, shapes.min, shapes.opt, shapes.max)) else: G_LOGGER.warning("input shape-tensor: {:24} | No values provided. Assuming this is not a dynamic shape-tensor.".format(inp.name), mode=LogMode.ONCE) elif misc.is_shape_dynamic(inp.shape): shapes = profile[inp.name] trt_profile.set_shape(inp.name, shapes.min, shapes.opt, shapes.max) G_LOGGER.extra_verbose("Input tensor: {:24} | Setting shape to min: {:}, opt: {:}, max: {:}".format(inp.name, shapes.min, shapes.opt, shapes.max)) if unused_keys: G_LOGGER.warning("Some inputs provided in the profile were unused: {:}".format(list(unused_keys))) return check_profile(trt_profile)
def set_shapes_from_feed_dict(self, feed_dict): """ Sets context shapes according to the provided feed_dict, then resizes buffers as needed. Args: feed_dict (OrderedDict[str, numpy.ndarray]): A mapping of input tensor names to corresponding input NumPy arrays. Returns: Tuple[int, int]: The start and end binding indices of the modified bindings. """ def is_dynamic_shape_input(binding): try: self.context.engine.get_profile_shape_input(0, binding) return True except RuntimeError: return False start_binding, end_binding = trt_util.get_active_profile_bindings( self.context) for name, inp in feed_dict.items(): binding = start_binding + self.context.engine[name] shape = inp.shape # Only set shapes if required. # get_shape/get_binding_shape will return what a shape input/data input is currently set to. if is_dynamic_shape_input(binding): # For input shape tensors G_LOGGER.verbose( "Setting shape binding: {:} (index: {:}) to: {:}".format( name, binding, inp)) if tuple(self.context.get_shape(binding)) != tuple(inp): self.context.set_shape_input(binding, inp) elif misc.is_shape_dynamic( self.context.engine.get_binding_shape(binding)): G_LOGGER.verbose( "Setting binding: {:} (index: {:}) to shape: {:}".format( name, binding, shape)) if tuple(self.context.get_binding_shape(binding)) != tuple( shape): self.context.set_binding_shape(binding, shape) if not self.context.all_binding_shapes_specified: G_LOGGER.critical( "Some input shapes were not specified.\nNote: Network inputs are: {:}" .format(self.get_input_metadata())) if not self.context.all_shape_inputs_specified: G_LOGGER.critical( "Some shape inputs were not specified.\nNote: Network inputs are: {:}" .format(self.get_input_metadata())) # Resize device buffers - host buffers will be automatically resized by copy_to for binding in range(start_binding, end_binding): name = self.context.engine[ binding - start_binding] # Use profile 0 binding names for all buffers. shape = tuple(self.context.get_binding_shape(binding)) self.device_buffers[name].resize(shape) return start_binding, end_binding
def is_shape_tensor(name, dtype): if name not in self.input_metadata or name not in self.user_input_metadata: return False _, shape = self.input_metadata[name] is_shape = np.issubdtype(dtype, np.integer) and ( not misc.is_shape_dynamic(shape)) and (len(shape) == 1) user_shape = self.user_input_metadata[name][1] is_shape &= len(user_shape) == shape[0] # Can't have negative values in shapes is_shape &= all([elem >= 0 for elem in user_shape]) return is_shape
def get_static_shape(name, shape): static_shape = shape if misc.is_shape_dynamic(shape): static_shape = misc.override_dynamic_shape(shape) if static_shape != shape and name not in self.user_input_metadata: if not misc.is_valid_shape_override(static_shape, shape): G_LOGGER.critical( "Input tensor: {:24} | Cannot override original shape: {:} to {:}" .format(name, shape, static_shape)) G_LOGGER.warning( "Input tensor: {:24} | Adjusted shape: {:} to: {:}. If this is incorrect, please set input_metadata " "or provide a custom data loader.".format( name, shape, static_shape), mode=LogMode.ONCE) return static_shape
def infer(self, feed_dict): def is_dynamic_shape_input(binding): try: self.engine.get_profile_shape_input(0, binding) return True except RuntimeError: return False start_binding, end_binding = trt_util.get_active_profile_bindings( self.engine, self.context) for name, inp in feed_dict.items(): binding = start_binding + self.engine[name] shape = inp.shape # Only set shapes if required. # get_shape/get_binding_shape will return what a shape input/data input is currently set to. if is_dynamic_shape_input(binding): G_LOGGER.verbose( "Setting shape binding: {:} (index: {:}) to: {:}".format( name, binding, inp)) if tuple(self.context.get_shape(binding)) != tuple(inp): self.context.set_shape_input(binding, inp) elif misc.is_shape_dynamic(self.engine.get_binding_shape(binding)): G_LOGGER.verbose( "Setting binding: {:} (index: {:}) to shape: {:}".format( name, binding, shape)) if tuple(self.context.get_binding_shape(binding)) != tuple( shape): self.context.set_binding_shape(binding, shape) if not self.context.all_binding_shapes_specified: G_LOGGER.critical( "Some input shapes were not specified.\nNote: Network inputs are: {:}" .format(self.get_input_metadata())) if not self.context.all_shape_inputs_specified: G_LOGGER.critical( "Some shape inputs were not specified.\nNote: Network inputs are: {:}" .format(self.get_input_metadata())) # Inference # Need to resize output buffers self.buffers.resize(self.engine, self.context, start_binding=start_binding, end_binding=end_binding) start = time.time() self.buffers.copy_inputs(feed_dict, self.stream) # Need to offset bindings in case the active profile is not 0. status = self.context.execute_async_v2( bindings=[0] * start_binding + self.buffers.bindings(), stream_handle=self.stream.address()) if not status: G_LOGGER.critical( "Model execution failed. Please see the log messages above for details" ) self.buffers.copy_outputs(self.stream) self.stream.synchronize() end = time.time() self.inference_time = end - start return self.buffers.outputs