Exemplo n.º 1
0
 def permuteHelper(self, network, inputs, dims):
     tensor = inputs[0]
     checkType([tensor, dims], [trt.ITensor, list]), "Input Type Error"
     assert max(dims) < 4 and dims[0] == 0
     dims = list(map(lambda x: x - 1, dims))[1:]
     shuffle = network.add_shuffle(tensor)
     shuffle.first_transpose = trt.Permutation(dims)
     return shuffle
Exemplo n.º 2
0
def trt_transpose(network,input_size,aixs=(0,2,3,1)):
    """
    x = np.random.random_sample([5,2,3,3]).astype(np.float32)
    torch_x = torch.from_numpy(x).permute(0,2,3,1)
    """
    layer = network.add_shuffle(input=input_size)
    layer.first_transpose = trt.Permutation(aixs)

    return layer
Exemplo n.º 3
0
def pixel_unshuffle(network: trt.INetworkDefinition, input: trt.ITensor,
                    downscale_factor: int) -> trt.ITensor:

    n, ic, ih, iw = input.shape
    assert ih % downscale_factor == 0 and ih % downscale_factor == 0
    oc = ic * (downscale_factor**2)
    oh = ih // downscale_factor
    ow = iw // downscale_factor

    reshape = network.add_shuffle(input)
    reshape.reshape_dims = trt.Dims(
        [n, ic, oh, downscale_factor, ow, downscale_factor])
    reshape.second_transpose = trt.Permutation([0, 1, 3, 5, 2, 4])

    reshape = network.add_shuffle(reshape.get_output(0))
    reshape.reshape_dims = trt.Dims([n, oc, oh, ow])

    return reshape.get_output(0)
Exemplo n.º 4
0
def pixel_shuffle(network: trt.INetworkDefinition, input: trt.ITensor,
                  upscale_factor: int) -> trt.ITensor:

    n, ic, ih, iw = input.shape
    assert ic % (upscale_factor**2) == 0
    oc = ic // (upscale_factor**2)
    oh = ih * upscale_factor
    ow = iw * upscale_factor

    reshape = network.add_shuffle(input)
    reshape.reshape_dims = trt.Dims(
        [n, oc, upscale_factor, upscale_factor, ih, iw])
    reshape.second_transpose = trt.Permutation([0, 1, 4, 2, 5, 3])

    reshape = network.add_shuffle(reshape.get_output(0))
    reshape.reshape_dims = trt.Dims([n, oc, oh, ow])

    return reshape.get_output(0)
def populate_network(network, weights, weights_color):
    # below is network define, no need to care
    # Configure the network layers based on the weights provided.
    #定义网络输入
    input_tensor = network.add_input(name=ModelData.INPUT_NAME,
                                     dtype=ModelData.DTYPE,
                                     shape=ModelData.INPUT_SHAPE)

    # Backbone Sequential
    conv0 = conv_trt(weights, input_tensor, 'cnn_1.conv0', network, 64)
    print('conv0')
    print(conv0.get_output(0).shape)
    bn0 = bn_trt(weights, conv0.get_output(0), 'cnn_1.batchnorm0', network)
    relu0 = network.add_activation(input=bn0.get_output(0),
                                   type=trt.ActivationType.RELU)
    pooling0 = max_pooling_trt(relu0.get_output(0), network)
    print('pooling0')
    print(pooling0.get_output(0).shape)

    conv1 = conv_trt(weights, pooling0.get_output(0), 'cnn_1.conv1', network,
                     128)
    print('conv1')
    print(conv1.get_output(0).shape)
    bn1 = bn_trt(weights, conv1.get_output(0), 'cnn_1.batchnorm1', network)
    relu1 = network.add_activation(input=bn1.get_output(0),
                                   type=trt.ActivationType.RELU)
    pooling1 = max_pooling_trt(relu1.get_output(0), network)
    print('pooling1')
    print(pooling1.get_output(0).shape)

    conv2 = conv_trt(weights, pooling1.get_output(0), 'cnn_1.conv2', network,
                     256)
    print('conv2')
    print(conv2.get_output(0).shape)
    bn2 = bn_trt(weights, conv2.get_output(0), 'cnn_1.batchnorm2', network)
    relu2 = network.add_activation(input=bn2.get_output(0),
                                   type=trt.ActivationType.RELU)
    # Backbone Sequential-2
    conv3 = conv_trt(weights, relu2.get_output(0), 'cnn_2.conv3', network, 256)
    print('conv3')
    print(conv3.get_output(0).shape)
    bn3 = bn_trt(weights, conv3.get_output(0), 'cnn_2.batchnorm3', network)
    relu3 = network.add_activation(input=bn3.get_output(0),
                                   type=trt.ActivationType.RELU)
    pooling2 = max_pooling_trt(relu3.get_output(0), network, stride=(2, 1))
    print('pooling2')
    print(pooling2.get_output(0).shape)

    conv4 = conv_trt(weights, pooling2.get_output(0), 'cnn_2.conv4', network,
                     512)
    print('conv4')
    print(conv4.get_output(0).shape)
    bn4 = bn_trt(weights, conv4.get_output(0), 'cnn_2.batchnorm4', network)
    relu4 = network.add_activation(input=bn4.get_output(0),
                                   type=trt.ActivationType.RELU)
    conv5 = conv_trt(weights, relu4.get_output(0), 'cnn_2.conv5', network, 512)
    print('conv5')
    print(conv5.get_output(0).shape)
    bn5 = bn_trt(weights, conv5.get_output(0), 'cnn_2.batchnorm5', network)
    relu5 = network.add_activation(input=bn5.get_output(0),
                                   type=trt.ActivationType.RELU)
    # Branch 1 Sequential
    pooling7 = max_pooling_trt(relu5.get_output(0), network)
    print('pooling7')
    print(pooling7.get_output(0).shape)

    conv9 = conv_trt(weights,
                     pooling7.get_output(0),
                     'branch1.conv9',
                     network,
                     512,
                     kernel=(2, 2),
                     stride=(2, 2),
                     padding=(0, 0))
    print('conv9')
    print(conv9.get_output(0).shape)
    bn9 = bn_trt(weights, conv9.get_output(0), 'branch1.batchnorm9', network)
    relu9 = network.add_activation(input=bn9.get_output(0),
                                   type=trt.ActivationType.RELU)

    permute1 = network.add_shuffle(relu9.get_output(0))
    permute1.first_transpose = trt.Permutation([3, 0, 1, 2])
    permute1.reshape_dims = [
        permute1.get_output(0).shape[0],
        permute1.get_output(0).shape[1], 1, 1, -1
    ]
    print('permute1')
    print(permute1.get_output(0).shape)

    fc1 = fc_trt(weights, permute1.get_output(0), 'fc1', network,
                 ModelData.OUTPUT_SIZE)
    fc1.get_output(0).name = ModelData.OUTPUT_NAME1
    print('fc1')
    print(fc1.get_output(0).shape)

    # Branch 2 Sequential
    pooling3 = max_pooling_trt(relu5.get_output(0), network, stride=(2, 1))
    conv7 = conv_trt(weights, pooling3.get_output(0), 'branch2.conv7', network,
                     512)
    print('conv7')
    print(conv7.get_output(0).shape)
    bn7 = bn_trt(weights, conv7.get_output(0), 'branch2.batchnorm7', network)
    relu7 = network.add_activation(input=bn7.get_output(0),
                                   type=trt.ActivationType.RELU)

    conv8 = conv_trt(weights,
                     relu7.get_output(0),
                     'branch2.conv8',
                     network,
                     512,
                     kernel=(2, 2),
                     stride=(1, 1),
                     padding=(0, 0))
    print('conv8')
    print(conv8.get_output(0).shape)
    bn8 = bn_trt(weights, conv8.get_output(0), 'branch2.batchnorm8', network)
    relu8 = network.add_activation(input=bn8.get_output(0),
                                   type=trt.ActivationType.RELU)

    permute2 = network.add_shuffle(relu8.get_output(0))
    permute2.first_transpose = trt.Permutation([3, 0, 1, 2])
    permute2.reshape_dims = [
        permute2.get_output(0).shape[0],
        permute2.get_output(0).shape[1], 1, 1, -1
    ]
    print('permute2')
    print(permute2.get_output(0).shape)

    fc2 = fc_trt(weights, permute2.get_output(0), 'fc2', network,
                 ModelData.OUTPUT_SIZE)
    fc2.get_output(0).name = ModelData.OUTPUT_NAME2
    print('fc2')
    print(fc2.get_output(0).shape)

    cat1 = network.add_concatenation([fc1.get_output(0), fc2.get_output(0)])
    cat1.axis = 0
    cat1.get_output(0).name = ModelData.OUTPUT_NAME3
    permute3 = network.add_shuffle(cat1.get_output(0))
    permute3.reshape_dims = [0, 0, -1]
    permute3.second_transpose = trt.Permutation([1, 2, 0])
    print('permute3')
    print(permute3.get_output(0).shape)
    re_softmax = network.add_softmax(input=permute3.get_output(0))
    re_softmax.axes = 2
    print('re_softmax')
    print(re_softmax.get_output(0).shape)
    permute4 = network.add_shuffle(re_softmax.get_output(0))
    permute4.first_transpose = trt.Permutation([0, 2, 1])
    print('permute4')
    print(permute4.get_output(0).shape)

    #branch color
    c_pooling0 = max_pooling_trt(relu2.get_output(0), network, stride=(2, 2))
    c_conv0 = conv_trt(weights_color, c_pooling0.get_output(0),
                       'color_branch.c_conv0', network, 128)
    print('c_conv0')
    print(c_conv0.get_output(0).shape)
    c_relu0 = network.add_activation(input=c_conv0.get_output(0),
                                     type=trt.ActivationType.RELU)

    c_pooling1 = max_pooling_trt(c_relu0.get_output(0), network, stride=(2, 2))

    c_conv1 = conv_trt(weights_color, c_pooling1.get_output(0),
                       'color_branch.c_conv1', network, 64)
    print('c_conv1')
    print(c_conv1.get_output(0).shape)
    c_relu1 = network.add_activation(input=c_conv1.get_output(0),
                                     type=trt.ActivationType.RELU)

    fc_c = fc_trt(weights_color, c_relu1.get_output(0), 'fc_c', network,
                  ModelData.OUTPUT_COLOR_SIZE)
    fc_c.get_output(0).name = ModelData.OUTPUT_COLOR_NAME
    print('fc_c')
    print(fc_c.get_output(0).shape)
    c_softmax = network.add_softmax(input=fc_c.get_output(0))
    c_softmax.axes = 2
    print('c_softmax')
    print(c_softmax.get_output(0).shape)

    #定义网络输出
    network.mark_output(tensor=permute4.get_output(0))
    network.mark_output(tensor=c_softmax.get_output(0))
