def build_engines(self): """Calls self.initialize() if it has not been called yet. Builds and saves the engine.""" if not self.initialized: self.initialize() # Create output directory if it does not exist. if not os.path.exists(self.engine_dir): os.makedirs(self.engine_dir) engine_name = self._get_engine_fpath(self.device_type, self.batch_size) logging.info("Building {:}".format(engine_name)) if self.network.has_implicit_batch_dimension: self.builder.max_batch_size = self.batch_size else: self.profiles = [] # Create optimization profiles if on GPU if self.dla_core is None: for i in range(self.num_profiles): profile = self.builder.create_optimization_profile() for input_idx in range(self.network.num_inputs): input_shape = self.network.get_input(input_idx).shape input_name = self.network.get_input(input_idx).name min_shape = trt.Dims(input_shape) min_shape[0] = 1 max_shape = trt.Dims(input_shape) max_shape[0] = self.batch_size profile.set_shape(input_name, min_shape, max_shape, max_shape) if not profile: raise RuntimeError("Invalid optimization profile!") self.builder_config.add_optimization_profile(profile) self.profiles.append(profile) else: # Use fixed batch size if on DLA for input_idx in range(self.network.num_inputs): input_shape = self.network.get_input(input_idx).shape input_shape[0] = self.batch_size self.network.get_input(input_idx).shape = input_shape # Build engines engine = self.builder.build_engine(self.network, self.builder_config) buf = engine.serialize() with open(engine_name, 'wb') as f: f.write(buf)
def pixel_unshuffle(network: trt.INetworkDefinition, input: trt.ITensor, downscale_factor: int) -> trt.ITensor: n, ic, ih, iw = input.shape assert ih % downscale_factor == 0 and ih % downscale_factor == 0 oc = ic * (downscale_factor**2) oh = ih // downscale_factor ow = iw // downscale_factor reshape = network.add_shuffle(input) reshape.reshape_dims = trt.Dims( [n, ic, oh, downscale_factor, ow, downscale_factor]) reshape.second_transpose = trt.Permutation([0, 1, 3, 5, 2, 4]) reshape = network.add_shuffle(reshape.get_output(0)) reshape.reshape_dims = trt.Dims([n, oc, oh, ow]) return reshape.get_output(0)
def pixel_shuffle(network: trt.INetworkDefinition, input: trt.ITensor, upscale_factor: int) -> trt.ITensor: n, ic, ih, iw = input.shape assert ic % (upscale_factor**2) == 0 oc = ic // (upscale_factor**2) oh = ih * upscale_factor ow = iw * upscale_factor reshape = network.add_shuffle(input) reshape.reshape_dims = trt.Dims( [n, oc, upscale_factor, upscale_factor, ih, iw]) reshape.second_transpose = trt.Permutation([0, 1, 4, 2, 5, 3]) reshape = network.add_shuffle(reshape.get_output(0)) reshape.reshape_dims = trt.Dims([n, oc, oh, ow]) return reshape.get_output(0)
def sequence_class_output(prefix, init_dict, network, input_tensor, softmax=True): logging.info(input_tensor.shape) seq_len = input_tensor.shape[1] hidden_size = input_tensor.shape[2] shuf = network.add_shuffle(input_tensor) shuf.first_transpose = (0, 3, 4, 1, 2) logging.info("seq class in: ", shuf.get_output(0).shape) in_shape_tensor = network.add_shape(shuf.get_output(0)).get_output(0) out_shape_tensor = network.add_gather( in_shape_tensor, network.add_constant( (5, ), trt.Weights(np.array([0, 1, 2, 2, 4]).astype(np.int32))).get_output(0), 0, ).get_output(0) first_token_tensor = network.add_slice( shuf.get_output(0), start=(0, 0, 0, 0, 0), shape=(-1, 1, 1, 1, hidden_size), stride=(1, 1, 1, 1, 1), ) first_token_tensor.set_input( 1, network.add_constant( (5, ), trt.Weights(np.array([0, 0, 0, 0, 0]).astype(np.int32))).get_output(0), ) first_token_tensor.set_input(2, out_shape_tensor) W_out = init_dict[prefix + "mlp.layer0." + SQD_W] B_out = init_dict[prefix + "mlp.layer0." + SQD_B] dense = network.add_fully_connected(first_token_tensor.get_output(0), W_out.shape[0], W_out, B_out) dense_relu = network.add_activation(dense.get_output(0), trt.ActivationType.RELU) W_out = init_dict[prefix + "mlp.layer2." + SQD_W] B_out = init_dict[prefix + "mlp.layer2." + SQD_B] classifier = network.add_fully_connected(dense_relu.get_output(0), W_out.shape[0], W_out, B_out) if softmax: probs = network.add_softmax(classifier.get_output(0)) probs.axes = 4 # last dimension classifier = probs classifier = network.add_shuffle(classifier.get_output(0)) classifier.reshape_dims = trt.Dims([0, -1]) set_layer_name(classifier, prefix, "classifier") logging.info("seq class: ", classifier.get_output(0).shape) return classifier
def token_class_output(prefix, init_dict, network, input_tensor, softmax=True): W_out = init_dict[prefix + "mlp.layer0." + SQD_W] B_out = init_dict[prefix + "mlp.layer0." + SQD_B] dense = network.add_fully_connected(input_tensor, W_out.shape[0], W_out, B_out) dense_relu = network.add_activation(dense.get_output(0), trt.ActivationType.RELU) W_out = init_dict[prefix + "mlp.layer2." + SQD_W] B_out = init_dict[prefix + "mlp.layer2." + SQD_B] classifier = network.add_fully_connected(dense_relu.get_output(0), W_out.shape[0], W_out, B_out) if softmax: probs = network.add_softmax(classifier.get_output(0)) probs.axes = 4 # last dimension classifier = probs set_layer_name(classifier, prefix, "classifier") classifier = network.add_shuffle(classifier.get_output(0)) classifier.reshape_dims = trt.Dims([0, 0, 0]) logging.info("tok class: ", classifier.get_output(0).shape) return classifier
def forward(self, *inputs, out=None): batch_size = inputs[0].shape[0] bindings = [None] * (len(self.input_names) + len(self.output_names)) for i, input_name in enumerate(self.input_names): # XXX Conclude dynamic input shape idx = self.engine.get_binding_index(input_name) binding_shape = tuple(self.context.get_binding_shape(idx)) arg_shape = tuple(inputs[i].shape) if binding_shape != arg_shape: # logging.info(f"Reallocate {input_name}.shape{binding_shape} -> {arg_shape}") self.context.set_binding_shape(idx, trt.Dims(arg_shape)) bindings[idx] = inputs[i].contiguous().data_ptr() # create output tensors outputs = [None] * len(self.output_names) if out is None: for i, output_name in enumerate(self.output_names): idx = self.engine.get_binding_index(output_name) dtype = t2t.torch_dtype_from_trt( self.engine.get_binding_dtype(idx)) shape = tuple(self.context.get_binding_shape(idx)) #assert shape[0] == batch_size device = t2t.torch_device_from_trt( self.engine.get_location(idx)) output = th.empty(size=shape, dtype=dtype, device=device) outputs[i] = output bindings[idx] = output.data_ptr() else: for i, output_name in enumerate(self.output_names): idx = self.engine.get_binding_index(output_name) outputs[i] = out[i] bindings[idx] = out[i].data_ptr() self.context.execute_async_v2( bindings=bindings, stream_handle=th.cuda.current_stream().cuda_stream) outputs = tuple(outputs) if len(outputs) == 1: outputs = outputs[0] return outputs
def allocate_buffers_with_existing_inputs(context, inp): ''' allocate_buffers() (see TRT python samples) but uses an existing inputs on device inp: List of pointers to device memory. Pointers are in the same order as would be produced by allocate_buffers(). That is, inputs are in the order defined by iterating through `engine` ''' # Add input to bindings bindings = [0, 0] outputs = [] engine = context.engine batch_size = inp[0].shape inp_idx = engine.get_binding_index("FEATURES") inp_b = inp[0].data_ptr() assert (inp[0].is_contiguous()) bindings[inp_idx] = inp_b sh = inp[0].shape batch_size = sh[0] orig_shape = context.get_binding_shape(inp_idx) if orig_shape[0] == -1: context.set_binding_shape(inp_idx, trt.Dims([batch_size, sh[1], sh[2]])) assert context.all_binding_shapes_specified out_idx = engine.get_binding_index("LOGITS") # Allocate output buffer by querying the size from the context. This may be different for different input shapes. out_shape = context.get_binding_shape(out_idx) #print ("Out_shape: ", out_shape) h_output = cuda.pagelocked_empty(tuple(out_shape), dtype=np.float32()) # print ("Out bytes: " , h_output.nbytes) d_output = cuda.mem_alloc(h_output.nbytes) bindings[out_idx] = int(d_output) hdm = HostDeviceMem(h_output, d_output) outputs.append(hdm) return outputs, bindings, out_shape
# replace LeakyRelu wiht LReLU_TRT plugin nodes = [n.name for n in dynamic_graph.as_graph_def().node] ns = {} for node in nodes: if "LeakyRelu" in node: ns[node] = gs.create_plugin_node(name=node, op="LReLU_TRT", negSlope=0.1) dynamic_graph.collapse_namespaces(ns) # convert to UFF uff_model = uff.from_tensorflow(dynamic_graph.as_graph_def(), output_nodes=[output_node]) # convert to TRT G_LOGGER = trt.Logger(trt.Logger.ERROR) trt.init_libnvinfer_plugins(G_LOGGER, "") builder = trt.Builder(G_LOGGER) builder.max_batch_size = 1 builder.max_workspace_size = 1 << 20 if data_type == trt.DataType.HALF: builder.fp16_mode = True network = builder.create_network() parser = trt.UffParser() parser.register_input(input_node, trt.Dims([3, 256, 512])) parser.register_output(output_node) parser.parse_buffer(uff_model, network, data_type) engine = builder.build_cuda_engine(network) with open(engine_file, "wb") as f: f.write(engine.serialize())
def build_engine(onnx_path, using_half): trt.init_libnvinfer_plugins(None, '') engine_file = onnx_path.replace(".onnx", ".engine") if os.path.exists(engine_file): with open(engine_file, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime: return runtime.deserialize_cuda_engine(f.read()) with trt.Builder(TRT_LOGGER) as builder, builder.create_network( EXPLICIT_BATCH) as network, trt.OnnxParser(network, TRT_LOGGER) as parser: builder.max_batch_size = 1 # always 1 for explicit batch config = builder.create_builder_config() config.max_workspace_size = GiB(1) if using_half: config.set_flag(trt.BuilderFlag.FP16) # Load the Onnx model and parse it in order to populate the TensorRT network. with open(onnx_path, 'rb') as model: if not parser.parse(model.read()): print('ERROR: Failed to parse the ONNX file.') for error in range(parser.num_errors): print(parser.get_error(error)) return None previous_output = network.get_output(0) network.unmark_output(previous_output) # slice boxes, obj_score, class_scores strides = trt.Dims([1, 1, 1]) starts = trt.Dims([0, 0, 0]) bs, num_boxes, _ = previous_output.shape shapes = trt.Dims([bs, num_boxes, 4]) boxes = network.add_slice(previous_output, starts, shapes, strides) starts[2] = 4 shapes[2] = 1 obj_score = network.add_slice(previous_output, starts, shapes, strides) starts[2] = 5 shapes[2] = num_classes scores = network.add_slice(previous_output, starts, shapes, strides) indices = network.add_constant( trt.Dims([num_classes]), trt.Weights(np.zeros(num_classes, np.int32))) gather_layer = network.add_gather(obj_score.get_output(0), indices.get_output(0), 2) # scores = obj_score * class_scores => [bs, num_boxes, nc] updated_scores = network.add_elementwise(gather_layer.get_output(0), scores.get_output(0), trt.ElementWiseOperation.PROD) # reshape box to [bs, num_boxes, 1, 4] reshaped_boxes = network.add_shuffle(boxes.get_output(0)) reshaped_boxes.reshape_dims = trt.Dims([0, 0, 1, 4]) # add batchedNMSPlugin, inputs:[boxes:(bs, num, 1, 4), scores:(bs, num, 1)] trt.init_libnvinfer_plugins(TRT_LOGGER, "") registry = trt.get_plugin_registry() assert (registry) creator = registry.get_plugin_creator("BatchedNMS_TRT", "1") assert (creator) fc = [] fc.append( trt.PluginField("shareLocation", np.array([1], dtype=np.int), trt.PluginFieldType.INT32)) fc.append( trt.PluginField("backgroundLabelId", np.array([-1], dtype=np.int), trt.PluginFieldType.INT32)) fc.append( trt.PluginField("numClasses", np.array([num_classes], dtype=np.int), trt.PluginFieldType.INT32)) fc.append( trt.PluginField("topK", np.array([topK], dtype=np.int), trt.PluginFieldType.INT32)) fc.append( trt.PluginField("keepTopK", np.array([keepTopK], dtype=np.int), trt.PluginFieldType.INT32)) fc.append( trt.PluginField("scoreThreshold", np.array([conf_thres], dtype=np.float32), trt.PluginFieldType.FLOAT32)) fc.append( trt.PluginField("iouThreshold", np.array([iou_thres], dtype=np.float32), trt.PluginFieldType.FLOAT32)) fc.append( trt.PluginField("isNormalized", np.array([0], dtype=np.int), trt.PluginFieldType.INT32)) fc.append( trt.PluginField("clipBoxes", np.array([0], dtype=np.int), trt.PluginFieldType.INT32)) fc = trt.PluginFieldCollection(fc) nms_layer = creator.create_plugin("nms_layer", fc) layer = network.add_plugin_v2( [reshaped_boxes.get_output(0), updated_scores.get_output(0)], nms_layer) layer.get_output(0).name = "num_detections" layer.get_output(1).name = "nmsed_boxes" layer.get_output(2).name = "nmsed_scores" layer.get_output(3).name = "nmsed_classes" for i in range(4): network.mark_output(layer.get_output(i)) return builder.build_engine(network, config)
def initialize(self): useConvForFC_bottom = (self.precision == "int8") useConvForFC_top = (self.precision == "int8") interactionsOutputInterleaved = False if self.need_calibration or self.input_dtype != "int8" else True # Check if we should split the model into the binary file with embedding weights quantized and model without embeddings if not (os.path.isfile(self.embedding_weights_binary_filepath) and os.path.isfile(self.model_without_embedding_weights_filepath)): logging.info("Loading checkpoint from " + self.model_filepath) self.weights = torch.load(self.model_filepath, map_location="cpu")["state_dict"] self.dump_embedding_weights_to_binary_file() logging.info("Writing model without embedding weights to " + self.model_without_embedding_weights_filepath) torch.save(self.weights, self.model_without_embedding_weights_filepath) del self.weights # Dump row frequencies to file in binary format if self.use_row_frequencies and not os.path.isfile(self.row_frequencies_binary_filepath): logging.info("Writing row frequencies to " + self.row_frequencies_binary_filepath) self.dump_row_frequencies_to_binary_file() # Load weights self.weights = torch.load(self.model_without_embedding_weights_filepath, map_location="cpu") # Create network. self.network = self.builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) # Numerical input numerical_input = self.network.add_input("numerical_input", trt.DataType.FLOAT, (-1, self.num_numerical_inputs, 1, 1)) if not self.need_calibration: if self.input_dtype == "int8": numerical_input.dtype = trt.int8 elif self.input_dtype == "fp16": numerical_input.dtype = trt.float16 if self.input_format == "linear": numerical_input.allowed_formats = 1 << int(trt.TensorFormat.LINEAR) elif self.input_format == "chw4": numerical_input.allowed_formats = 1 << int(trt.TensorFormat.CHW4) elif self.input_format == "chw32": numerical_input.allowed_formats = 1 << int(trt.TensorFormat.CHW32) # Bottom MLP if self.need_calibration or self.input_dtype != "int8": bottom_mlp = self.add_mlp(numerical_input, self.num_numerical_inputs, self.bottom_mlp_channels, self.bottom_mlp_names, last_relu=True, useConvForFC=useConvForFC_bottom) else: bottom_mlp_plugin, output_tesnor_name = self.add_fused_bottom_mlp("DLRM_BOTTOM_MLP_TRT", numerical_input, self.num_numerical_inputs, self.bottom_mlp_channels, self.bottom_mlp_names) bottom_mlp = self.network.add_plugin_v2([numerical_input], bottom_mlp_plugin) bottom_mlp.get_output(0).name = output_tesnor_name bottom_mlp_shuffle = self.network.add_shuffle(bottom_mlp.get_output(0)) bottom_mlp_shuffle.reshape_dims = trt.Dims((-1, 1, self.embedding_size)) # Index input index_input = self.network.add_input("index_input", trt.DataType.INT32, (-1, self.num_features)) # Embedding lookup and interactions dlrm_interactions_plugin = self.get_dlrm_interactions_plugin("DLRM_INTERACTIONS_TRT", np.cumsum(np.array([0] + self.embedding_rows[:-1]).astype(np.int32)).astype(np.int32), interactionsOutputInterleaved) interaction_output_concat = self.network.add_plugin_v2([bottom_mlp.get_output(0), index_input], dlrm_interactions_plugin) interaction_output_concat.name = "interaction_plugin" interaction_output_concat.get_output(0).name = "interaction_output_concat_output" if self.INTERLEAVED_TOP_MLP and not interactionsOutputInterleaved: # Shuffle from [BS, C, 1, 1] to [BS//2, C, 2, 1] before top_mlp interleave_pre_top_mlp = self.network.add_shuffle(interaction_output_concat.get_output(0)) interleave_pre_top_mlp.reshape_dims = trt.Dims((-1, 2, interaction_output_concat.get_output(0).shape[1], 0)) interleave_pre_top_mlp.second_transpose = trt.Permutation([0, 2, 1, 3]) interleave_pre_top_mlp.name = "interleave_pre_top_mlp" top_mlp_input = interleave_pre_top_mlp.get_output(0) top_mlp_input.name = "interleave_pre_top_mlp" else: top_mlp_input = interaction_output_concat.get_output(0) # Top MLP top_mlp = self.add_mlp(top_mlp_input, self.top_mlp_input_size, self.top_mlp_channels, self.top_mlp_names, last_relu=False, useConvForFC=useConvForFC_top) if self.INTERLEAVED_TOP_MLP: # Shuffle back to [BS, 1, 1, 1] from [BS//2, 1, 2, 1] interleave_post_top_mlp = self.network.add_shuffle(top_mlp.get_output(0)) interleave_post_top_mlp.reshape_dims = trt.Dims((-1, 0, 1, 0)) interleave_post_top_mlp.name = "interleave_post_top_mlp" sigmoid_input = interleave_post_top_mlp.get_output(0) sigmoid_input.name = "interleave_post_top_mlp" else: sigmoid_input = top_mlp.get_output(0) # Sigmoid sigmoid_layer = self.network.add_activation(sigmoid_input, trt.ActivationType.SIGMOID) sigmoid_layer.name = "sigmoid" sigmoid_layer.get_output(0).name = "sigmoid_output" # Output self.network.mark_output(sigmoid_layer.get_output(0)) # Make sure we release the memory to system del self.weights self.initialized = True
def torch2trt(module, inputs, input_names=None, output_names=None, max_batch_size=1, max_workspace_size=GiB(1), strict_type_constraints=False, fp16_mode=False, int8_mode=False, int8_calib_batch_size=16, int8_calib_algorithm=t2t.DEFAULT_CALIBRATION_ALGORITHM, keep_network=True, log_level=trt.Logger.INFO, use_onnx=True, **kwargs): """Revise to support dynamic batch size through ONNX by default Args: module(nn.Module): inputs(List[Tensor]): list of tensors Kwargs: input_names: output_names: dynamic_axes(Dict[str, Dict[int, str]]): {input_key: {index: name}} min_shapes(Dict[int, Tuple(int, int, int)]): minimum shape for each dynamic input opt_shapes(Dict[int, Tuple(int, int, int)]): optimal shape for each dynamic input max_shapes(Dict[int, Tuple(int, int, int)]): maximum shape for each dynamic input max_batch_size(int): max batch size as the dynamic axis 0 int8_calib_cache_file(str): int8_calib_data_path(str): int8_calib_max_data(int): int8_calib_batch_size(int): batch size to load for calibration int8_calib_preprocess_func(input): opset_version(int): """ # copy inputs to avoid modifications to source data inputs = [tensor.clone()[0:1] for tensor in inputs] # only run single entry logger = trt.Logger(log_level) builder = trt.Builder(logger) if isinstance(inputs, list): inputs = tuple(inputs) if not isinstance(inputs, tuple): inputs = (inputs, ) # run once to get num outputs outputs = module(*inputs) if not isinstance(outputs, tuple) and not isinstance(outputs, list): outputs = (outputs, ) def reduce(value, outputs): nonlocal count if th.is_tensor(outputs): value += 1 else: for output in outputs: value = reduce(value, output) return value if input_names is None: # list of tensors expected input_names = t2t.default_input_names(len(inputs)) if output_names is None: # in case of nested tensors count = reduce(0, outputs) output_names = t2t.default_output_names(count) # logging.info(f"len(outputs)={len(outputs)}, count={count}") logging.info(f"input_names={input_names}") logging.info(f"output_names={output_names}") dynamic_axes = kwargs.pop('dynamic_axes', None) if dynamic_axes is None and max_batch_size > 1: dynamic_axes = { input_name: { 0: 'batch_size' } for input_name in input_names } if use_onnx: f = io.BytesIO() th.onnx.export(module, inputs, f, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, opset_version=kwargs.pop('opset_version', 13)) f.seek(0) onnx_bytes = f.read() network = builder.create_network( 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) parser = trt.OnnxParser(network, logger) parser.parse(onnx_bytes) else: # FIXME No dynamic batch size by default network = builder.create_network() with t2t.ConversionContext(network) as ctx: ctx.add_inputs(inputs, input_names) outputs = module(*inputs) if not isinstance(outputs, tuple) and not isinstance( outputs, list): outputs = (outputs, ) ctx.mark_outputs(outputs, output_names) builder.max_batch_size = max_batch_size """ Removed in tensorrt > 8.0""" #builder.max_workspace_size = max_workspace_size #builder.fp16_mode = fp16_mode #builder.int8_mode = int8_mode #builder.strict_type_constraints = strict_type_constraints """""" """""" """""" """""" """""" "" if dynamic_axes is None: for i in range(network.num_inputs): logging.info( f"network.get_input({i}).shape={network.get_input(i).shape}") for i in range(network.num_outputs): logging.info( f"network.get_output({i}).shape={network.get_output(i).shape}") engine = builder.build_cuda_engine(network) else: cfg = builder.create_builder_config() if fp16_mode: cfg.flags |= 1 << int(trt.BuilderFlag.FP16) if strict_type_constraints: cfg.flags |= 1 << int(trt.BuilderFlag.STRICT_TYPES) if int8_mode: from .calibrator import Calibrator cfg.set_flag(trt.BuilderFlag.INT8) calib_cache = kwargs.pop('int8_calib_cache', None) # cache of calibration dataset calib_data = kwargs.pop('int8_calib_data', None) # path to calibration data calib_max = kwargs.pop( 'int8_calib_max', 512) # max amount of calibration data to use calib_preprocess_func = kwargs.pop( 'int8_calib_preprocess_func', None) # preprocessing of calibration data calib_files = calib_data and get_calibration_files( calib_data, calib_max) or [] # TODO: test calibrator with dynamic shapes other than batch size dimension cfg.int8_calibrator = Calibrator( batch_size=int8_calib_batch_size, inputs=[tuple(tensor.shape[1:]) for tensor in inputs], cache=calib_cache, calibration_files=calib_files, max_calib_data=calib_max, preprocess_func=calib_preprocess_func, algorithm=int8_calib_algorithm) # XXX: set max_workspace in config for dynamic input cfg.max_workspace_size = max_workspace_size min_shapes = kwargs.pop('min_shapes', None) max_shapes = kwargs.pop('max_shapes', None) opt_shapes = kwargs.pop('opt_shapes', None) logging.info(f"dynamic_axes={dynamic_axes}") profiles = {} for i in range(network.num_inputs): shape = network.get_input(i).shape dynamic = any([s < 1 for s in shape]) if dynamic: logging.info(f"dynamic network.get_input({i}).shape={shape}") profile = builder.create_optimization_profile() min = min_shapes and (1, *min_shapes[i]) or (1, *shape[1:]) max = max_shapes and (max_batch_size, *max_shapes[i]) or ( max_batch_size, *shape[1:]) opt = opt_shapes and (max_batch_size, *opt_shapes[i]) or max profile.set_shape(input_names[i], min=trt.Dims(min), opt=trt.Dims(opt), max=trt.Dims(max)) idx = cfg.add_optimization_profile(profile) profiles[idx] = profile logging.info( f"set dynamic {input_names[i]}.shape to min={min}, opt={opt}, max={max} in profile[{idx}]" ) else: logging.info(f"network.get_input({i}).shape={shape}") for i in range(network.num_outputs): shape = network.get_output(i).shape logging.info(f"network.get_output({i}).shape={shape}") logging.info( f"building TensorRT engine with fp16={fp16_mode}, int8={int8_mode}, strict={strict_type_constraints}" ) if int8_mode: assert network.num_inputs == 1, "Only one dynamic tensor input is supported for int8 calibration" cfg.set_calibration_profile(profiles[0]) engine = builder.build_engine(network, cfg) module_trt = TRTPredictor(engine, input_names, output_names) if keep_network: module_trt.network = network return module_trt
def initialize(self): """Create DLRM network using TRT API and plugins and set the weights.""" useConvForFC_bottom = (self.precision == "int8") useConvForFC_top = (self.precision == "int8") interactionsOutputInterleaved = False if self.need_calibration or self.input_dtype != "int8" else True # Turn off interleaved format if top_mlp use non-interleaved format if not self.enable_interleaved_top_mlp: interactionsOutputInterleaved = False else: print("Using batch-interleaved format for top_mlp.") # Check if we should split the model into the binary file with embedding weights quantized and model without embeddings if not (os.path.isfile(self.embedding_weights_binary_filepath) and os.path.isfile(self.model_without_embedding_weights_filepath)): logging.info("Loading checkpoint from " + self.model_filepath) self.weights = torch.load(self.model_filepath, map_location="cpu")["state_dict"] self.dump_embedding_weights_to_binary_file() logging.info("Writing model without embedding weights to " + self.model_without_embedding_weights_filepath) torch.save(self.weights, self.model_without_embedding_weights_filepath) del self.weights # Dump row frequencies to file in binary format if self.use_row_frequencies and not os.path.isfile(self.row_frequencies_binary_filepath): logging.info("Writing row frequencies to " + self.row_frequencies_binary_filepath) self.dump_row_frequencies_to_binary_file() # Load weights self.weights = torch.load(self.model_without_embedding_weights_filepath, map_location="cpu") # Create network. self.network = self.builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) # Numerical input numerical_input = self.network.add_input("numerical_input", trt.DataType.FLOAT, (-1, self.num_numerical_inputs, 1, 1)) if not self.need_calibration: if self.input_dtype == "int8": numerical_input.dtype = trt.int8 elif self.input_dtype == "fp16": numerical_input.dtype = trt.float16 if self.input_format == "linear": numerical_input.allowed_formats = 1 << int(trt.TensorFormat.LINEAR) elif self.input_format == "chw4": numerical_input.allowed_formats = 1 << int(trt.TensorFormat.CHW4) elif self.input_format == "chw32": numerical_input.allowed_formats = 1 << int(trt.TensorFormat.CHW32) # Bottom MLP if self.need_calibration or self.input_dtype != "int8": bottom_mlp = self.add_mlp(numerical_input, self.num_numerical_inputs, self.bottom_mlp_channels, self.bottom_mlp_names, last_relu=True, useConvForFC=useConvForFC_bottom) else: bottom_mlp_plugin, output_tensor_name = self.add_fused_bottom_mlp("DLRM_BOTTOM_MLP_TRT", numerical_input, self.num_numerical_inputs, self.bottom_mlp_channels, self.bottom_mlp_names) bottom_mlp = self.network.add_plugin_v2([numerical_input], bottom_mlp_plugin) bottom_mlp.get_output(0).name = output_tensor_name bottom_mlp_shuffle = self.network.add_shuffle(bottom_mlp.get_output(0)) bottom_mlp_shuffle.reshape_dims = trt.Dims((-1, 1, self.embedding_size)) # Index input index_input = self.network.add_input("index_input", trt.DataType.INT32, (-1, self.num_features)) # Embedding lookup and interactions dlrm_interactions_plugin = self.get_dlrm_interactions_plugin("DLRM_INTERACTIONS_TRT", np.cumsum( np.array([0] + self.embedding_rows[:-1]).astype(np.int32)).astype(np.int32), interactionsOutputInterleaved) interaction_output_concat = self.network.add_plugin_v2([bottom_mlp.get_output(0), index_input], dlrm_interactions_plugin) interaction_output_concat.name = "interaction_plugin" interaction_output_concat.get_output(0).name = "interaction_output_concat_output" if self.enable_interleaved_top_mlp and not interactionsOutputInterleaved: # Shuffle from [BS, C, 1, 1] to [BS//2, C, 2, 1] before top_mlp interleave_pre_top_mlp = self.network.add_shuffle(interaction_output_concat.get_output(0)) interleave_pre_top_mlp.reshape_dims = trt.Dims((-1, 2, interaction_output_concat.get_output(0).shape[1], 0)) interleave_pre_top_mlp.second_transpose = trt.Permutation([0, 2, 1, 3]) interleave_pre_top_mlp.name = "interleave_pre_top_mlp" top_mlp_input = interleave_pre_top_mlp.get_output(0) top_mlp_input.name = "interleave_pre_top_mlp" else: top_mlp_input = interaction_output_concat.get_output(0) # Insert small-tile GEMM plugin. The plugin supports Ampere-only. gpu_arch = get_system().arch system_id = get_system().gpu if self.use_small_tile_gemm_plugin: if gpu_arch != Architecture.Ampere: print("Small-Tile GEMM plugin does not support {}. Plugin disabled.".format(system_id)) self.use_small_tile_gemm_plugin = False # Enable gemm plugin with interleaved format is not recommended. # Note (2/7/21): GEMM plugin doesn't perform well when H*W > 1 if self.use_small_tile_gemm_plugin and self.enable_interleaved_top_mlp: print("Warning: small-Tile GEMM plugin performance will be " "significantly impacted by interleaved format. Turn off " "interleaved format for the best performance") tmp_mlp_input = top_mlp_input tmp_input_size = self.top_mlp_input_size # Helper function to check whether the provided shape is supported by # Small-Tile GEMM plugin def support_small_tile_gemm_func(C, K): return \ (C >= 256) and (C <= 1280) and (C % 128 == 0) and (K % 128 == 0) # Split the top_mlp layers, and use GEMM plugin for 2,4,6 # C, K for top_mlp.0,2,4,6,8: [480,1024],[1024,1024],[1024,512],[512,256],[256,1] for i in range(len(self.top_mlp_channels)): # Insert plugin if the layer meets the restriction if support_small_tile_gemm_func(tmp_input_size, self.top_mlp_channels[i]) and \ self.use_small_tile_gemm_plugin: print("Replacing {} with Small-Tile GEMM Plugin, with fairshare cache size {}". format(self.top_mlp_names[i], self.gemm_plugin_fairshare_cache_size)) layer_top_mlp = self.add_small_tile_gemm_top_mlp( tmp_mlp_input, tmp_input_size, self.top_mlp_channels[i], self.top_mlp_names[i], self.gemm_plugin_fairshare_cache_size ) else: layer_top_mlp = self.add_single_mlp( tmp_mlp_input, tmp_input_size, self.top_mlp_channels[i], self.top_mlp_names[i], useConvForFC=useConvForFC_top, add_relu=(i != len(self.top_mlp_channels) - 1)) tmp_mlp_input = layer_top_mlp.get_output(0) tmp_input_size = self.top_mlp_channels[i] top_mlp = layer_top_mlp if self.enable_interleaved_top_mlp: # Shuffle [BS//2, 1, 2, 1] back to [BS, 1, 1, 1] interleave_post_top_mlp = self.network.add_shuffle(top_mlp.get_output(0)) interleave_post_top_mlp.reshape_dims = trt.Dims((-1, 0, 1, 0)) interleave_post_top_mlp.name = "interleave_post_top_mlp" sigmoid_input = interleave_post_top_mlp.get_output(0) sigmoid_input.name = "interleave_post_top_mlp" else: sigmoid_input = top_mlp.get_output(0) # Sigmoid sigmoid_layer = self.network.add_activation(sigmoid_input, trt.ActivationType.SIGMOID) sigmoid_layer.name = "sigmoid" sigmoid_layer.get_output(0).name = "sigmoid_output" # Output self.network.mark_output(sigmoid_layer.get_output(0)) # Make sure we release the memory to system del self.weights self.initialized = True