def permuteHelper(self, network, inputs, dims): tensor = inputs[0] checkType([tensor, dims], [trt.ITensor, list]), "Input Type Error" assert max(dims) < 4 and dims[0] == 0 dims = list(map(lambda x: x - 1, dims))[1:] shuffle = network.add_shuffle(tensor) shuffle.first_transpose = trt.Permutation(dims) return shuffle
def trt_transpose(network,input_size,aixs=(0,2,3,1)): """ x = np.random.random_sample([5,2,3,3]).astype(np.float32) torch_x = torch.from_numpy(x).permute(0,2,3,1) """ layer = network.add_shuffle(input=input_size) layer.first_transpose = trt.Permutation(aixs) return layer
def pixel_unshuffle(network: trt.INetworkDefinition, input: trt.ITensor, downscale_factor: int) -> trt.ITensor: n, ic, ih, iw = input.shape assert ih % downscale_factor == 0 and ih % downscale_factor == 0 oc = ic * (downscale_factor**2) oh = ih // downscale_factor ow = iw // downscale_factor reshape = network.add_shuffle(input) reshape.reshape_dims = trt.Dims( [n, ic, oh, downscale_factor, ow, downscale_factor]) reshape.second_transpose = trt.Permutation([0, 1, 3, 5, 2, 4]) reshape = network.add_shuffle(reshape.get_output(0)) reshape.reshape_dims = trt.Dims([n, oc, oh, ow]) return reshape.get_output(0)
def pixel_shuffle(network: trt.INetworkDefinition, input: trt.ITensor, upscale_factor: int) -> trt.ITensor: n, ic, ih, iw = input.shape assert ic % (upscale_factor**2) == 0 oc = ic // (upscale_factor**2) oh = ih * upscale_factor ow = iw * upscale_factor reshape = network.add_shuffle(input) reshape.reshape_dims = trt.Dims( [n, oc, upscale_factor, upscale_factor, ih, iw]) reshape.second_transpose = trt.Permutation([0, 1, 4, 2, 5, 3]) reshape = network.add_shuffle(reshape.get_output(0)) reshape.reshape_dims = trt.Dims([n, oc, oh, ow]) return reshape.get_output(0)
def populate_network(network, weights, weights_color): # below is network define, no need to care # Configure the network layers based on the weights provided. #定义网络输入 input_tensor = network.add_input(name=ModelData.INPUT_NAME, dtype=ModelData.DTYPE, shape=ModelData.INPUT_SHAPE) # Backbone Sequential conv0 = conv_trt(weights, input_tensor, 'cnn_1.conv0', network, 64) print('conv0') print(conv0.get_output(0).shape) bn0 = bn_trt(weights, conv0.get_output(0), 'cnn_1.batchnorm0', network) relu0 = network.add_activation(input=bn0.get_output(0), type=trt.ActivationType.RELU) pooling0 = max_pooling_trt(relu0.get_output(0), network) print('pooling0') print(pooling0.get_output(0).shape) conv1 = conv_trt(weights, pooling0.get_output(0), 'cnn_1.conv1', network, 128) print('conv1') print(conv1.get_output(0).shape) bn1 = bn_trt(weights, conv1.get_output(0), 'cnn_1.batchnorm1', network) relu1 = network.add_activation(input=bn1.get_output(0), type=trt.ActivationType.RELU) pooling1 = max_pooling_trt(relu1.get_output(0), network) print('pooling1') print(pooling1.get_output(0).shape) conv2 = conv_trt(weights, pooling1.get_output(0), 'cnn_1.conv2', network, 256) print('conv2') print(conv2.get_output(0).shape) bn2 = bn_trt(weights, conv2.get_output(0), 'cnn_1.batchnorm2', network) relu2 = network.add_activation(input=bn2.get_output(0), type=trt.ActivationType.RELU) # Backbone Sequential-2 conv3 = conv_trt(weights, relu2.get_output(0), 'cnn_2.conv3', network, 256) print('conv3') print(conv3.get_output(0).shape) bn3 = bn_trt(weights, conv3.get_output(0), 'cnn_2.batchnorm3', network) relu3 = network.add_activation(input=bn3.get_output(0), type=trt.ActivationType.RELU) pooling2 = max_pooling_trt(relu3.get_output(0), network, stride=(2, 1)) print('pooling2') print(pooling2.get_output(0).shape) conv4 = conv_trt(weights, pooling2.get_output(0), 'cnn_2.conv4', network, 512) print('conv4') print(conv4.get_output(0).shape) bn4 = bn_trt(weights, conv4.get_output(0), 'cnn_2.batchnorm4', network) relu4 = network.add_activation(input=bn4.get_output(0), type=trt.ActivationType.RELU) conv5 = conv_trt(weights, relu4.get_output(0), 'cnn_2.conv5', network, 512) print('conv5') print(conv5.get_output(0).shape) bn5 = bn_trt(weights, conv5.get_output(0), 'cnn_2.batchnorm5', network) relu5 = network.add_activation(input=bn5.get_output(0), type=trt.ActivationType.RELU) # Branch 1 Sequential pooling7 = max_pooling_trt(relu5.get_output(0), network) print('pooling7') print(pooling7.get_output(0).shape) conv9 = conv_trt(weights, pooling7.get_output(0), 'branch1.conv9', network, 512, kernel=(2, 2), stride=(2, 2), padding=(0, 0)) print('conv9') print(conv9.get_output(0).shape) bn9 = bn_trt(weights, conv9.get_output(0), 'branch1.batchnorm9', network) relu9 = network.add_activation(input=bn9.get_output(0), type=trt.ActivationType.RELU) permute1 = network.add_shuffle(relu9.get_output(0)) permute1.first_transpose = trt.Permutation([3, 0, 1, 2]) permute1.reshape_dims = [ permute1.get_output(0).shape[0], permute1.get_output(0).shape[1], 1, 1, -1 ] print('permute1') print(permute1.get_output(0).shape) fc1 = fc_trt(weights, permute1.get_output(0), 'fc1', network, ModelData.OUTPUT_SIZE) fc1.get_output(0).name = ModelData.OUTPUT_NAME1 print('fc1') print(fc1.get_output(0).shape) # Branch 2 Sequential pooling3 = max_pooling_trt(relu5.get_output(0), network, stride=(2, 1)) conv7 = conv_trt(weights, pooling3.get_output(0), 'branch2.conv7', network, 512) print('conv7') print(conv7.get_output(0).shape) bn7 = bn_trt(weights, conv7.get_output(0), 'branch2.batchnorm7', network) relu7 = network.add_activation(input=bn7.get_output(0), type=trt.ActivationType.RELU) conv8 = conv_trt(weights, relu7.get_output(0), 'branch2.conv8', network, 512, kernel=(2, 2), stride=(1, 1), padding=(0, 0)) print('conv8') print(conv8.get_output(0).shape) bn8 = bn_trt(weights, conv8.get_output(0), 'branch2.batchnorm8', network) relu8 = network.add_activation(input=bn8.get_output(0), type=trt.ActivationType.RELU) permute2 = network.add_shuffle(relu8.get_output(0)) permute2.first_transpose = trt.Permutation([3, 0, 1, 2]) permute2.reshape_dims = [ permute2.get_output(0).shape[0], permute2.get_output(0).shape[1], 1, 1, -1 ] print('permute2') print(permute2.get_output(0).shape) fc2 = fc_trt(weights, permute2.get_output(0), 'fc2', network, ModelData.OUTPUT_SIZE) fc2.get_output(0).name = ModelData.OUTPUT_NAME2 print('fc2') print(fc2.get_output(0).shape) cat1 = network.add_concatenation([fc1.get_output(0), fc2.get_output(0)]) cat1.axis = 0 cat1.get_output(0).name = ModelData.OUTPUT_NAME3 permute3 = network.add_shuffle(cat1.get_output(0)) permute3.reshape_dims = [0, 0, -1] permute3.second_transpose = trt.Permutation([1, 2, 0]) print('permute3') print(permute3.get_output(0).shape) re_softmax = network.add_softmax(input=permute3.get_output(0)) re_softmax.axes = 2 print('re_softmax') print(re_softmax.get_output(0).shape) permute4 = network.add_shuffle(re_softmax.get_output(0)) permute4.first_transpose = trt.Permutation([0, 2, 1]) print('permute4') print(permute4.get_output(0).shape) #branch color c_pooling0 = max_pooling_trt(relu2.get_output(0), network, stride=(2, 2)) c_conv0 = conv_trt(weights_color, c_pooling0.get_output(0), 'color_branch.c_conv0', network, 128) print('c_conv0') print(c_conv0.get_output(0).shape) c_relu0 = network.add_activation(input=c_conv0.get_output(0), type=trt.ActivationType.RELU) c_pooling1 = max_pooling_trt(c_relu0.get_output(0), network, stride=(2, 2)) c_conv1 = conv_trt(weights_color, c_pooling1.get_output(0), 'color_branch.c_conv1', network, 64) print('c_conv1') print(c_conv1.get_output(0).shape) c_relu1 = network.add_activation(input=c_conv1.get_output(0), type=trt.ActivationType.RELU) fc_c = fc_trt(weights_color, c_relu1.get_output(0), 'fc_c', network, ModelData.OUTPUT_COLOR_SIZE) fc_c.get_output(0).name = ModelData.OUTPUT_COLOR_NAME print('fc_c') print(fc_c.get_output(0).shape) c_softmax = network.add_softmax(input=fc_c.get_output(0)) c_softmax.axes = 2 print('c_softmax') print(c_softmax.get_output(0).shape) #定义网络输出 network.mark_output(tensor=permute4.get_output(0)) network.mark_output(tensor=c_softmax.get_output(0))
def populate_network(network, weights): # below is network define, no need to care # Configure the network layers based on the weights provided. input_tensor = network.add_input(name=ModelData.INPUT_NAME, dtype=ModelData.DTYPE, shape=ModelData.INPUT_SHAPE) # Backbone Sequential conv0 = conv_trt(weights, input_tensor, 'cnn.conv0', network, 64) print('conv0') print(conv0.get_output(0).shape) bn0 = bn_trt(weights, conv0.get_output(0), 'cnn.batchnorm0', network) relu0 = network.add_activation(input=bn0.get_output(0), type=trt.ActivationType.LEAKY_RELU) relu0.alpha = ModelData.RELU_alpha pooling0 = max_pooling_trt(relu0.get_output(0), network) print('pooling0') print(pooling0.get_output(0).shape) conv1 = conv_trt(weights, pooling0.get_output(0), 'cnn.conv1', network, 128) print('conv1') print(conv1.get_output(0).shape) bn1 = bn_trt(weights, conv1.get_output(0), 'cnn.batchnorm1', network) relu1 = network.add_activation(input=bn1.get_output(0), type=trt.ActivationType.LEAKY_RELU) relu1.alpha = ModelData.RELU_alpha pooling1 = max_pooling_trt(relu1.get_output(0), network) print('pooling1') print(pooling1.get_output(0).shape) conv2 = conv_trt(weights, pooling1.get_output(0), 'cnn.conv2', network, 256) print('conv2') print(conv2.get_output(0).shape) bn2 = bn_trt(weights, conv2.get_output(0), 'cnn.batchnorm2', network) relu2 = network.add_activation(input=bn2.get_output(0), type=trt.ActivationType.LEAKY_RELU) relu2.alpha = ModelData.RELU_alpha conv3 = conv_trt(weights, relu2.get_output(0), 'cnn.conv3', network, 256) print('conv3') print(conv3.get_output(0).shape) bn3 = bn_trt(weights, conv3.get_output(0), 'cnn.batchnorm3', network) relu3 = network.add_activation(input=bn3.get_output(0), type=trt.ActivationType.LEAKY_RELU) relu3.alpha = ModelData.RELU_alpha pooling2 = max_pooling_trt(relu3.get_output(0), network, stride=(2, 1)) print('pooling2') print(pooling2.get_output(0).shape) conv4 = conv_trt(weights, pooling2.get_output(0), 'cnn.conv4', network, 512) print('conv4') print(conv4.get_output(0).shape) bn4 = bn_trt(weights, conv4.get_output(0), 'cnn.batchnorm4', network) relu4 = network.add_activation(input=bn4.get_output(0), type=trt.ActivationType.LEAKY_RELU) relu4.alpha = ModelData.RELU_alpha conv5 = conv_trt(weights, relu4.get_output(0), 'cnn.conv5', network, 512) print('conv5') print(conv5.get_output(0).shape) bn5 = bn_trt(weights, conv5.get_output(0), 'cnn.batchnorm5', network) relu5 = network.add_activation(input=bn5.get_output(0), type=trt.ActivationType.LEAKY_RELU) relu5.alpha = ModelData.RELU_alpha # Branch 1 Sequential pooling7 = max_pooling_trt(relu5.get_output(0), network) print('pooling7') print(pooling7.get_output(0).shape) conv9 = conv_trt(weights, pooling7.get_output(0), 'branch1.conv9', network, 512, kernel=(2, 2), stride=(2, 2), padding=(0, 0)) print('conv9') print(conv9.get_output(0).shape) bn9 = bn_trt(weights, conv9.get_output(0), 'branch1.batchnorm9', network) relu9 = network.add_activation(input=bn9.get_output(0), type=trt.ActivationType.RELU) permute1 = network.add_shuffle(relu9.get_output(0)) permute1.first_transpose = trt.Permutation([2, 0, 1]) permute1.reshape_dims = [0, 1, 1, -1] print('permute1') print(permute1.get_output(0).shape) fc1 = fc_trt(weights, permute1.get_output(0), 'fc1', network, ModelData.OUTPUT_SIZE) fc1.get_output(0).name = ModelData.OUTPUT_NAME1 print('fc1') print(fc1.get_output(0).shape) # Branch 2 Sequential pooling3 = max_pooling_trt(relu5.get_output(0), network, stride=(2, 1)) conv7 = conv_trt(weights, pooling3.get_output(0), 'branch2.conv7', network, 512) print('conv7') print(conv7.get_output(0).shape) bn7 = bn_trt(weights, conv7.get_output(0), 'branch2.batchnorm7', network) relu7 = network.add_activation(input=bn7.get_output(0), type=trt.ActivationType.RELU) conv8 = conv_trt(weights, relu7.get_output(0), 'branch2.conv8', network, 512, kernel=(2, 2), stride=(1, 1), padding=(0, 0)) print('conv8') print(conv8.get_output(0).shape) bn8 = bn_trt(weights, conv8.get_output(0), 'branch2.batchnorm8', network) relu8 = network.add_activation(input=bn8.get_output(0), type=trt.ActivationType.RELU) permute2 = network.add_shuffle(relu8.get_output(0)) permute2.first_transpose = trt.Permutation([2, 0, 1]) permute2.reshape_dims = [0, 1, 1, -1] fc2 = fc_trt(weights, permute2.get_output(0), 'fc2', network, ModelData.OUTPUT_SIZE) fc2.get_output(0).name = ModelData.OUTPUT_NAME2 print('fc2') print(fc2.get_output(0).shape) cat1 = network.add_concatenation([fc1.get_output(0), fc2.get_output(0)]) cat1.axis = 0 cat1.get_output(0).name = ModelData.OUTPUT_NAME3 permute3 = network.add_shuffle(cat1.get_output(0)) permute3.reshape_dims = [0, -1] network.mark_output(tensor=permute3.get_output(0))
def initialize(self): useConvForFC_bottom = (self.precision == "int8") useConvForFC_top = (self.precision == "int8") interactionsOutputInterleaved = False if self.need_calibration or self.input_dtype != "int8" else True # Check if we should split the model into the binary file with embedding weights quantized and model without embeddings if not (os.path.isfile(self.embedding_weights_binary_filepath) and os.path.isfile(self.model_without_embedding_weights_filepath)): logging.info("Loading checkpoint from " + self.model_filepath) self.weights = torch.load(self.model_filepath, map_location="cpu")["state_dict"] self.dump_embedding_weights_to_binary_file() logging.info("Writing model without embedding weights to " + self.model_without_embedding_weights_filepath) torch.save(self.weights, self.model_without_embedding_weights_filepath) del self.weights # Dump row frequencies to file in binary format if self.use_row_frequencies and not os.path.isfile(self.row_frequencies_binary_filepath): logging.info("Writing row frequencies to " + self.row_frequencies_binary_filepath) self.dump_row_frequencies_to_binary_file() # Load weights self.weights = torch.load(self.model_without_embedding_weights_filepath, map_location="cpu") # Create network. self.network = self.builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) # Numerical input numerical_input = self.network.add_input("numerical_input", trt.DataType.FLOAT, (-1, self.num_numerical_inputs, 1, 1)) if not self.need_calibration: if self.input_dtype == "int8": numerical_input.dtype = trt.int8 elif self.input_dtype == "fp16": numerical_input.dtype = trt.float16 if self.input_format == "linear": numerical_input.allowed_formats = 1 << int(trt.TensorFormat.LINEAR) elif self.input_format == "chw4": numerical_input.allowed_formats = 1 << int(trt.TensorFormat.CHW4) elif self.input_format == "chw32": numerical_input.allowed_formats = 1 << int(trt.TensorFormat.CHW32) # Bottom MLP if self.need_calibration or self.input_dtype != "int8": bottom_mlp = self.add_mlp(numerical_input, self.num_numerical_inputs, self.bottom_mlp_channels, self.bottom_mlp_names, last_relu=True, useConvForFC=useConvForFC_bottom) else: bottom_mlp_plugin, output_tesnor_name = self.add_fused_bottom_mlp("DLRM_BOTTOM_MLP_TRT", numerical_input, self.num_numerical_inputs, self.bottom_mlp_channels, self.bottom_mlp_names) bottom_mlp = self.network.add_plugin_v2([numerical_input], bottom_mlp_plugin) bottom_mlp.get_output(0).name = output_tesnor_name bottom_mlp_shuffle = self.network.add_shuffle(bottom_mlp.get_output(0)) bottom_mlp_shuffle.reshape_dims = trt.Dims((-1, 1, self.embedding_size)) # Index input index_input = self.network.add_input("index_input", trt.DataType.INT32, (-1, self.num_features)) # Embedding lookup and interactions dlrm_interactions_plugin = self.get_dlrm_interactions_plugin("DLRM_INTERACTIONS_TRT", np.cumsum(np.array([0] + self.embedding_rows[:-1]).astype(np.int32)).astype(np.int32), interactionsOutputInterleaved) interaction_output_concat = self.network.add_plugin_v2([bottom_mlp.get_output(0), index_input], dlrm_interactions_plugin) interaction_output_concat.name = "interaction_plugin" interaction_output_concat.get_output(0).name = "interaction_output_concat_output" if self.INTERLEAVED_TOP_MLP and not interactionsOutputInterleaved: # Shuffle from [BS, C, 1, 1] to [BS//2, C, 2, 1] before top_mlp interleave_pre_top_mlp = self.network.add_shuffle(interaction_output_concat.get_output(0)) interleave_pre_top_mlp.reshape_dims = trt.Dims((-1, 2, interaction_output_concat.get_output(0).shape[1], 0)) interleave_pre_top_mlp.second_transpose = trt.Permutation([0, 2, 1, 3]) interleave_pre_top_mlp.name = "interleave_pre_top_mlp" top_mlp_input = interleave_pre_top_mlp.get_output(0) top_mlp_input.name = "interleave_pre_top_mlp" else: top_mlp_input = interaction_output_concat.get_output(0) # Top MLP top_mlp = self.add_mlp(top_mlp_input, self.top_mlp_input_size, self.top_mlp_channels, self.top_mlp_names, last_relu=False, useConvForFC=useConvForFC_top) if self.INTERLEAVED_TOP_MLP: # Shuffle back to [BS, 1, 1, 1] from [BS//2, 1, 2, 1] interleave_post_top_mlp = self.network.add_shuffle(top_mlp.get_output(0)) interleave_post_top_mlp.reshape_dims = trt.Dims((-1, 0, 1, 0)) interleave_post_top_mlp.name = "interleave_post_top_mlp" sigmoid_input = interleave_post_top_mlp.get_output(0) sigmoid_input.name = "interleave_post_top_mlp" else: sigmoid_input = top_mlp.get_output(0) # Sigmoid sigmoid_layer = self.network.add_activation(sigmoid_input, trt.ActivationType.SIGMOID) sigmoid_layer.name = "sigmoid" sigmoid_layer.get_output(0).name = "sigmoid_output" # Output self.network.mark_output(sigmoid_layer.get_output(0)) # Make sure we release the memory to system del self.weights self.initialized = True
def populate_duration_predictor(self, name, network, weights, seq_tensor, seq_mask_tensor, batch_size, max_seq_len, d_model): duration_predictor_filter_size = self.model.duration_predictor_filter_size duration_predictor_kernel_size = self.model.duration_predictor_kernel_size # Pytorch: input *= input_mask.to(input.dtype) # can be skipped. # Pytorch: out = self.conv1d_1(input.transpose(1,2)).transpose(1,2) trans1 = network.add_shuffle( input=seq_tensor) # (b, t, d_model) to (b, d_model, t, 1) trans1.first_transpose = trt.Permutation([0, 2, 1]) trans1.reshape_dims = Dims((batch_size, d_model, max_seq_len, 1)) trans1.name = "{}.trans1".format(name) out = trans1.get_output(0) # (b, d_model, t, 1) conv1_w = weights["{}.conv1d_1.weight".format( name )] # (1, d_model, duration_predictor_filter_size, duration_predictor_kernel_size, 1) conv1_b = weights["{}.conv1d_1.bias".format( name)] # (duration_predictor_filter_size, ) conv1 = network.add_convolution( input=out, num_output_maps=duration_predictor_filter_size, kernel_shape=trt.DimsHW(duration_predictor_kernel_size, 1), kernel=Weights(conv1_w), bias=Weights(conv1_b)) conv1.padding = trt.DimsHW(1, 0) conv1.name = "{}.conv1".format(name) out = conv1.get_output(0) # (b, duration_predictor_filter_size, t, 1) trans2 = network.add_shuffle( input=out ) # (b, duration_predictor_filter_size, t, 1) to (b, t, duration_predictor_filter_size) trans2.first_transpose = trt.Permutation([0, 2, 1, 3]) trans2.reshape_dims = Dims( (batch_size, max_seq_len, duration_predictor_filter_size)) trans2.name = "{}.trans2".format(name) out = trans2.get_output(0) # (b, t, duration_predictor_filter_size) # Pytorch: out = self.relu_1(out) relu = network.add_activation(input=out, type=trt.ActivationType.RELU) relu.name = "{}.relu1".format(name) out_relu = relu.get_output(0) # (b, t, duration_predictor_filter_size) # Pytorch: out = self.layer_norm_1(out) out = self.populate_layernorm(name="{}.layer_norm_1".format(name), network=network, weights=weights, seq_tensor=out_relu, d_layer=duration_predictor_filter_size, batch_size=batch_size, max_seq_len=max_seq_len) # Pytorch: out = self.conv1d_2(out.transpose(1,2)).transpose(1,2) trans3 = network.add_shuffle( input=out ) # (b, t, duration_predictor_filter_size) to (b, duration_predictor_filter_size, t, 1) trans3.first_transpose = trt.Permutation([0, 2, 1]) trans3.reshape_dims = Dims( (batch_size, duration_predictor_filter_size, max_seq_len, 1)) trans3.name = "{}.trans3".format(name) out = trans3.get_output(0) # (b, duration_predictor_filter_size, t, 1) conv2_w = weights["{}.conv1d_2.weight".format( name )] # (1, duration_predictor_filter_size, duration_predictor_filter_size, duration_predictor_kernel_size, 1) conv2_b = weights["{}.conv1d_2.bias".format( name)] # (duration_predictor_filter_size, ) conv2 = network.add_convolution( input=out, num_output_maps=duration_predictor_filter_size, kernel_shape=trt.DimsHW(duration_predictor_kernel_size, 1), kernel=Weights(conv2_w), bias=Weights(conv2_b)) conv2.padding = trt.DimsHW(1, 0) conv2.name = "{}.conv2".format(name) out = conv2.get_output(0) trans4 = network.add_shuffle( input=out ) # (b, duration_predictor_filter_size, t, 1) to (b, t, duration_predictor_filter_size) trans4.first_transpose = trt.Permutation([0, 2, 1, 3]) trans4.reshape_dims = Dims( (batch_size, max_seq_len, duration_predictor_filter_size)) trans4.name = "{}.trans4".format(name) out = trans4.get_output(0) # (b, t, duration_predictor_filter_size) # Pytorch: out = self.relu_2(out) relu = network.add_activation(input=out, type=trt.ActivationType.RELU) relu.name = "{}.relu2".format(name) out_relu = relu.get_output(0) # (b, t, duration_predictor_filter_size) # Pytorch: out = self.layer_norm_2(out) out = self.populate_layernorm( name="{}.layer_norm_2".format(name), network=network, weights=weights, seq_tensor=out_relu, d_layer=duration_predictor_filter_size, batch_size=batch_size, max_seq_len=max_seq_len, ) # (b, t, duration_predictor_filter_size) # Pytorch: out = self.linear_layer(out) w = weights["{}.linear_layer.weight".format( name)] # (1, duration_predictor_filter_size) out_w = network.add_constant( shape=(1, 1, duration_predictor_filter_size), weights=trt.Weights(w)).get_output( 0) # (1, 1, duration_predictor_filter_size) linear_w = network.add_matrix_multiply( out, MatrixOperation.NONE, out_w, MatrixOperation.TRANSPOSE ) # (b, t, duration_predictor_filter_size) * (1->b, duration_predictor_filter_size, 1) => (b, t, 1) linear_w.name = "{}.linear.w".format(name) out = linear_w.get_output(0) # (b, t, 1) b = weights["{}.linear_layer.bias".format(name)] # (1,) out_b = network.add_constant( shape=(1, 1, 1), weights=trt.Weights(b)).get_output(0) # (1, 1, 1) linear_b = network.add_elementwise(input1=out, input2=out_b, op=trt.ElementWiseOperation.SUM) linear_b.name = "{}.linear.b".format(name) out = linear_b.get_output(0) # (b, t, 1) # Pytorch: out *= input_mask.to(out.dtype) zeros = network.add_constant(weights=Weights( np.zeros(shape=(batch_size, max_seq_len, 1), dtype=np.float32)), shape=(batch_size, max_seq_len, 1)) out_zeros = zeros.get_output(0) # (b, t, 1) dur = network.add_select(condition=seq_mask_tensor, then_input=out, else_input=out_zeros) dur.name = "{}.mask".format(name) out_dur = dur.get_output(0) # Pytorch: duration = torch.clamp_min(torch.exp(duration) - 1, 0) exp = network.add_unary(input=out_dur, op=trt.UnaryOperation.EXP) exp.name = "{}.exp".format(name) out_exp = exp.get_output(0) ones = network.add_constant(weights=Weights( np.ones(shape=(batch_size, max_seq_len, 1), dtype=np.float32)), shape=(batch_size, max_seq_len, 1)) out_ones = ones.get_output(0) # (b, t, 1) sub = network.add_elementwise(input1=out_exp, input2=out_ones, op=trt.ElementWiseOperation.SUB) sub.name = "{}.sub_one".format(name) out_sub = sub.get_output(0) dur = network.add_elementwise(input1=out_sub, input2=out_zeros, op=trt.ElementWiseOperation.MAX) dur.name = "{}.max".format(name) out_dur = dur.get_output(0) # Pytorch: repeats = torch.round(repeats).long() half_ones = network.add_constant(weights=Weights( np.full((batch_size, max_seq_len, 1), 0.5, dtype=np.float32)), shape=(batch_size, max_seq_len, 1)) out_half_ones = half_ones.get_output(0) # (b, t, 1) add = network.add_elementwise(input1=out_dur, input2=out_half_ones, op=trt.ElementWiseOperation.SUM) add.name = "{}.round_add".format(name) out_add = add.get_output(0) # (b, t, 1) dur = network.add_elementwise(input1=out_add, input2=out_ones, op=trt.ElementWiseOperation.FLOOR_DIV) dur.name = "{}.round_floor_div".format(name) out_dur = dur.get_output(0) # (b, t, 1) dur = network.add_shuffle(input=out_dur) # (b, t, 1) to (b, t) dur.reshape_dims = Dims(shape=(batch_size, max_seq_len)) out_dur = dur.get_output(0) # (b, t) return out_dur
def populate_pos_wise(self, name, network, weights, seq_tensor, batch_size, max_seq_len, d_model, conv_filter_size, conv_kernel_size, conv_padding): # Pytorch: output = x.transpose(1, 2) trans1 = network.add_shuffle( input=seq_tensor) # (b, t, d_model) to (b, d_model, t, 1) trans1.first_transpose = trt.Permutation([0, 2, 1]) trans1.reshape_dims = Dims((batch_size, d_model, max_seq_len, 1)) trans1.name = "{}.trans1".format(name) out = trans1.get_output(0) # (b, d_model, t, 1) # Pytorch: output = self.w_1(output) conv1_w = weights["{}.w_1.weight".format( name)] # (1, conv_filter_size, d_model, conv_kernel_size, 1) conv1_b = weights["{}.w_1.bias".format(name)] # (cov_filter_size,) conv1 = network.add_convolution(input=out, num_output_maps=conv_filter_size, kernel_shape=trt.DimsHW( conv_kernel_size, 1), kernel=Weights(conv1_w), bias=Weights(conv1_b)) conv1.padding = trt.DimsHW(1, 0) conv1.name = "{}.conv1".format(name) out = conv1.get_output(0) # (b, conv_filter_size, t, 1) if self.validate_accuracy: self.add_activation_as_output(network, out, "act.{}.conv1".format(name)) # Pytorch: output = F.relu(output) relu = network.add_activation(input=out, type=trt.ActivationType.RELU) relu.name = "{}.relu".format(name) out = relu.get_output(0) # (b, conv_filter_size, t, 1) # Pytorch: output = self.w_2(output) conv2_w = weights["{}.w_2.weight".format( name)] # (1, d_model, conv_filter_size, conv_kernel_size, 1) conv2_b = weights["{}.w_2.bias".format(name)] # (d_model, ) conv2 = network.add_convolution(input=out, num_output_maps=d_model, kernel_shape=trt.DimsHW( conv_kernel_size, 1), kernel=Weights(conv2_w), bias=Weights(conv2_b)) conv2.padding = trt.DimsHW(1, 0) conv2.name = "{}.conv2".format(name) out = conv2.get_output(0) # (b, d_model, t, 1) if self.validate_accuracy: self.add_activation_as_output(network, out, "act.{}.conv2".format(name)) # Pytorch: output = output.transpose(1, 2) trans2 = network.add_shuffle( input=out) # (b, d_model, t, 1) to (b, t, d_model) trans2.first_transpose = trt.Permutation([0, 2, 1, 3]) trans2.reshape_dims = Dims((batch_size, max_seq_len, d_model)) trans2.name = "{}.trans2".format(name) out = trans2.get_output(0) # (b, t, d_model) # Pytorch: output += residual residual = network.add_elementwise(input1=seq_tensor, input2=out, op=trt.ElementWiseOperation.SUM) residual.name = "{}.residual".format(name) out = residual.get_output(0) # (b, t, d_model) if self.validate_accuracy: self.add_activation_as_output(network, out, "act.{}.residual".format(name)) # Pytorch: output = self.layer_norm(output) out = self.populate_layernorm( name="{}.layer_norm".format(name), network=network, weights=weights, seq_tensor=out, batch_size=self.batch_size, max_seq_len=max_seq_len, d_layer=d_model, ) # (b, t, d_model) if self.validate_accuracy: self.add_activation_as_output(network, out, "act.{}.ln".format(name)) return out
def populate_slf_attn(self, name, network, weights, seq_tensor, seq_mask_tensor, batch_size, max_seq_len, d_model, n_heads, d_k, d_v): d_qkv = d_k + d_k + d_v # Pytorch: x = self.linear(x) w = weights["{}.linear.weight".format( name)] # (n_heads * d_qkv, d_model) out_w = network.add_constant(shape=(1, d_model, n_heads * d_qkv), weights=trt.Weights(w)).get_output( 0) # (1, n_heads * d_qkv, d_model) linear_w = network.add_matrix_multiply( seq_tensor, MatrixOperation.NONE, out_w, MatrixOperation.TRANSPOSE ) # (b, t, d_model) * (1->b, d_model, n_heads * d_qkv) => (b, t, n_heads * d_qkv) linear_w.name = "{}.linear.w".format(name) out = linear_w.get_output(0) # (b, t, n_heads * d_qkv) b = weights["{}.linear.bias".format(name)] # (n_heads * d_qkv,) out_b = network.add_constant(shape=(1, 1, n_heads * d_qkv), weights=trt.Weights(b)).get_output( 0) # (1, 1, n_heads * d_qkv) linear_b = network.add_elementwise(input1=out, input2=out_b, op=trt.ElementWiseOperation.SUM) linear_b.name = "{}.linear.b".format(name) out = linear_b.get_output(0) # (b, t, n_heads * d_qkv) if self.validate_accuracy: self.add_activation_as_output(network, out, "act.{}.linear".format(name)) trans1 = network.add_shuffle( input=out) # (b, t, n_heads * d_qkv) to (b, n_heads, t, d_qkv) trans1.reshape_dims = Dims((batch_size, max_seq_len, n_heads, d_qkv)) trans1.second_transpose = trt.Permutation([0, 2, 1, 3]) trans1.name = "{}.trans1".format(name) out = trans1.get_output(0) # (b, n_heads, t, d_qkv) # if self.validate_accuracy: # self.add_activation_as_output(network, out, "act.{}.reshape".format(name)) q = network.add_slice(input=out, start=Dims((0, 0, 0, 0)), shape=Dims( (batch_size, n_heads, max_seq_len, d_k)), stride=Dims((1, 1, 1, 1))) q.name = "{}.slide_q".format(name) k = network.add_slice(input=out, start=Dims((0, 0, 0, d_k)), shape=Dims( (batch_size, n_heads, max_seq_len, d_k)), stride=Dims((1, 1, 1, 1))) k.name = "{}.slide_k".format(name) v = network.add_slice(input=out, start=Dims((0, 0, 0, 2 * d_k)), shape=Dims( (batch_size, n_heads, max_seq_len, d_k)), stride=Dims((1, 1, 1, 1))) v.name = "{}.slide_v".format(name) out_q = q.get_output(0) # (b, n_heads, t, d_q) out_k = k.get_output(0) # (b, n_heads, t, d_k) out_v = v.get_output(0) # (b, n_heads, t, d_v) # Pytorch: output, attn = self.attention(q, k, v, mask=mask) out = self.populate_scaled_dot( name="{}.scaled_dot".format(name), # (b, n_heads, t, d_k) network=network, q_tensor=out_q, k_tensor=out_k, v_tensor=out_v, mask_tensor=seq_mask_tensor, batch_size=batch_size, max_seq_len=max_seq_len, n_heads=n_heads, temperature=d_k**0.5) # Pytorch: # output = output.view(self.n_head, bs, seq_len, self.d_v) # output = output.permute(1, 2, 0, 3).contiguous().view(bs, seq_len, self.n_head * self.d_v) trans2 = network.add_shuffle( input=out) # b, n_heads, t, d_k) to (b, t, n_heads * d_k) trans2.first_transpose = trt.Permutation([0, 2, 1, 3]) trans2.reshape_dims = Dims((batch_size, max_seq_len, n_heads * d_v)) trans2.name = "{}.trans2".format(name) out = trans2.get_output(0) # (b, t, n_heads * d_k) if self.validate_accuracy: self.add_activation_as_output(network, out, "act.{}.scaled_dot".format(name)) # Pytorch: output = self.fc(output) w = weights["{}.fc.weight".format(name)] # (d_model, n_heads * d_v) out_w = network.add_constant(shape=(1, d_model, n_heads * d_v), weights=trt.Weights(w)).get_output( 0) # (1, d_model, n_heads * d_v) fc_w = network.add_matrix_multiply( out, MatrixOperation.NONE, out_w, MatrixOperation.TRANSPOSE ) # (b, t, n_heads * d_k) * (1->b, n_heads * d_k, d_model) => (b, t, d_model) fc_w.name = "{}.fc.w".format(name) out = fc_w.get_output(0) # (b, t, d_model) b = weights["{}.fc.bias".format(name)] # (d_model,) out_b = network.add_constant(shape=(1, 1, n_heads * d_qkv), weights=trt.Weights(b)).get_output( 0) # (1, 1, d_model) fc_b = network.add_elementwise(input1=out, input2=out_b, op=trt.ElementWiseOperation.SUM) fc_b.name = "{}.fc.b".format(name) out = fc_b.get_output(0) # (b, t, d_model) # if self.validate_accuracy: # self.add_activation_as_output(network, out, "act.{}.fc".format(name)) # Pytorch: output += residual residual = network.add_elementwise(input1=seq_tensor, input2=out, op=ElementWiseOperation.SUM) residual.name = "{}.residual".format(name) out = residual.get_output(0) # (b, t, d_model) if self.validate_accuracy: self.add_activation_as_output(network, out, "act.{}.residual".format(name)) # Pytorch: output = self.layer_norm(output) out = self.populate_layernorm( name="{}.layer_norm".format(name), network=network, weights=weights, seq_tensor=out, batch_size=self.batch_size, max_seq_len=max_seq_len, d_layer=d_model, ) # (b, t, d_model) if self.validate_accuracy: self.add_activation_as_output(network, out, "act.{}.ln".format(name)) return out
def populate_network(self, network, weights, batch_size, trt_max_input_seq_len, trt_max_output_seq_len): d_model = self.model.d_model ## # Inputs ## out_seq = network.add_input(name="input_seq", dtype=trt.float32, shape=(batch_size, trt_max_input_seq_len, d_model)) # (b, t, d_model) # zeros = network.add_constant(weights=Weights( np.zeros(shape=(batch_size, trt_max_input_seq_len, 1), dtype=np.float32)), shape=(batch_size, trt_max_input_seq_len, 1)) # (b, t, 1) out_zeros = zeros.get_output(0) # (b, t, 1) seq = network.add_elementwise(input1=out_seq, input2=out_zeros, op=trt.ElementWiseOperation.SUM) out_seq = seq.get_output(0) # (b, t, d_model) if self.validate_accuracy: self.add_activation_as_output(network, out_seq, "act.emb") # out_seq_mask = network.add_input( # paddings are False name="input_mask", dtype=trt.bool, shape=(batch_size, trt_max_input_seq_len, 1)) # (b, t, 1) ## # Phoneme-side FFT Blocks ## # Positional Encoding # The plugin adds positional encoding to the padding values also (for better performance), whereas Pytorch impl does not. # It's fine because the padding values will be eventually masked out in coming layers, giving accurate output. seq = network.add_plugin_v2([out_seq], self.get_plugin('AddPosEncPlugin')) seq.name = "phoneme_side.add_pos_enc" out_seq = seq.get_output(0) # (b, t, d_model) if self.validate_accuracy: self.add_activation_as_output(network, out_seq, "act.phoneme_side.add_pos_enc") for layer_idx in range(self.model.phoneme_side_n_layer): out_seq = self.populate_fft( name='phoneme_side.layer_stack.{}'.format(layer_idx), network=network, weights=weights, seq_tensor=out_seq, seq_mask_tensor=out_seq_mask, batch_size=self.batch_size, max_seq_len=trt_max_input_seq_len, d_model=d_model, n_heads=self.model.phoneme_side_head, d_k=self.model.phoneme_side.d_k, d_v=self.model.phoneme_side.d_v, self_attn_temp=self.model.phoneme_side.d_k**0.5, conv_filter_size=self.model.phoneme_side_conv1d_filter_size, conv_kernel_size=self.model.fft_conv1d_kernel, conv_padding=self.model.fft_conv1d_padding) if self.validate_accuracy: self.add_activation_as_output(network, out_seq, "act.phoneme_side.seq") out_seq, out_seq_mask, out_dur = self.populate_length_regulator( name="length_regulator", network=network, weights=weights, seq_tensor=out_seq, seq_mask_tensor=out_seq_mask, batch_size=batch_size, trt_max_input_seq_len=trt_max_input_seq_len, trt_max_output_seq_len=trt_max_output_seq_len, d_model=d_model) if self.validate_accuracy: self.add_activation_as_output(network, out_seq, "act.length_regulator.seq") self.add_activation_as_output(network, out_dur, "act.length_regulator.dur") ## # Mel-side FFT Blocks ## # Type int to bool: out_seq_mask. TODO: remove if bool output is allowed in the plugin. ones = network.add_constant(weights=Weights( np.ones(shape=(batch_size, trt_max_output_seq_len, 1), dtype=np.int32)), shape=(batch_size, trt_max_output_seq_len, 1)) # (b, t, 1) out_ones = ones.get_output(0) # (b, t, 1) seq_mask = network.add_elementwise( input1=out_seq_mask, input2=out_ones, op=ElementWiseOperation.EQUAL) # (b, t, 1) seq_mask.name = "mel_side.seq_mask" out_seq_mask = seq_mask.get_output(0) # Positional Encoding seq = network.add_plugin_v2([out_seq], self.get_plugin('AddPosEncPlugin')) seq.name = "mel_side.add_pos_enc" out_seq = seq.get_output(0) if self.validate_accuracy: self.add_activation_as_output(network, out_seq, "act.mel_side.add_pos_enc") for layer_idx in range(self.model.mel_side_n_layer): out_seq = self.populate_fft( name="mel_side.layer_stack.{}".format(layer_idx), network=network, weights=weights, seq_tensor=out_seq, seq_mask_tensor=out_seq_mask, batch_size=self.batch_size, max_seq_len=trt_max_output_seq_len, d_model=d_model, n_heads=self.model.mel_side_head, d_k=self.model.mel_side.d_k, d_v=self.model.mel_side.d_v, self_attn_temp=self.model.mel_side.d_k**0.5, conv_filter_size=self.model.mel_side_conv1d_filter_size, conv_kernel_size=self.model.fft_conv1d_kernel, conv_padding=self.model.fft_conv1d_padding) if self.validate_accuracy: self.add_activation_as_output(network, out_seq, "act.mel_side.seq") ## # Linear ## # Pytorch: self.mel_linear = nn.Linear(mel_side_output_size, n_mels, bias=True) w = weights["mel_linear.weight"] # (n_mels, d_model) out_w = network.add_constant(shape=(1, self.model.n_mels, d_model), weights=trt.Weights(w)).get_output( 0) # (1, n_mels, d_model) linear_w = network.add_matrix_multiply( out_seq, MatrixOperation.NONE, out_w, MatrixOperation.TRANSPOSE ) # (b, t, d_model) * (1->b, d_model, n_mels) => (b, t, n_mels) linear_w.name = "linear.w" out_seq = linear_w.get_output(0) # (b, t, n_mels) b = weights["mel_linear.bias"] # (n_mels,) out_b = network.add_constant(shape=(1, 1, self.model.n_mels), weights=trt.Weights(b)).get_output( 0) # (1, 1, n_mels) linear_b = network.add_elementwise(input1=out_seq, input2=out_b, op=trt.ElementWiseOperation.SUM) linear_b.name = "linear.b" out_seq = linear_b.get_output(0) # (b, t, n_mels) ## # Outputs ## if self.validate_accuracy: self.add_activation_as_output(network, out_seq_mask, "out.seq_mask") self.add_activation_as_output(network, out_seq, "out.seq") seq = network.add_shuffle( input=out_seq) # (b, t, n_mels) to (b, n_mels, t) seq.reshape_dims = Dims( (batch_size, trt_max_output_seq_len, self.model.n_mels)) seq.second_transpose = trt.Permutation([0, 2, 1]) seq.name = "trans_seq" out_seq = seq.get_output(0) seq_mask = network.add_shuffle( input=out_seq_mask) # (b, t, 1) to (b, t) seq_mask.reshape_dims = Dims((batch_size, trt_max_output_seq_len)) out_seq_mask = seq_mask.get_output(0) # (b, t) network.mark_output(tensor=out_seq) # (b, n_mels, t) network.mark_output(tensor=out_seq_mask) # (b, t) return network
def initialize(self): """Create DLRM network using TRT API and plugins and set the weights.""" useConvForFC_bottom = (self.precision == "int8") useConvForFC_top = (self.precision == "int8") interactionsOutputInterleaved = False if self.need_calibration or self.input_dtype != "int8" else True # Turn off interleaved format if top_mlp use non-interleaved format if not self.enable_interleaved_top_mlp: interactionsOutputInterleaved = False else: print("Using batch-interleaved format for top_mlp.") # Check if we should split the model into the binary file with embedding weights quantized and model without embeddings if not (os.path.isfile(self.embedding_weights_binary_filepath) and os.path.isfile(self.model_without_embedding_weights_filepath)): logging.info("Loading checkpoint from " + self.model_filepath) self.weights = torch.load(self.model_filepath, map_location="cpu")["state_dict"] self.dump_embedding_weights_to_binary_file() logging.info("Writing model without embedding weights to " + self.model_without_embedding_weights_filepath) torch.save(self.weights, self.model_without_embedding_weights_filepath) del self.weights # Dump row frequencies to file in binary format if self.use_row_frequencies and not os.path.isfile(self.row_frequencies_binary_filepath): logging.info("Writing row frequencies to " + self.row_frequencies_binary_filepath) self.dump_row_frequencies_to_binary_file() # Load weights self.weights = torch.load(self.model_without_embedding_weights_filepath, map_location="cpu") # Create network. self.network = self.builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) # Numerical input numerical_input = self.network.add_input("numerical_input", trt.DataType.FLOAT, (-1, self.num_numerical_inputs, 1, 1)) if not self.need_calibration: if self.input_dtype == "int8": numerical_input.dtype = trt.int8 elif self.input_dtype == "fp16": numerical_input.dtype = trt.float16 if self.input_format == "linear": numerical_input.allowed_formats = 1 << int(trt.TensorFormat.LINEAR) elif self.input_format == "chw4": numerical_input.allowed_formats = 1 << int(trt.TensorFormat.CHW4) elif self.input_format == "chw32": numerical_input.allowed_formats = 1 << int(trt.TensorFormat.CHW32) # Bottom MLP if self.need_calibration or self.input_dtype != "int8": bottom_mlp = self.add_mlp(numerical_input, self.num_numerical_inputs, self.bottom_mlp_channels, self.bottom_mlp_names, last_relu=True, useConvForFC=useConvForFC_bottom) else: bottom_mlp_plugin, output_tensor_name = self.add_fused_bottom_mlp("DLRM_BOTTOM_MLP_TRT", numerical_input, self.num_numerical_inputs, self.bottom_mlp_channels, self.bottom_mlp_names) bottom_mlp = self.network.add_plugin_v2([numerical_input], bottom_mlp_plugin) bottom_mlp.get_output(0).name = output_tensor_name bottom_mlp_shuffle = self.network.add_shuffle(bottom_mlp.get_output(0)) bottom_mlp_shuffle.reshape_dims = trt.Dims((-1, 1, self.embedding_size)) # Index input index_input = self.network.add_input("index_input", trt.DataType.INT32, (-1, self.num_features)) # Embedding lookup and interactions dlrm_interactions_plugin = self.get_dlrm_interactions_plugin("DLRM_INTERACTIONS_TRT", np.cumsum( np.array([0] + self.embedding_rows[:-1]).astype(np.int32)).astype(np.int32), interactionsOutputInterleaved) interaction_output_concat = self.network.add_plugin_v2([bottom_mlp.get_output(0), index_input], dlrm_interactions_plugin) interaction_output_concat.name = "interaction_plugin" interaction_output_concat.get_output(0).name = "interaction_output_concat_output" if self.enable_interleaved_top_mlp and not interactionsOutputInterleaved: # Shuffle from [BS, C, 1, 1] to [BS//2, C, 2, 1] before top_mlp interleave_pre_top_mlp = self.network.add_shuffle(interaction_output_concat.get_output(0)) interleave_pre_top_mlp.reshape_dims = trt.Dims((-1, 2, interaction_output_concat.get_output(0).shape[1], 0)) interleave_pre_top_mlp.second_transpose = trt.Permutation([0, 2, 1, 3]) interleave_pre_top_mlp.name = "interleave_pre_top_mlp" top_mlp_input = interleave_pre_top_mlp.get_output(0) top_mlp_input.name = "interleave_pre_top_mlp" else: top_mlp_input = interaction_output_concat.get_output(0) # Insert small-tile GEMM plugin. The plugin supports Ampere-only. gpu_arch = get_system().arch system_id = get_system().gpu if self.use_small_tile_gemm_plugin: if gpu_arch != Architecture.Ampere: print("Small-Tile GEMM plugin does not support {}. Plugin disabled.".format(system_id)) self.use_small_tile_gemm_plugin = False # Enable gemm plugin with interleaved format is not recommended. # Note (2/7/21): GEMM plugin doesn't perform well when H*W > 1 if self.use_small_tile_gemm_plugin and self.enable_interleaved_top_mlp: print("Warning: small-Tile GEMM plugin performance will be " "significantly impacted by interleaved format. Turn off " "interleaved format for the best performance") tmp_mlp_input = top_mlp_input tmp_input_size = self.top_mlp_input_size # Helper function to check whether the provided shape is supported by # Small-Tile GEMM plugin def support_small_tile_gemm_func(C, K): return \ (C >= 256) and (C <= 1280) and (C % 128 == 0) and (K % 128 == 0) # Split the top_mlp layers, and use GEMM plugin for 2,4,6 # C, K for top_mlp.0,2,4,6,8: [480,1024],[1024,1024],[1024,512],[512,256],[256,1] for i in range(len(self.top_mlp_channels)): # Insert plugin if the layer meets the restriction if support_small_tile_gemm_func(tmp_input_size, self.top_mlp_channels[i]) and \ self.use_small_tile_gemm_plugin: print("Replacing {} with Small-Tile GEMM Plugin, with fairshare cache size {}". format(self.top_mlp_names[i], self.gemm_plugin_fairshare_cache_size)) layer_top_mlp = self.add_small_tile_gemm_top_mlp( tmp_mlp_input, tmp_input_size, self.top_mlp_channels[i], self.top_mlp_names[i], self.gemm_plugin_fairshare_cache_size ) else: layer_top_mlp = self.add_single_mlp( tmp_mlp_input, tmp_input_size, self.top_mlp_channels[i], self.top_mlp_names[i], useConvForFC=useConvForFC_top, add_relu=(i != len(self.top_mlp_channels) - 1)) tmp_mlp_input = layer_top_mlp.get_output(0) tmp_input_size = self.top_mlp_channels[i] top_mlp = layer_top_mlp if self.enable_interleaved_top_mlp: # Shuffle [BS//2, 1, 2, 1] back to [BS, 1, 1, 1] interleave_post_top_mlp = self.network.add_shuffle(top_mlp.get_output(0)) interleave_post_top_mlp.reshape_dims = trt.Dims((-1, 0, 1, 0)) interleave_post_top_mlp.name = "interleave_post_top_mlp" sigmoid_input = interleave_post_top_mlp.get_output(0) sigmoid_input.name = "interleave_post_top_mlp" else: sigmoid_input = top_mlp.get_output(0) # Sigmoid sigmoid_layer = self.network.add_activation(sigmoid_input, trt.ActivationType.SIGMOID) sigmoid_layer.name = "sigmoid" sigmoid_layer.get_output(0).name = "sigmoid_output" # Output self.network.mark_output(sigmoid_layer.get_output(0)) # Make sure we release the memory to system del self.weights self.initialized = True