Exemplo n.º 6
0
def populate_network(network, weights):
    # below is network define, no need to care
    # Configure the network layers based on the weights provided.
    input_tensor = network.add_input(name=ModelData.INPUT_NAME,
                                     dtype=ModelData.DTYPE,
                                     shape=ModelData.INPUT_SHAPE)

    # Backbone Sequential
    conv0 = conv_trt(weights, input_tensor, 'cnn.conv0', network, 64)
    print('conv0')
    print(conv0.get_output(0).shape)
    bn0 = bn_trt(weights, conv0.get_output(0), 'cnn.batchnorm0', network)
    relu0 = network.add_activation(input=bn0.get_output(0),
                                   type=trt.ActivationType.LEAKY_RELU)
    relu0.alpha = ModelData.RELU_alpha
    pooling0 = max_pooling_trt(relu0.get_output(0), network)
    print('pooling0')
    print(pooling0.get_output(0).shape)

    conv1 = conv_trt(weights, pooling0.get_output(0), 'cnn.conv1', network,
                     128)
    print('conv1')
    print(conv1.get_output(0).shape)
    bn1 = bn_trt(weights, conv1.get_output(0), 'cnn.batchnorm1', network)
    relu1 = network.add_activation(input=bn1.get_output(0),
                                   type=trt.ActivationType.LEAKY_RELU)
    relu1.alpha = ModelData.RELU_alpha
    pooling1 = max_pooling_trt(relu1.get_output(0), network)
    print('pooling1')
    print(pooling1.get_output(0).shape)

    conv2 = conv_trt(weights, pooling1.get_output(0), 'cnn.conv2', network,
                     256)
    print('conv2')
    print(conv2.get_output(0).shape)
    bn2 = bn_trt(weights, conv2.get_output(0), 'cnn.batchnorm2', network)
    relu2 = network.add_activation(input=bn2.get_output(0),
                                   type=trt.ActivationType.LEAKY_RELU)
    relu2.alpha = ModelData.RELU_alpha
    conv3 = conv_trt(weights, relu2.get_output(0), 'cnn.conv3', network, 256)
    print('conv3')
    print(conv3.get_output(0).shape)
    bn3 = bn_trt(weights, conv3.get_output(0), 'cnn.batchnorm3', network)
    relu3 = network.add_activation(input=bn3.get_output(0),
                                   type=trt.ActivationType.LEAKY_RELU)
    relu3.alpha = ModelData.RELU_alpha
    pooling2 = max_pooling_trt(relu3.get_output(0), network, stride=(2, 1))
    print('pooling2')
    print(pooling2.get_output(0).shape)

    conv4 = conv_trt(weights, pooling2.get_output(0), 'cnn.conv4', network,
                     512)
    print('conv4')
    print(conv4.get_output(0).shape)
    bn4 = bn_trt(weights, conv4.get_output(0), 'cnn.batchnorm4', network)
    relu4 = network.add_activation(input=bn4.get_output(0),
                                   type=trt.ActivationType.LEAKY_RELU)
    relu4.alpha = ModelData.RELU_alpha
    conv5 = conv_trt(weights, relu4.get_output(0), 'cnn.conv5', network, 512)
    print('conv5')
    print(conv5.get_output(0).shape)
    bn5 = bn_trt(weights, conv5.get_output(0), 'cnn.batchnorm5', network)
    relu5 = network.add_activation(input=bn5.get_output(0),
                                   type=trt.ActivationType.LEAKY_RELU)
    relu5.alpha = ModelData.RELU_alpha
    # Branch 1 Sequential
    pooling7 = max_pooling_trt(relu5.get_output(0), network)
    print('pooling7')
    print(pooling7.get_output(0).shape)

    conv9 = conv_trt(weights,
                     pooling7.get_output(0),
                     'branch1.conv9',
                     network,
                     512,
                     kernel=(2, 2),
                     stride=(2, 2),
                     padding=(0, 0))
    print('conv9')
    print(conv9.get_output(0).shape)
    bn9 = bn_trt(weights, conv9.get_output(0), 'branch1.batchnorm9', network)
    relu9 = network.add_activation(input=bn9.get_output(0),
                                   type=trt.ActivationType.RELU)
    permute1 = network.add_shuffle(relu9.get_output(0))
    permute1.first_transpose = trt.Permutation([2, 0, 1])
    permute1.reshape_dims = [0, 1, 1, -1]
    print('permute1')
    print(permute1.get_output(0).shape)
    fc1 = fc_trt(weights, permute1.get_output(0), 'fc1', network,
                 ModelData.OUTPUT_SIZE)
    fc1.get_output(0).name = ModelData.OUTPUT_NAME1
    print('fc1')
    print(fc1.get_output(0).shape)

    # Branch 2 Sequential
    pooling3 = max_pooling_trt(relu5.get_output(0), network, stride=(2, 1))
    conv7 = conv_trt(weights, pooling3.get_output(0), 'branch2.conv7', network,
                     512)
    print('conv7')
    print(conv7.get_output(0).shape)
    bn7 = bn_trt(weights, conv7.get_output(0), 'branch2.batchnorm7', network)
    relu7 = network.add_activation(input=bn7.get_output(0),
                                   type=trt.ActivationType.RELU)

    conv8 = conv_trt(weights,
                     relu7.get_output(0),
                     'branch2.conv8',
                     network,
                     512,
                     kernel=(2, 2),
                     stride=(1, 1),
                     padding=(0, 0))
    print('conv8')
    print(conv8.get_output(0).shape)
    bn8 = bn_trt(weights, conv8.get_output(0), 'branch2.batchnorm8', network)
    relu8 = network.add_activation(input=bn8.get_output(0),
                                   type=trt.ActivationType.RELU)

    permute2 = network.add_shuffle(relu8.get_output(0))
    permute2.first_transpose = trt.Permutation([2, 0, 1])
    permute2.reshape_dims = [0, 1, 1, -1]
    fc2 = fc_trt(weights, permute2.get_output(0), 'fc2', network,
                 ModelData.OUTPUT_SIZE)
    fc2.get_output(0).name = ModelData.OUTPUT_NAME2
    print('fc2')
    print(fc2.get_output(0).shape)

    cat1 = network.add_concatenation([fc1.get_output(0), fc2.get_output(0)])
    cat1.axis = 0
    cat1.get_output(0).name = ModelData.OUTPUT_NAME3
    permute3 = network.add_shuffle(cat1.get_output(0))
    permute3.reshape_dims = [0, -1]

    network.mark_output(tensor=permute3.get_output(0))
