def get_input_metadata_impl(self): bindings_per_profile = trt_util.get_bindings_per_profile( self.context.engine) # This function always uses binding names of the 0th profile. return trt_util.get_input_metadata_from_engine( self.context.engine, start_binding=0, end_binding=bindings_per_profile)
def from_engine(engine): buffers = Buffers() bindings_per_profile = trt_util.get_bindings_per_profile(engine) for idx in range(bindings_per_profile): binding = engine[idx] dtype = trt.nptype(engine.get_binding_dtype(binding)) buffers.device_buffers[binding] = cuda.DeviceBuffer(dtype=dtype) if not engine.binding_is_input(binding): buffers.outputs[binding] = np.empty(shape=tuple(), dtype=dtype) G_LOGGER.extra_verbose("Created device buffers: {:}".format( buffers.device_buffers)) return buffers
def make_buffers(engine): """ Creates empty host and device buffers for the specified engine. Always uses binding names from Profile 0. """ device_buffers = OrderedDict() host_output_buffers = OrderedDict() for idx in range(trt_util.get_bindings_per_profile(engine)): binding = engine[idx] dtype = trt_util.np_dtype_from_trt(engine.get_binding_dtype(binding)) device_buffers[binding] = cuda.DeviceArray(dtype=dtype) if not engine.binding_is_input(binding): host_output_buffers[binding] = np.empty(shape=tuple(), dtype=dtype) G_LOGGER.extra_verbose("Created device buffers: {:}".format(device_buffers)) return device_buffers, host_output_buffers