Exemplo n.º 7
0
    def initialize(self):
        useConvForFC_bottom = (self.precision == "int8")
        useConvForFC_top = (self.precision == "int8")
        interactionsOutputInterleaved = False if self.need_calibration or self.input_dtype != "int8" else True

        # Check if we should split the model into the binary file with embedding weights quantized and model without embeddings
        if not (os.path.isfile(self.embedding_weights_binary_filepath) and os.path.isfile(self.model_without_embedding_weights_filepath)):
            logging.info("Loading checkpoint from " + self.model_filepath)
            self.weights = torch.load(self.model_filepath, map_location="cpu")["state_dict"]
            self.dump_embedding_weights_to_binary_file()
            logging.info("Writing model without embedding weights to " + self.model_without_embedding_weights_filepath)
            torch.save(self.weights, self.model_without_embedding_weights_filepath)
            del self.weights

        # Dump row frequencies to file in binary format
        if self.use_row_frequencies and not os.path.isfile(self.row_frequencies_binary_filepath):
            logging.info("Writing row frequencies to " + self.row_frequencies_binary_filepath)
            self.dump_row_frequencies_to_binary_file()

        # Load weights
        self.weights = torch.load(self.model_without_embedding_weights_filepath, map_location="cpu")

        # Create network.
        self.network = self.builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))

        # Numerical input
        numerical_input = self.network.add_input("numerical_input", trt.DataType.FLOAT, (-1, self.num_numerical_inputs, 1, 1))
        if not self.need_calibration:
            if self.input_dtype == "int8":
                numerical_input.dtype = trt.int8
            elif self.input_dtype == "fp16":
                numerical_input.dtype = trt.float16
            if self.input_format == "linear":
                numerical_input.allowed_formats = 1 << int(trt.TensorFormat.LINEAR)
            elif self.input_format == "chw4":
                numerical_input.allowed_formats = 1 << int(trt.TensorFormat.CHW4)
            elif self.input_format == "chw32":
                numerical_input.allowed_formats = 1 << int(trt.TensorFormat.CHW32)

        # Bottom MLP
        if self.need_calibration or self.input_dtype != "int8":
            bottom_mlp = self.add_mlp(numerical_input, self.num_numerical_inputs, self.bottom_mlp_channels, self.bottom_mlp_names,
                last_relu=True, useConvForFC=useConvForFC_bottom)
        else:
            bottom_mlp_plugin, output_tesnor_name = self.add_fused_bottom_mlp("DLRM_BOTTOM_MLP_TRT", numerical_input, self.num_numerical_inputs, self.bottom_mlp_channels, self.bottom_mlp_names)
            bottom_mlp = self.network.add_plugin_v2([numerical_input], bottom_mlp_plugin)
            bottom_mlp.get_output(0).name = output_tesnor_name
        bottom_mlp_shuffle = self.network.add_shuffle(bottom_mlp.get_output(0))
        bottom_mlp_shuffle.reshape_dims = trt.Dims((-1, 1, self.embedding_size))

        # Index input
        index_input = self.network.add_input("index_input", trt.DataType.INT32, (-1, self.num_features))
        
        # Embedding lookup and interactions
        dlrm_interactions_plugin = self.get_dlrm_interactions_plugin("DLRM_INTERACTIONS_TRT", np.cumsum(np.array([0] + self.embedding_rows[:-1]).astype(np.int32)).astype(np.int32), interactionsOutputInterleaved)
        interaction_output_concat = self.network.add_plugin_v2([bottom_mlp.get_output(0), index_input], dlrm_interactions_plugin)
        interaction_output_concat.name = "interaction_plugin"
        interaction_output_concat.get_output(0).name = "interaction_output_concat_output"

        if self.INTERLEAVED_TOP_MLP and not interactionsOutputInterleaved:
            # Shuffle from [BS, C, 1, 1] to [BS//2, C, 2, 1] before top_mlp
            interleave_pre_top_mlp = self.network.add_shuffle(interaction_output_concat.get_output(0))
            interleave_pre_top_mlp.reshape_dims = trt.Dims((-1, 2, interaction_output_concat.get_output(0).shape[1], 0))
            interleave_pre_top_mlp.second_transpose = trt.Permutation([0, 2, 1, 3])
            interleave_pre_top_mlp.name = "interleave_pre_top_mlp"

            top_mlp_input = interleave_pre_top_mlp.get_output(0)
            top_mlp_input.name = "interleave_pre_top_mlp"
        else:
            top_mlp_input = interaction_output_concat.get_output(0)

        # Top MLP
        top_mlp = self.add_mlp(top_mlp_input, self.top_mlp_input_size, self.top_mlp_channels, self.top_mlp_names,
                               last_relu=False, useConvForFC=useConvForFC_top)

        if self.INTERLEAVED_TOP_MLP:
            # Shuffle back to [BS, 1, 1, 1] from [BS//2, 1, 2, 1]
            interleave_post_top_mlp = self.network.add_shuffle(top_mlp.get_output(0))
            interleave_post_top_mlp.reshape_dims = trt.Dims((-1, 0, 1, 0))
            interleave_post_top_mlp.name = "interleave_post_top_mlp"

            sigmoid_input = interleave_post_top_mlp.get_output(0)
            sigmoid_input.name = "interleave_post_top_mlp"
        else:
            sigmoid_input = top_mlp.get_output(0)

        # Sigmoid
        sigmoid_layer = self.network.add_activation(sigmoid_input, trt.ActivationType.SIGMOID)
        sigmoid_layer.name = "sigmoid"
        sigmoid_layer.get_output(0).name = "sigmoid_output"

        # Output
        self.network.mark_output(sigmoid_layer.get_output(0))

        # Make sure we release the memory to system
        del self.weights

        self.initialized = True
    def populate_duration_predictor(self, name, network, weights, seq_tensor,
                                    seq_mask_tensor, batch_size, max_seq_len,
                                    d_model):
        duration_predictor_filter_size = self.model.duration_predictor_filter_size
        duration_predictor_kernel_size = self.model.duration_predictor_kernel_size

        # Pytorch: input *= input_mask.to(input.dtype)
        # can be skipped.

        # Pytorch: out = self.conv1d_1(input.transpose(1,2)).transpose(1,2)
        trans1 = network.add_shuffle(
            input=seq_tensor)  # (b, t, d_model) to  (b, d_model, t, 1)
        trans1.first_transpose = trt.Permutation([0, 2, 1])
        trans1.reshape_dims = Dims((batch_size, d_model, max_seq_len, 1))
        trans1.name = "{}.trans1".format(name)
        out = trans1.get_output(0)  # (b, d_model, t, 1)

        conv1_w = weights["{}.conv1d_1.weight".format(
            name
        )]  # (1, d_model, duration_predictor_filter_size, duration_predictor_kernel_size, 1)
        conv1_b = weights["{}.conv1d_1.bias".format(
            name)]  # (duration_predictor_filter_size, )
        conv1 = network.add_convolution(
            input=out,
            num_output_maps=duration_predictor_filter_size,
            kernel_shape=trt.DimsHW(duration_predictor_kernel_size, 1),
            kernel=Weights(conv1_w),
            bias=Weights(conv1_b))
        conv1.padding = trt.DimsHW(1, 0)
        conv1.name = "{}.conv1".format(name)
        out = conv1.get_output(0)  # (b, duration_predictor_filter_size, t, 1)

        trans2 = network.add_shuffle(
            input=out
        )  # (b, duration_predictor_filter_size, t, 1) to (b, t, duration_predictor_filter_size)
        trans2.first_transpose = trt.Permutation([0, 2, 1, 3])
        trans2.reshape_dims = Dims(
            (batch_size, max_seq_len, duration_predictor_filter_size))
        trans2.name = "{}.trans2".format(name)
        out = trans2.get_output(0)  # (b, t, duration_predictor_filter_size)

        # Pytorch: out = self.relu_1(out)
        relu = network.add_activation(input=out, type=trt.ActivationType.RELU)
        relu.name = "{}.relu1".format(name)
        out_relu = relu.get_output(0)  # (b, t, duration_predictor_filter_size)

        # Pytorch: out = self.layer_norm_1(out)
        out = self.populate_layernorm(name="{}.layer_norm_1".format(name),
                                      network=network,
                                      weights=weights,
                                      seq_tensor=out_relu,
                                      d_layer=duration_predictor_filter_size,
                                      batch_size=batch_size,
                                      max_seq_len=max_seq_len)

        # Pytorch: out = self.conv1d_2(out.transpose(1,2)).transpose(1,2)
        trans3 = network.add_shuffle(
            input=out
        )  # (b, t, duration_predictor_filter_size) to (b, duration_predictor_filter_size, t, 1)
        trans3.first_transpose = trt.Permutation([0, 2, 1])
        trans3.reshape_dims = Dims(
            (batch_size, duration_predictor_filter_size, max_seq_len, 1))
        trans3.name = "{}.trans3".format(name)
        out = trans3.get_output(0)  # (b, duration_predictor_filter_size, t, 1)

        conv2_w = weights["{}.conv1d_2.weight".format(
            name
        )]  # (1, duration_predictor_filter_size, duration_predictor_filter_size, duration_predictor_kernel_size, 1)
        conv2_b = weights["{}.conv1d_2.bias".format(
            name)]  # (duration_predictor_filter_size, )
        conv2 = network.add_convolution(
            input=out,
            num_output_maps=duration_predictor_filter_size,
            kernel_shape=trt.DimsHW(duration_predictor_kernel_size, 1),
            kernel=Weights(conv2_w),
            bias=Weights(conv2_b))
        conv2.padding = trt.DimsHW(1, 0)
        conv2.name = "{}.conv2".format(name)
        out = conv2.get_output(0)

        trans4 = network.add_shuffle(
            input=out
        )  # (b, duration_predictor_filter_size, t, 1) to (b, t, duration_predictor_filter_size)
        trans4.first_transpose = trt.Permutation([0, 2, 1, 3])
        trans4.reshape_dims = Dims(
            (batch_size, max_seq_len, duration_predictor_filter_size))
        trans4.name = "{}.trans4".format(name)
        out = trans4.get_output(0)  # (b, t, duration_predictor_filter_size)

        # Pytorch: out = self.relu_2(out)
        relu = network.add_activation(input=out, type=trt.ActivationType.RELU)
        relu.name = "{}.relu2".format(name)
        out_relu = relu.get_output(0)  # (b, t, duration_predictor_filter_size)

        # Pytorch: out = self.layer_norm_2(out)
        out = self.populate_layernorm(
            name="{}.layer_norm_2".format(name),
            network=network,
            weights=weights,
            seq_tensor=out_relu,
            d_layer=duration_predictor_filter_size,
            batch_size=batch_size,
            max_seq_len=max_seq_len,
        )  # (b, t, duration_predictor_filter_size)

        # Pytorch: out = self.linear_layer(out)
        w = weights["{}.linear_layer.weight".format(
            name)]  # (1, duration_predictor_filter_size)
        out_w = network.add_constant(
            shape=(1, 1, duration_predictor_filter_size),
            weights=trt.Weights(w)).get_output(
                0)  # (1, 1, duration_predictor_filter_size)
        linear_w = network.add_matrix_multiply(
            out, MatrixOperation.NONE, out_w, MatrixOperation.TRANSPOSE
        )  # (b, t, duration_predictor_filter_size) * (1->b, duration_predictor_filter_size, 1) => (b, t, 1)
        linear_w.name = "{}.linear.w".format(name)
        out = linear_w.get_output(0)  # (b, t, 1)

        b = weights["{}.linear_layer.bias".format(name)]  # (1,)
        out_b = network.add_constant(
            shape=(1, 1, 1), weights=trt.Weights(b)).get_output(0)  # (1, 1, 1)
        linear_b = network.add_elementwise(input1=out,
                                           input2=out_b,
                                           op=trt.ElementWiseOperation.SUM)
        linear_b.name = "{}.linear.b".format(name)
        out = linear_b.get_output(0)  # (b, t, 1)

        # Pytorch: out *= input_mask.to(out.dtype)
        zeros = network.add_constant(weights=Weights(
            np.zeros(shape=(batch_size, max_seq_len, 1), dtype=np.float32)),
                                     shape=(batch_size, max_seq_len, 1))
        out_zeros = zeros.get_output(0)  # (b, t, 1)
        dur = network.add_select(condition=seq_mask_tensor,
                                 then_input=out,
                                 else_input=out_zeros)
        dur.name = "{}.mask".format(name)
        out_dur = dur.get_output(0)

        # Pytorch: duration = torch.clamp_min(torch.exp(duration) - 1, 0)
        exp = network.add_unary(input=out_dur, op=trt.UnaryOperation.EXP)
        exp.name = "{}.exp".format(name)
        out_exp = exp.get_output(0)
        ones = network.add_constant(weights=Weights(
            np.ones(shape=(batch_size, max_seq_len, 1), dtype=np.float32)),
                                    shape=(batch_size, max_seq_len, 1))
        out_ones = ones.get_output(0)  # (b, t, 1)
        sub = network.add_elementwise(input1=out_exp,
                                      input2=out_ones,
                                      op=trt.ElementWiseOperation.SUB)
        sub.name = "{}.sub_one".format(name)
        out_sub = sub.get_output(0)
        dur = network.add_elementwise(input1=out_sub,
                                      input2=out_zeros,
                                      op=trt.ElementWiseOperation.MAX)
        dur.name = "{}.max".format(name)
        out_dur = dur.get_output(0)

        # Pytorch: repeats = torch.round(repeats).long()
        half_ones = network.add_constant(weights=Weights(
            np.full((batch_size, max_seq_len, 1), 0.5, dtype=np.float32)),
                                         shape=(batch_size, max_seq_len, 1))
        out_half_ones = half_ones.get_output(0)  # (b, t, 1)
        add = network.add_elementwise(input1=out_dur,
                                      input2=out_half_ones,
                                      op=trt.ElementWiseOperation.SUM)
        add.name = "{}.round_add".format(name)
        out_add = add.get_output(0)  # (b, t, 1)
        dur = network.add_elementwise(input1=out_add,
                                      input2=out_ones,
                                      op=trt.ElementWiseOperation.FLOOR_DIV)
        dur.name = "{}.round_floor_div".format(name)
        out_dur = dur.get_output(0)  # (b, t, 1)

        dur = network.add_shuffle(input=out_dur)  # (b, t, 1) to (b, t)
        dur.reshape_dims = Dims(shape=(batch_size, max_seq_len))
        out_dur = dur.get_output(0)  # (b, t)

        return out_dur
    def populate_pos_wise(self, name, network, weights, seq_tensor, batch_size,
                          max_seq_len, d_model, conv_filter_size,
                          conv_kernel_size, conv_padding):
        # Pytorch: output = x.transpose(1, 2)
        trans1 = network.add_shuffle(
            input=seq_tensor)  # (b, t, d_model) to (b, d_model, t, 1)
        trans1.first_transpose = trt.Permutation([0, 2, 1])
        trans1.reshape_dims = Dims((batch_size, d_model, max_seq_len, 1))
        trans1.name = "{}.trans1".format(name)
        out = trans1.get_output(0)  # (b, d_model, t, 1)

        # Pytorch: output = self.w_1(output)
        conv1_w = weights["{}.w_1.weight".format(
            name)]  # (1, conv_filter_size, d_model, conv_kernel_size, 1)
        conv1_b = weights["{}.w_1.bias".format(name)]  # (cov_filter_size,)
        conv1 = network.add_convolution(input=out,
                                        num_output_maps=conv_filter_size,
                                        kernel_shape=trt.DimsHW(
                                            conv_kernel_size, 1),
                                        kernel=Weights(conv1_w),
                                        bias=Weights(conv1_b))
        conv1.padding = trt.DimsHW(1, 0)
        conv1.name = "{}.conv1".format(name)
        out = conv1.get_output(0)  # (b, conv_filter_size, t, 1)

        if self.validate_accuracy:
            self.add_activation_as_output(network, out,
                                          "act.{}.conv1".format(name))

        # Pytorch: output = F.relu(output)
        relu = network.add_activation(input=out, type=trt.ActivationType.RELU)
        relu.name = "{}.relu".format(name)
        out = relu.get_output(0)  # (b, conv_filter_size, t, 1)

        # Pytorch: output = self.w_2(output)
        conv2_w = weights["{}.w_2.weight".format(
            name)]  # (1, d_model, conv_filter_size, conv_kernel_size, 1)
        conv2_b = weights["{}.w_2.bias".format(name)]  # (d_model, )
        conv2 = network.add_convolution(input=out,
                                        num_output_maps=d_model,
                                        kernel_shape=trt.DimsHW(
                                            conv_kernel_size, 1),
                                        kernel=Weights(conv2_w),
                                        bias=Weights(conv2_b))
        conv2.padding = trt.DimsHW(1, 0)
        conv2.name = "{}.conv2".format(name)
        out = conv2.get_output(0)  # (b, d_model, t, 1)

        if self.validate_accuracy:
            self.add_activation_as_output(network, out,
                                          "act.{}.conv2".format(name))

        # Pytorch: output = output.transpose(1, 2)
        trans2 = network.add_shuffle(
            input=out)  # (b, d_model, t, 1) to (b, t, d_model)
        trans2.first_transpose = trt.Permutation([0, 2, 1, 3])
        trans2.reshape_dims = Dims((batch_size, max_seq_len, d_model))
        trans2.name = "{}.trans2".format(name)
        out = trans2.get_output(0)  # (b, t, d_model)

        # Pytorch: output += residual
        residual = network.add_elementwise(input1=seq_tensor,
                                           input2=out,
                                           op=trt.ElementWiseOperation.SUM)
        residual.name = "{}.residual".format(name)
        out = residual.get_output(0)  # (b, t, d_model)

        if self.validate_accuracy:
            self.add_activation_as_output(network, out,
                                          "act.{}.residual".format(name))

        # Pytorch: output = self.layer_norm(output)
        out = self.populate_layernorm(
            name="{}.layer_norm".format(name),
            network=network,
            weights=weights,
            seq_tensor=out,
            batch_size=self.batch_size,
            max_seq_len=max_seq_len,
            d_layer=d_model,
        )  # (b, t, d_model)

        if self.validate_accuracy:
            self.add_activation_as_output(network, out,
                                          "act.{}.ln".format(name))

        return out
    def populate_slf_attn(self, name, network, weights, seq_tensor,
                          seq_mask_tensor, batch_size, max_seq_len, d_model,
                          n_heads, d_k, d_v):
        d_qkv = d_k + d_k + d_v

        # Pytorch: x = self.linear(x)
        w = weights["{}.linear.weight".format(
            name)]  # (n_heads * d_qkv, d_model)
        out_w = network.add_constant(shape=(1, d_model, n_heads * d_qkv),
                                     weights=trt.Weights(w)).get_output(
                                         0)  # (1, n_heads * d_qkv, d_model)
        linear_w = network.add_matrix_multiply(
            seq_tensor, MatrixOperation.NONE, out_w, MatrixOperation.TRANSPOSE
        )  # (b, t, d_model) * (1->b, d_model, n_heads * d_qkv) => (b, t, n_heads * d_qkv)
        linear_w.name = "{}.linear.w".format(name)
        out = linear_w.get_output(0)  # (b, t, n_heads * d_qkv)

        b = weights["{}.linear.bias".format(name)]  # (n_heads * d_qkv,)
        out_b = network.add_constant(shape=(1, 1, n_heads * d_qkv),
                                     weights=trt.Weights(b)).get_output(
                                         0)  # (1, 1, n_heads * d_qkv)
        linear_b = network.add_elementwise(input1=out,
                                           input2=out_b,
                                           op=trt.ElementWiseOperation.SUM)
        linear_b.name = "{}.linear.b".format(name)
        out = linear_b.get_output(0)  # (b, t, n_heads * d_qkv)

        if self.validate_accuracy:
            self.add_activation_as_output(network, out,
                                          "act.{}.linear".format(name))

        trans1 = network.add_shuffle(
            input=out)  # (b, t, n_heads * d_qkv) to (b, n_heads, t, d_qkv)
        trans1.reshape_dims = Dims((batch_size, max_seq_len, n_heads, d_qkv))
        trans1.second_transpose = trt.Permutation([0, 2, 1, 3])
        trans1.name = "{}.trans1".format(name)
        out = trans1.get_output(0)  # (b, n_heads, t, d_qkv)

        # if self.validate_accuracy:
        #     self.add_activation_as_output(network, out, "act.{}.reshape".format(name))

        q = network.add_slice(input=out,
                              start=Dims((0, 0, 0, 0)),
                              shape=Dims(
                                  (batch_size, n_heads, max_seq_len, d_k)),
                              stride=Dims((1, 1, 1, 1)))
        q.name = "{}.slide_q".format(name)

        k = network.add_slice(input=out,
                              start=Dims((0, 0, 0, d_k)),
                              shape=Dims(
                                  (batch_size, n_heads, max_seq_len, d_k)),
                              stride=Dims((1, 1, 1, 1)))
        k.name = "{}.slide_k".format(name)

        v = network.add_slice(input=out,
                              start=Dims((0, 0, 0, 2 * d_k)),
                              shape=Dims(
                                  (batch_size, n_heads, max_seq_len, d_k)),
                              stride=Dims((1, 1, 1, 1)))
        v.name = "{}.slide_v".format(name)

        out_q = q.get_output(0)  # (b, n_heads, t, d_q)
        out_k = k.get_output(0)  # (b, n_heads, t, d_k)
        out_v = v.get_output(0)  # (b, n_heads, t, d_v)

        # Pytorch: output, attn = self.attention(q, k, v, mask=mask)
        out = self.populate_scaled_dot(
            name="{}.scaled_dot".format(name),  # (b, n_heads, t, d_k)
            network=network,
            q_tensor=out_q,
            k_tensor=out_k,
            v_tensor=out_v,
            mask_tensor=seq_mask_tensor,
            batch_size=batch_size,
            max_seq_len=max_seq_len,
            n_heads=n_heads,
            temperature=d_k**0.5)

        # Pytorch:
        # output = output.view(self.n_head, bs, seq_len, self.d_v)
        # output = output.permute(1, 2, 0, 3).contiguous().view(bs, seq_len, self.n_head * self.d_v)
        trans2 = network.add_shuffle(
            input=out)  # b, n_heads, t, d_k) to (b, t, n_heads * d_k)
        trans2.first_transpose = trt.Permutation([0, 2, 1, 3])
        trans2.reshape_dims = Dims((batch_size, max_seq_len, n_heads * d_v))
        trans2.name = "{}.trans2".format(name)
        out = trans2.get_output(0)  # (b, t, n_heads * d_k)

        if self.validate_accuracy:
            self.add_activation_as_output(network, out,
                                          "act.{}.scaled_dot".format(name))

        # Pytorch: output = self.fc(output)
        w = weights["{}.fc.weight".format(name)]  # (d_model, n_heads * d_v)
        out_w = network.add_constant(shape=(1, d_model, n_heads * d_v),
                                     weights=trt.Weights(w)).get_output(
                                         0)  # (1, d_model, n_heads * d_v)
        fc_w = network.add_matrix_multiply(
            out, MatrixOperation.NONE, out_w, MatrixOperation.TRANSPOSE
        )  # (b, t, n_heads * d_k) * (1->b, n_heads * d_k, d_model) => (b, t, d_model)
        fc_w.name = "{}.fc.w".format(name)
        out = fc_w.get_output(0)  # (b, t, d_model)

        b = weights["{}.fc.bias".format(name)]  # (d_model,)
        out_b = network.add_constant(shape=(1, 1, n_heads * d_qkv),
                                     weights=trt.Weights(b)).get_output(
                                         0)  # (1, 1, d_model)
        fc_b = network.add_elementwise(input1=out,
                                       input2=out_b,
                                       op=trt.ElementWiseOperation.SUM)
        fc_b.name = "{}.fc.b".format(name)
        out = fc_b.get_output(0)  # (b, t, d_model)

        # if self.validate_accuracy:
        #     self.add_activation_as_output(network, out, "act.{}.fc".format(name))

        # Pytorch: output += residual
        residual = network.add_elementwise(input1=seq_tensor,
                                           input2=out,
                                           op=ElementWiseOperation.SUM)
        residual.name = "{}.residual".format(name)
        out = residual.get_output(0)  # (b, t, d_model)

        if self.validate_accuracy:
            self.add_activation_as_output(network, out,
                                          "act.{}.residual".format(name))

        # Pytorch: output = self.layer_norm(output)
        out = self.populate_layernorm(
            name="{}.layer_norm".format(name),
            network=network,
            weights=weights,
            seq_tensor=out,
            batch_size=self.batch_size,
            max_seq_len=max_seq_len,
            d_layer=d_model,
        )  # (b, t, d_model)

        if self.validate_accuracy:
            self.add_activation_as_output(network, out,
                                          "act.{}.ln".format(name))

        return out
    def populate_network(self, network, weights, batch_size,
                         trt_max_input_seq_len, trt_max_output_seq_len):
        d_model = self.model.d_model

        ##
        # Inputs
        ##
        out_seq = network.add_input(name="input_seq",
                                    dtype=trt.float32,
                                    shape=(batch_size, trt_max_input_seq_len,
                                           d_model))  # (b, t, d_model)
        #
        zeros = network.add_constant(weights=Weights(
            np.zeros(shape=(batch_size, trt_max_input_seq_len, 1),
                     dtype=np.float32)),
                                     shape=(batch_size, trt_max_input_seq_len,
                                            1))  # (b, t, 1)
        out_zeros = zeros.get_output(0)  # (b, t, 1)
        seq = network.add_elementwise(input1=out_seq,
                                      input2=out_zeros,
                                      op=trt.ElementWiseOperation.SUM)
        out_seq = seq.get_output(0)  # (b, t, d_model)

        if self.validate_accuracy:
            self.add_activation_as_output(network, out_seq, "act.emb")
        #

        out_seq_mask = network.add_input(  # paddings are False
            name="input_mask",
            dtype=trt.bool,
            shape=(batch_size, trt_max_input_seq_len, 1))  # (b, t, 1)

        ##
        # Phoneme-side FFT Blocks
        ##

        # Positional Encoding
        # The plugin adds positional encoding to the padding values also (for better performance), whereas Pytorch impl does not.
        # It's fine because the padding values will be eventually masked out in coming layers, giving accurate output.
        seq = network.add_plugin_v2([out_seq],
                                    self.get_plugin('AddPosEncPlugin'))
        seq.name = "phoneme_side.add_pos_enc"
        out_seq = seq.get_output(0)  # (b, t, d_model)

        if self.validate_accuracy:
            self.add_activation_as_output(network, out_seq,
                                          "act.phoneme_side.add_pos_enc")

        for layer_idx in range(self.model.phoneme_side_n_layer):
            out_seq = self.populate_fft(
                name='phoneme_side.layer_stack.{}'.format(layer_idx),
                network=network,
                weights=weights,
                seq_tensor=out_seq,
                seq_mask_tensor=out_seq_mask,
                batch_size=self.batch_size,
                max_seq_len=trt_max_input_seq_len,
                d_model=d_model,
                n_heads=self.model.phoneme_side_head,
                d_k=self.model.phoneme_side.d_k,
                d_v=self.model.phoneme_side.d_v,
                self_attn_temp=self.model.phoneme_side.d_k**0.5,
                conv_filter_size=self.model.phoneme_side_conv1d_filter_size,
                conv_kernel_size=self.model.fft_conv1d_kernel,
                conv_padding=self.model.fft_conv1d_padding)

        if self.validate_accuracy:
            self.add_activation_as_output(network, out_seq,
                                          "act.phoneme_side.seq")

        out_seq, out_seq_mask, out_dur = self.populate_length_regulator(
            name="length_regulator",
            network=network,
            weights=weights,
            seq_tensor=out_seq,
            seq_mask_tensor=out_seq_mask,
            batch_size=batch_size,
            trt_max_input_seq_len=trt_max_input_seq_len,
            trt_max_output_seq_len=trt_max_output_seq_len,
            d_model=d_model)

        if self.validate_accuracy:
            self.add_activation_as_output(network, out_seq,
                                          "act.length_regulator.seq")
            self.add_activation_as_output(network, out_dur,
                                          "act.length_regulator.dur")

        ##
        # Mel-side FFT Blocks
        ##

        # Type int to bool: out_seq_mask. TODO: remove if bool output is allowed in the plugin.
        ones = network.add_constant(weights=Weights(
            np.ones(shape=(batch_size, trt_max_output_seq_len, 1),
                    dtype=np.int32)),
                                    shape=(batch_size, trt_max_output_seq_len,
                                           1))  # (b, t, 1)
        out_ones = ones.get_output(0)  # (b, t, 1)
        seq_mask = network.add_elementwise(
            input1=out_seq_mask,
            input2=out_ones,
            op=ElementWiseOperation.EQUAL)  # (b, t, 1)
        seq_mask.name = "mel_side.seq_mask"
        out_seq_mask = seq_mask.get_output(0)

        # Positional Encoding
        seq = network.add_plugin_v2([out_seq],
                                    self.get_plugin('AddPosEncPlugin'))
        seq.name = "mel_side.add_pos_enc"
        out_seq = seq.get_output(0)

        if self.validate_accuracy:
            self.add_activation_as_output(network, out_seq,
                                          "act.mel_side.add_pos_enc")

        for layer_idx in range(self.model.mel_side_n_layer):
            out_seq = self.populate_fft(
                name="mel_side.layer_stack.{}".format(layer_idx),
                network=network,
                weights=weights,
                seq_tensor=out_seq,
                seq_mask_tensor=out_seq_mask,
                batch_size=self.batch_size,
                max_seq_len=trt_max_output_seq_len,
                d_model=d_model,
                n_heads=self.model.mel_side_head,
                d_k=self.model.mel_side.d_k,
                d_v=self.model.mel_side.d_v,
                self_attn_temp=self.model.mel_side.d_k**0.5,
                conv_filter_size=self.model.mel_side_conv1d_filter_size,
                conv_kernel_size=self.model.fft_conv1d_kernel,
                conv_padding=self.model.fft_conv1d_padding)

        if self.validate_accuracy:
            self.add_activation_as_output(network, out_seq, "act.mel_side.seq")

        ##
        # Linear
        ##

        # Pytorch: self.mel_linear = nn.Linear(mel_side_output_size, n_mels, bias=True)
        w = weights["mel_linear.weight"]  # (n_mels, d_model)
        out_w = network.add_constant(shape=(1, self.model.n_mels, d_model),
                                     weights=trt.Weights(w)).get_output(
                                         0)  # (1, n_mels, d_model)
        linear_w = network.add_matrix_multiply(
            out_seq, MatrixOperation.NONE, out_w, MatrixOperation.TRANSPOSE
        )  # (b, t, d_model) * (1->b, d_model, n_mels) => (b, t, n_mels)
        linear_w.name = "linear.w"
        out_seq = linear_w.get_output(0)  # (b, t, n_mels)

        b = weights["mel_linear.bias"]  # (n_mels,)
        out_b = network.add_constant(shape=(1, 1, self.model.n_mels),
                                     weights=trt.Weights(b)).get_output(
                                         0)  # (1, 1, n_mels)
        linear_b = network.add_elementwise(input1=out_seq,
                                           input2=out_b,
                                           op=trt.ElementWiseOperation.SUM)
        linear_b.name = "linear.b"
        out_seq = linear_b.get_output(0)  # (b, t, n_mels)

        ##
        # Outputs
        ##

        if self.validate_accuracy:
            self.add_activation_as_output(network, out_seq_mask,
                                          "out.seq_mask")
            self.add_activation_as_output(network, out_seq, "out.seq")

        seq = network.add_shuffle(
            input=out_seq)  # (b, t, n_mels) to (b, n_mels, t)
        seq.reshape_dims = Dims(
            (batch_size, trt_max_output_seq_len, self.model.n_mels))
        seq.second_transpose = trt.Permutation([0, 2, 1])
        seq.name = "trans_seq"
        out_seq = seq.get_output(0)

        seq_mask = network.add_shuffle(
            input=out_seq_mask)  # (b, t, 1) to (b, t)
        seq_mask.reshape_dims = Dims((batch_size, trt_max_output_seq_len))
        out_seq_mask = seq_mask.get_output(0)  # (b, t)

        network.mark_output(tensor=out_seq)  # (b, n_mels, t)
        network.mark_output(tensor=out_seq_mask)  # (b, t)

        return network
Exemplo n.º 12
0
    def initialize(self):
        """Create DLRM network using TRT API and plugins and set the weights."""

        useConvForFC_bottom = (self.precision == "int8")
        useConvForFC_top = (self.precision == "int8")
        interactionsOutputInterleaved = False if self.need_calibration or self.input_dtype != "int8" else True

        # Turn off interleaved format if top_mlp use non-interleaved format
        if not self.enable_interleaved_top_mlp:
            interactionsOutputInterleaved = False
        else:
            print("Using batch-interleaved format for top_mlp.")

        # Check if we should split the model into the binary file with embedding weights quantized and model without embeddings
        if not (os.path.isfile(self.embedding_weights_binary_filepath) and os.path.isfile(self.model_without_embedding_weights_filepath)):
            logging.info("Loading checkpoint from " + self.model_filepath)
            self.weights = torch.load(self.model_filepath, map_location="cpu")["state_dict"]
            self.dump_embedding_weights_to_binary_file()
            logging.info("Writing model without embedding weights to " + self.model_without_embedding_weights_filepath)
            torch.save(self.weights, self.model_without_embedding_weights_filepath)
            del self.weights

        # Dump row frequencies to file in binary format
        if self.use_row_frequencies and not os.path.isfile(self.row_frequencies_binary_filepath):
            logging.info("Writing row frequencies to " + self.row_frequencies_binary_filepath)
            self.dump_row_frequencies_to_binary_file()

        # Load weights
        self.weights = torch.load(self.model_without_embedding_weights_filepath, map_location="cpu")

        # Create network.
        self.network = self.builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))

        # Numerical input
        numerical_input = self.network.add_input("numerical_input", trt.DataType.FLOAT, (-1, self.num_numerical_inputs, 1, 1))
        if not self.need_calibration:
            if self.input_dtype == "int8":
                numerical_input.dtype = trt.int8
            elif self.input_dtype == "fp16":
                numerical_input.dtype = trt.float16
            if self.input_format == "linear":
                numerical_input.allowed_formats = 1 << int(trt.TensorFormat.LINEAR)
            elif self.input_format == "chw4":
                numerical_input.allowed_formats = 1 << int(trt.TensorFormat.CHW4)
            elif self.input_format == "chw32":
                numerical_input.allowed_formats = 1 << int(trt.TensorFormat.CHW32)

        # Bottom MLP
        if self.need_calibration or self.input_dtype != "int8":
            bottom_mlp = self.add_mlp(numerical_input, self.num_numerical_inputs, self.bottom_mlp_channels, self.bottom_mlp_names,
                                      last_relu=True, useConvForFC=useConvForFC_bottom)
        else:
            bottom_mlp_plugin, output_tensor_name = self.add_fused_bottom_mlp("DLRM_BOTTOM_MLP_TRT", numerical_input, self.num_numerical_inputs, self.bottom_mlp_channels, self.bottom_mlp_names)
            bottom_mlp = self.network.add_plugin_v2([numerical_input], bottom_mlp_plugin)
            bottom_mlp.get_output(0).name = output_tensor_name
        bottom_mlp_shuffle = self.network.add_shuffle(bottom_mlp.get_output(0))
        bottom_mlp_shuffle.reshape_dims = trt.Dims((-1, 1, self.embedding_size))

        # Index input
        index_input = self.network.add_input("index_input", trt.DataType.INT32, (-1, self.num_features))

        # Embedding lookup and interactions
        dlrm_interactions_plugin = self.get_dlrm_interactions_plugin("DLRM_INTERACTIONS_TRT", np.cumsum(
            np.array([0] + self.embedding_rows[:-1]).astype(np.int32)).astype(np.int32), interactionsOutputInterleaved)
        interaction_output_concat = self.network.add_plugin_v2([bottom_mlp.get_output(0), index_input], dlrm_interactions_plugin)
        interaction_output_concat.name = "interaction_plugin"
        interaction_output_concat.get_output(0).name = "interaction_output_concat_output"

        if self.enable_interleaved_top_mlp and not interactionsOutputInterleaved:
            # Shuffle from [BS, C, 1, 1] to [BS//2, C, 2, 1] before top_mlp
            interleave_pre_top_mlp = self.network.add_shuffle(interaction_output_concat.get_output(0))
            interleave_pre_top_mlp.reshape_dims = trt.Dims((-1, 2, interaction_output_concat.get_output(0).shape[1], 0))
            interleave_pre_top_mlp.second_transpose = trt.Permutation([0, 2, 1, 3])
            interleave_pre_top_mlp.name = "interleave_pre_top_mlp"

            top_mlp_input = interleave_pre_top_mlp.get_output(0)
            top_mlp_input.name = "interleave_pre_top_mlp"
        else:
            top_mlp_input = interaction_output_concat.get_output(0)

        # Insert small-tile GEMM plugin. The plugin supports Ampere-only.
        gpu_arch = get_system().arch
        system_id = get_system().gpu
        if self.use_small_tile_gemm_plugin:
            if gpu_arch != Architecture.Ampere:
                print("Small-Tile GEMM plugin does not support {}. Plugin disabled.".format(system_id))
                self.use_small_tile_gemm_plugin = False

        # Enable gemm plugin with interleaved format is not recommended.
        # Note (2/7/21): GEMM plugin doesn't perform well when H*W > 1
        if self.use_small_tile_gemm_plugin and self.enable_interleaved_top_mlp:
            print("Warning: small-Tile GEMM plugin performance will be "
                  "significantly impacted by interleaved format. Turn off "
                  "interleaved format for the best performance")

        tmp_mlp_input = top_mlp_input
        tmp_input_size = self.top_mlp_input_size

        # Helper function to check whether the provided shape is supported by
        # Small-Tile GEMM plugin
        def support_small_tile_gemm_func(C, K): return \
            (C >= 256) and (C <= 1280) and (C % 128 == 0) and (K % 128 == 0)

        # Split the top_mlp layers, and use GEMM plugin for 2,4,6
        # C, K for top_mlp.0,2,4,6,8: [480,1024],[1024,1024],[1024,512],[512,256],[256,1]
        for i in range(len(self.top_mlp_channels)):
            # Insert plugin if the layer meets the restriction
            if support_small_tile_gemm_func(tmp_input_size, self.top_mlp_channels[i]) and \
                    self.use_small_tile_gemm_plugin:
                print("Replacing {} with Small-Tile GEMM Plugin, with fairshare cache size {}".
                      format(self.top_mlp_names[i], self.gemm_plugin_fairshare_cache_size))
                layer_top_mlp = self.add_small_tile_gemm_top_mlp(
                    tmp_mlp_input, tmp_input_size,
                    self.top_mlp_channels[i], self.top_mlp_names[i],
                    self.gemm_plugin_fairshare_cache_size
                )
            else:
                layer_top_mlp = self.add_single_mlp(
                    tmp_mlp_input, tmp_input_size,
                    self.top_mlp_channels[i], self.top_mlp_names[i],
                    useConvForFC=useConvForFC_top,
                    add_relu=(i != len(self.top_mlp_channels) - 1))

            tmp_mlp_input = layer_top_mlp.get_output(0)
            tmp_input_size = self.top_mlp_channels[i]

        top_mlp = layer_top_mlp

        if self.enable_interleaved_top_mlp:
            # Shuffle [BS//2, 1, 2, 1] back to [BS, 1, 1, 1]
            interleave_post_top_mlp = self.network.add_shuffle(top_mlp.get_output(0))
            interleave_post_top_mlp.reshape_dims = trt.Dims((-1, 0, 1, 0))
            interleave_post_top_mlp.name = "interleave_post_top_mlp"

            sigmoid_input = interleave_post_top_mlp.get_output(0)
            sigmoid_input.name = "interleave_post_top_mlp"
        else:
            sigmoid_input = top_mlp.get_output(0)

        # Sigmoid
        sigmoid_layer = self.network.add_activation(sigmoid_input, trt.ActivationType.SIGMOID)
        sigmoid_layer.name = "sigmoid"
        sigmoid_layer.get_output(0).name = "sigmoid_output"

        # Output
        self.network.mark_output(sigmoid_layer.get_output(0))

        # Make sure we release the memory to system
        del self.weights

        self.initialized = True