def populate_length_regulator(self, name, network, weights, seq_tensor, seq_mask_tensor, batch_size, trt_max_input_seq_len, trt_max_output_seq_len, d_model):
        out_dur = self.populate_duration_predictor(name="{}.duration_predictor".format(name),
                                                   network=network,
                                                   weights=weights,
                                                   seq_tensor=seq_tensor,
                                                   seq_mask_tensor=seq_mask_tensor,
                                                   batch_size=batch_size,
                                                   max_seq_len=trt_max_input_seq_len,
                                                   d_model=d_model)  # (b, t)

        # Pytorch: output.append(torch.repeat_interleave(input[i], repeats, dim=0))
        seq = network.add_plugin_v2([seq_tensor, out_dur], self.get_plugin('RepeatPlugin'))
        seq.name = "{}.repeat_seq".format(name)
        out_seq = seq.get_output(0)  # (b, t, d), (b, t) => (b, t', d), dtype: float32

        # Type bool to int: seq_mask_tensor. TODO: remove if bool input is allowed in the plugin.
        zeros = network.add_constant(weights=Weights(
            np.zeros(shape=(batch_size, trt_max_input_seq_len, 1), dtype=np.int32)),
            shape=(batch_size, trt_max_input_seq_len, 1))
        out_zeros = zeros.get_output(0)  # (b, t, 1)
        ones = network.add_constant(weights=Weights(
            np.ones(shape=(batch_size, trt_max_input_seq_len, 1), dtype=np.int32)),
            shape=(batch_size, trt_max_input_seq_len, 1))
        out_ones = ones.get_output(0)  # (b, t, 1)
        seq_mask = network.add_select(condition=seq_mask_tensor, then_input=out_ones, else_input=out_zeros)
        seq_mask.name = "{}.seq_mask".format(name)
        out_seq_mask = seq_mask.get_output(0)  # (b, t, 1)

        seq_mask = network.add_plugin_v2([out_seq_mask, out_dur], self.get_plugin('RepeatPlugin'))
        seq_mask.name = "{}.repeat_seq_mask".format(name)
        out_seq_mask = seq_mask.get_output(0)  # (b, t, 1), (b, t) => (b, t', 1), dtype: int32

        return out_seq, out_seq_mask, out_dur
    def populate_scaled_dot(self, name, network, q_tensor, k_tensor, v_tensor, mask_tensor, batch_size, max_seq_len, n_heads, temperature):
        # if self.validate_accuracy:
        #     self.add_activation_as_output(network, q_tensor, "act.{}.q".format(name))
        #     self.add_activation_as_output(network, k_tensor, "act.{}.k".format(name))
        #     self.add_activation_as_output(network, v_tensor, "act.{}.v".format(name))

        # Pytorch: attn = self.bmm1(q, k.transpose(1, 2))
        attn = network.add_matrix_multiply(q_tensor, MatrixOperation.NONE, k_tensor, MatrixOperation.TRANSPOSE)  # (b, n, t, d_k) * (b, n, d_k, t) = (b, n, t, t)
        attn.name = "{}.bmm1".format(name)
        out = attn.get_output(0)

        # if self.validate_accuracy:
        #     self.add_activation_as_output(network, out, "act.{}.bmm1".format(name))

        # Pytorch: attn = attn / self.temperature
        temperature = network.add_constant(weights=Weights(np.full((batch_size, n_heads, max_seq_len, max_seq_len), temperature, dtype=np.float32)),
                                           shape=Dims((batch_size, n_heads, max_seq_len, max_seq_len)))  # (b, n, t, t)
        output_temperature = temperature.get_output(0)

        attn = network.add_elementwise(input1=out, input2=output_temperature, op=ElementWiseOperation.DIV)  # (b, n, t, t)
        attn.name = "{}.div".format(name)
        out = attn.get_output(0)

        # Pytorch: attn = attn.masked_fill(mask, -65504)
        minus_inf = network.add_constant(weights=Weights(np.full((batch_size, n_heads, max_seq_len, max_seq_len), -65504, dtype=np.float32)),
                                       shape=Dims((batch_size, n_heads, max_seq_len, max_seq_len)))  # (b, n, t, t)
        output_minus_inf = minus_inf.get_output(0)
        mask = network.add_shuffle(input=mask_tensor)
        mask.reshape_dims = Dims((batch_size, 1, 1, max_seq_len))  # (b, t, 1) -> (b, 1, 1, t)
        mask.name = "{}.mask_reshape".format(name)
        mask_tensor = mask.get_output(0)
        attn = network.add_select(condition=mask_tensor, # (b, 1->n, 1, t)
                                  then_input=out, # (b, n, t, t)
                                  else_input=output_minus_inf)  # (b, n, t, t)
        attn.name = "{}.mask".format(name)
        out = attn.get_output(0)

        # if self.validate_accuracy:
        #     self.add_activation_as_output(network, out, "act.{}.masked_fill".format(name))

        # Pytorch: attn = self.softmax(attn)
        softmax = network.add_softmax(input=out)
        softmax.axes = (1 << 3)  # dim=3
        softmax.name = "{}.softmax".format(name)
        out = softmax.get_output(0)

        # if self.validate_accuracy:
        #     self.add_activation_as_output(network, out, "act.{}.softmax".format(name))

        # Pytorch: output = self.bmm2(attn, v)
        attn = network.add_matrix_multiply(out, MatrixOperation.NONE, v_tensor, MatrixOperation.NONE)  # (b, n, t, t) * (b, n, t, d_k) => (b, n, t, d_k)
        attn.name = "{}.bmm2".format(name)
        out = attn.get_output(0)

        # if self.validate_accuracy:
        #     self.add_activation_as_output(network, out, "act.{}.bmm2".format(name))

        return out
    def populate_fft(self, name, network, weights, seq_tensor, seq_mask_tensor, batch_size,
                     max_seq_len, d_model, n_heads, d_k, d_v, self_attn_temp,
                     conv_filter_size, conv_kernel_size, conv_padding):
        # Self attn
        out = self.populate_slf_attn("{}.slf_attn".format(name), network, weights, seq_tensor, seq_mask_tensor, batch_size,
                                     max_seq_len, d_model, n_heads, d_k, d_v)  # (b, t, d_model)

        # Masking
        zeros = network.add_constant(weights=Weights(
            np.zeros(shape=(batch_size, max_seq_len, 1), dtype=np.float32)),
            shape=(batch_size, max_seq_len, 1))  # (b, t, 1)
        out_zeros = zeros.get_output(0)  # (b, t, 1)
        seq = network.add_select(condition=seq_mask_tensor, then_input=out, else_input=out_zeros)
        seq.name = "{}.mask1".format(name)
        out = seq.get_output(0)  # (b, t, d_model)

        # Position-wise
        out = self.populate_pos_wise("{}.pos_ffn".format(name), network, weights, out,
                          batch_size, max_seq_len, d_model,
                          conv_filter_size, conv_kernel_size, conv_padding)  # (b, t, d_model)

        # Masking
        seq = network.add_select(condition=seq_mask_tensor, then_input=out, else_input=out_zeros)
        seq.name = "{}.mask2".format(name)
        out = seq.get_output(0)  # (b, t, d_model)

        if self.validate_accuracy:
            self.add_activation_as_output(network, out, "act.{}".format(name))

        return out
Esempio n. 4
0
def _make_implicit_batch_size_tensorrt_model() -> TensorRTModel:
    with Logger() as logger, Builder(logger) as builder, builder.create_network() as network:
        input_x = network.add_input(name='x', dtype=DataType.FLOAT, shape=[4])
        input_y = network.add_input(name='y', dtype=DataType.FLOAT, shape=[4])

        weight = network.add_constant(
            shape=[4],
            weights=Weights(a=numpy.array([2.0, 3.0, 4.0, 5.0], dtype=numpy.float32))
        ).get_output(0)

        output_z = network.add_elementwise(input1=network.add_elementwise(input1=input_x,
                                                                          input2=input_y,
                                                                          op=ElementWiseOperation.SUM).get_output(0),
                                           input2=weight,
                                           op=ElementWiseOperation.SUM).get_output(0)

        output_z.name = 'z'

        network.mark_output(tensor=output_z)

        return TensorRTModel(cuda_engine=builder.build_cuda_engine(network), input_data_formats=[None, None])
    def populate_layernorm(self, name, network, weights, seq_tensor,
                           batch_size, max_seq_len, d_layer):
        # m
        mean = network.add_reduce(input=seq_tensor,
                                  op=trt.ReduceOperation.AVG,
                                  axes=(1 << 2),
                                  keep_dims=True)
        mean.name = "{}.mean".format(name)
        out_mean = mean.get_output(0)  # (b, t, 1)

        # m^2
        square_mean = network.add_elementwise(input1=out_mean,
                                              input2=out_mean,
                                              op=ElementWiseOperation.PROD)
        square_mean.name = "{}.square_mean".format(name)
        out_square_mean = square_mean.get_output(0)  # (b, t, 1)

        # x^2
        square = network.add_elementwise(input1=seq_tensor,
                                         input2=seq_tensor,
                                         op=ElementWiseOperation.PROD)
        square.name = "{}.square".format(name)
        out_square = square.get_output(0)  # (b, t, h)

        # e[x^2]
        mean_square = network.add_reduce(input=out_square,
                                         op=trt.ReduceOperation.AVG,
                                         axes=(1 << 2),
                                         keep_dims=True)
        mean_square.name = "{}.mean_square".format(name)
        out_mean_square = mean_square.get_output(0)  # (b, t, 1)

        # e[x^2] - m^2
        sub_square = network.add_elementwise(input1=out_mean_square,
                                             input2=out_square_mean,
                                             op=ElementWiseOperation.SUB)
        sub_square.name = "{}.sub_square".format(name)
        out_sub_square = sub_square.get_output(0)  # (b, t, 1)

        # + eps
        eps = network.add_constant(weights=Weights(
            np.full((batch_size, max_seq_len, 1), 1e-5, dtype=np.float32)),
                                   shape=Dims((batch_size, max_seq_len,
                                               1)))  # (b, t, 1)
        out_eps = eps.get_output(0)
        eps.name = "{}.eps".format(name)
        std = network.add_elementwise(input1=out_sub_square,
                                      input2=out_eps,
                                      op=ElementWiseOperation.SUM)
        std.name = "{}.std".format(name)
        out_std = std.get_output(0)  # (b, t, 1)

        # std
        sqrt = network.add_unary(input=out_std, op=trt.UnaryOperation.SQRT)
        sqrt.name = "{}.sqrt".format(name)
        out_sqrt = sqrt.get_output(0)  # (b, t, 1)

        # y = (x - mean) / std
        sub = network.add_elementwise(input1=seq_tensor,
                                      input2=out_mean,
                                      op=ElementWiseOperation.SUB)
        sub.name = "{}.sub".format(name)
        out_sub_square = sub.get_output(0)  # (b, t, h)

        div = network.add_elementwise(input1=out_sub_square,
                                      input2=out_sqrt,
                                      op=ElementWiseOperation.DIV)
        div.name = "{}.div".format(name)
        out = div.get_output(0)  # (b, t, h)

        # Pytorch: y = self.weight * y + self.bias
        w = weights["{}.weight".format(name)]  # (h, )
        out_w = network.add_constant(shape=(1, 1, d_layer),
                                     weights=trt.Weights(w)).get_output(
                                         0)  # (1, 1, h)
        scale_w = network.add_elementwise(
            input1=out, input2=out_w, op=ElementWiseOperation.PROD
        )  # (b, t, h) * (1->b, 1->t, h) => (b, t, h)
        scale_w.name = "{}.scale.w".format(name)
        out = scale_w.get_output(0)  # (b, t, h)

        b = weights["{}.bias".format(name)]  # (h, )
        out_b = network.add_constant(shape=(1, 1, d_layer),
                                     weights=trt.Weights(b)).get_output(
                                         0)  # (1, 1, h)
        scale_b = network.add_elementwise(
            input1=out, input2=out_b, op=ElementWiseOperation.SUM
        )  # (b, t, h) * (1->b, 1->t, h) => (b, t, h)
        scale_b.name = "{}.scale.b".format(name)
        out = scale_b.get_output(0)  # (b, t, h)

        return out
    def populate_duration_predictor(self, name, network, weights, seq_tensor,
                                    seq_mask_tensor, batch_size, max_seq_len,
                                    d_model):
        duration_predictor_filter_size = self.model.duration_predictor_filter_size
        duration_predictor_kernel_size = self.model.duration_predictor_kernel_size

        # Pytorch: input *= input_mask.to(input.dtype)
        # can be skipped.

        # Pytorch: out = self.conv1d_1(input.transpose(1,2)).transpose(1,2)
        trans1 = network.add_shuffle(
            input=seq_tensor)  # (b, t, d_model) to  (b, d_model, t, 1)
        trans1.first_transpose = trt.Permutation([0, 2, 1])
        trans1.reshape_dims = Dims((batch_size, d_model, max_seq_len, 1))
        trans1.name = "{}.trans1".format(name)
        out = trans1.get_output(0)  # (b, d_model, t, 1)

        conv1_w = weights["{}.conv1d_1.weight".format(
            name
        )]  # (1, d_model, duration_predictor_filter_size, duration_predictor_kernel_size, 1)
        conv1_b = weights["{}.conv1d_1.bias".format(
            name)]  # (duration_predictor_filter_size, )
        conv1 = network.add_convolution(
            input=out,
            num_output_maps=duration_predictor_filter_size,
            kernel_shape=trt.DimsHW(duration_predictor_kernel_size, 1),
            kernel=Weights(conv1_w),
            bias=Weights(conv1_b))
        conv1.padding = trt.DimsHW(1, 0)
        conv1.name = "{}.conv1".format(name)
        out = conv1.get_output(0)  # (b, duration_predictor_filter_size, t, 1)

        trans2 = network.add_shuffle(
            input=out
        )  # (b, duration_predictor_filter_size, t, 1) to (b, t, duration_predictor_filter_size)
        trans2.first_transpose = trt.Permutation([0, 2, 1, 3])
        trans2.reshape_dims = Dims(
            (batch_size, max_seq_len, duration_predictor_filter_size))
        trans2.name = "{}.trans2".format(name)
        out = trans2.get_output(0)  # (b, t, duration_predictor_filter_size)

        # Pytorch: out = self.relu_1(out)
        relu = network.add_activation(input=out, type=trt.ActivationType.RELU)
        relu.name = "{}.relu1".format(name)
        out_relu = relu.get_output(0)  # (b, t, duration_predictor_filter_size)

        # Pytorch: out = self.layer_norm_1(out)
        out = self.populate_layernorm(name="{}.layer_norm_1".format(name),
                                      network=network,
                                      weights=weights,
                                      seq_tensor=out_relu,
                                      d_layer=duration_predictor_filter_size,
                                      batch_size=batch_size,
                                      max_seq_len=max_seq_len)

        # Pytorch: out = self.conv1d_2(out.transpose(1,2)).transpose(1,2)
        trans3 = network.add_shuffle(
            input=out
        )  # (b, t, duration_predictor_filter_size) to (b, duration_predictor_filter_size, t, 1)
        trans3.first_transpose = trt.Permutation([0, 2, 1])
        trans3.reshape_dims = Dims(
            (batch_size, duration_predictor_filter_size, max_seq_len, 1))
        trans3.name = "{}.trans3".format(name)
        out = trans3.get_output(0)  # (b, duration_predictor_filter_size, t, 1)

        conv2_w = weights["{}.conv1d_2.weight".format(
            name
        )]  # (1, duration_predictor_filter_size, duration_predictor_filter_size, duration_predictor_kernel_size, 1)
        conv2_b = weights["{}.conv1d_2.bias".format(
            name)]  # (duration_predictor_filter_size, )
        conv2 = network.add_convolution(
            input=out,
            num_output_maps=duration_predictor_filter_size,
            kernel_shape=trt.DimsHW(duration_predictor_kernel_size, 1),
            kernel=Weights(conv2_w),
            bias=Weights(conv2_b))
        conv2.padding = trt.DimsHW(1, 0)
        conv2.name = "{}.conv2".format(name)
        out = conv2.get_output(0)

        trans4 = network.add_shuffle(
            input=out
        )  # (b, duration_predictor_filter_size, t, 1) to (b, t, duration_predictor_filter_size)
        trans4.first_transpose = trt.Permutation([0, 2, 1, 3])
        trans4.reshape_dims = Dims(
            (batch_size, max_seq_len, duration_predictor_filter_size))
        trans4.name = "{}.trans4".format(name)
        out = trans4.get_output(0)  # (b, t, duration_predictor_filter_size)

        # Pytorch: out = self.relu_2(out)
        relu = network.add_activation(input=out, type=trt.ActivationType.RELU)
        relu.name = "{}.relu2".format(name)
        out_relu = relu.get_output(0)  # (b, t, duration_predictor_filter_size)

        # Pytorch: out = self.layer_norm_2(out)
        out = self.populate_layernorm(
            name="{}.layer_norm_2".format(name),
            network=network,
            weights=weights,
            seq_tensor=out_relu,
            d_layer=duration_predictor_filter_size,
            batch_size=batch_size,
            max_seq_len=max_seq_len,
        )  # (b, t, duration_predictor_filter_size)

        # Pytorch: out = self.linear_layer(out)
        w = weights["{}.linear_layer.weight".format(
            name)]  # (1, duration_predictor_filter_size)
        out_w = network.add_constant(
            shape=(1, 1, duration_predictor_filter_size),
            weights=trt.Weights(w)).get_output(
                0)  # (1, 1, duration_predictor_filter_size)
        linear_w = network.add_matrix_multiply(
            out, MatrixOperation.NONE, out_w, MatrixOperation.TRANSPOSE
        )  # (b, t, duration_predictor_filter_size) * (1->b, duration_predictor_filter_size, 1) => (b, t, 1)
        linear_w.name = "{}.linear.w".format(name)
        out = linear_w.get_output(0)  # (b, t, 1)

        b = weights["{}.linear_layer.bias".format(name)]  # (1,)
        out_b = network.add_constant(
            shape=(1, 1, 1), weights=trt.Weights(b)).get_output(0)  # (1, 1, 1)
        linear_b = network.add_elementwise(input1=out,
                                           input2=out_b,
                                           op=trt.ElementWiseOperation.SUM)
        linear_b.name = "{}.linear.b".format(name)
        out = linear_b.get_output(0)  # (b, t, 1)

        # Pytorch: out *= input_mask.to(out.dtype)
        zeros = network.add_constant(weights=Weights(
            np.zeros(shape=(batch_size, max_seq_len, 1), dtype=np.float32)),
                                     shape=(batch_size, max_seq_len, 1))
        out_zeros = zeros.get_output(0)  # (b, t, 1)
        dur = network.add_select(condition=seq_mask_tensor,
                                 then_input=out,
                                 else_input=out_zeros)
        dur.name = "{}.mask".format(name)
        out_dur = dur.get_output(0)

        # Pytorch: duration = torch.clamp_min(torch.exp(duration) - 1, 0)
        exp = network.add_unary(input=out_dur, op=trt.UnaryOperation.EXP)
        exp.name = "{}.exp".format(name)
        out_exp = exp.get_output(0)
        ones = network.add_constant(weights=Weights(
            np.ones(shape=(batch_size, max_seq_len, 1), dtype=np.float32)),
                                    shape=(batch_size, max_seq_len, 1))
        out_ones = ones.get_output(0)  # (b, t, 1)
        sub = network.add_elementwise(input1=out_exp,
                                      input2=out_ones,
                                      op=trt.ElementWiseOperation.SUB)
        sub.name = "{}.sub_one".format(name)
        out_sub = sub.get_output(0)
        dur = network.add_elementwise(input1=out_sub,
                                      input2=out_zeros,
                                      op=trt.ElementWiseOperation.MAX)
        dur.name = "{}.max".format(name)
        out_dur = dur.get_output(0)

        # Pytorch: repeats = torch.round(repeats).long()
        half_ones = network.add_constant(weights=Weights(
            np.full((batch_size, max_seq_len, 1), 0.5, dtype=np.float32)),
                                         shape=(batch_size, max_seq_len, 1))
        out_half_ones = half_ones.get_output(0)  # (b, t, 1)
        add = network.add_elementwise(input1=out_dur,
                                      input2=out_half_ones,
                                      op=trt.ElementWiseOperation.SUM)
        add.name = "{}.round_add".format(name)
        out_add = add.get_output(0)  # (b, t, 1)
        dur = network.add_elementwise(input1=out_add,
                                      input2=out_ones,
                                      op=trt.ElementWiseOperation.FLOOR_DIV)
        dur.name = "{}.round_floor_div".format(name)
        out_dur = dur.get_output(0)  # (b, t, 1)

        dur = network.add_shuffle(input=out_dur)  # (b, t, 1) to (b, t)
        dur.reshape_dims = Dims(shape=(batch_size, max_seq_len))
        out_dur = dur.get_output(0)  # (b, t)

        return out_dur
    def populate_pos_wise(self, name, network, weights, seq_tensor, batch_size,
                          max_seq_len, d_model, conv_filter_size,
                          conv_kernel_size, conv_padding):
        # Pytorch: output = x.transpose(1, 2)
        trans1 = network.add_shuffle(
            input=seq_tensor)  # (b, t, d_model) to (b, d_model, t, 1)
        trans1.first_transpose = trt.Permutation([0, 2, 1])
        trans1.reshape_dims = Dims((batch_size, d_model, max_seq_len, 1))
        trans1.name = "{}.trans1".format(name)
        out = trans1.get_output(0)  # (b, d_model, t, 1)

        # Pytorch: output = self.w_1(output)
        conv1_w = weights["{}.w_1.weight".format(
            name)]  # (1, conv_filter_size, d_model, conv_kernel_size, 1)
        conv1_b = weights["{}.w_1.bias".format(name)]  # (cov_filter_size,)
        conv1 = network.add_convolution(input=out,
                                        num_output_maps=conv_filter_size,
                                        kernel_shape=trt.DimsHW(
                                            conv_kernel_size, 1),
                                        kernel=Weights(conv1_w),
                                        bias=Weights(conv1_b))
        conv1.padding = trt.DimsHW(1, 0)
        conv1.name = "{}.conv1".format(name)
        out = conv1.get_output(0)  # (b, conv_filter_size, t, 1)

        if self.validate_accuracy:
            self.add_activation_as_output(network, out,
                                          "act.{}.conv1".format(name))

        # Pytorch: output = F.relu(output)
        relu = network.add_activation(input=out, type=trt.ActivationType.RELU)
        relu.name = "{}.relu".format(name)
        out = relu.get_output(0)  # (b, conv_filter_size, t, 1)

        # Pytorch: output = self.w_2(output)
        conv2_w = weights["{}.w_2.weight".format(
            name)]  # (1, d_model, conv_filter_size, conv_kernel_size, 1)
        conv2_b = weights["{}.w_2.bias".format(name)]  # (d_model, )
        conv2 = network.add_convolution(input=out,
                                        num_output_maps=d_model,
                                        kernel_shape=trt.DimsHW(
                                            conv_kernel_size, 1),
                                        kernel=Weights(conv2_w),
                                        bias=Weights(conv2_b))
        conv2.padding = trt.DimsHW(1, 0)
        conv2.name = "{}.conv2".format(name)
        out = conv2.get_output(0)  # (b, d_model, t, 1)

        if self.validate_accuracy:
            self.add_activation_as_output(network, out,
                                          "act.{}.conv2".format(name))

        # Pytorch: output = output.transpose(1, 2)
        trans2 = network.add_shuffle(
            input=out)  # (b, d_model, t, 1) to (b, t, d_model)
        trans2.first_transpose = trt.Permutation([0, 2, 1, 3])
        trans2.reshape_dims = Dims((batch_size, max_seq_len, d_model))
        trans2.name = "{}.trans2".format(name)
        out = trans2.get_output(0)  # (b, t, d_model)

        # Pytorch: output += residual
        residual = network.add_elementwise(input1=seq_tensor,
                                           input2=out,
                                           op=trt.ElementWiseOperation.SUM)
        residual.name = "{}.residual".format(name)
        out = residual.get_output(0)  # (b, t, d_model)

        if self.validate_accuracy:
            self.add_activation_as_output(network, out,
                                          "act.{}.residual".format(name))

        # Pytorch: output = self.layer_norm(output)
        out = self.populate_layernorm(
            name="{}.layer_norm".format(name),
            network=network,
            weights=weights,
            seq_tensor=out,
            batch_size=self.batch_size,
            max_seq_len=max_seq_len,
            d_layer=d_model,
        )  # (b, t, d_model)

        if self.validate_accuracy:
            self.add_activation_as_output(network, out,
                                          "act.{}.ln".format(name))

        return out
    def populate_network(self, network, weights, batch_size,
                         trt_max_input_seq_len, trt_max_output_seq_len):
        d_model = self.model.d_model

        ##
        # Inputs
        ##
        out_seq = network.add_input(name="input_seq",
                                    dtype=trt.float32,
                                    shape=(batch_size, trt_max_input_seq_len,
                                           d_model))  # (b, t, d_model)
        #
        zeros = network.add_constant(weights=Weights(
            np.zeros(shape=(batch_size, trt_max_input_seq_len, 1),
                     dtype=np.float32)),
                                     shape=(batch_size, trt_max_input_seq_len,
                                            1))  # (b, t, 1)
        out_zeros = zeros.get_output(0)  # (b, t, 1)
        seq = network.add_elementwise(input1=out_seq,
                                      input2=out_zeros,
                                      op=trt.ElementWiseOperation.SUM)
        out_seq = seq.get_output(0)  # (b, t, d_model)

        if self.validate_accuracy:
            self.add_activation_as_output(network, out_seq, "act.emb")
        #

        out_seq_mask = network.add_input(  # paddings are False
            name="input_mask",
            dtype=trt.bool,
            shape=(batch_size, trt_max_input_seq_len, 1))  # (b, t, 1)

        ##
        # Phoneme-side FFT Blocks
        ##

        # Positional Encoding
        # The plugin adds positional encoding to the padding values also (for better performance), whereas Pytorch impl does not.
        # It's fine because the padding values will be eventually masked out in coming layers, giving accurate output.
        seq = network.add_plugin_v2([out_seq],
                                    self.get_plugin('AddPosEncPlugin'))
        seq.name = "phoneme_side.add_pos_enc"
        out_seq = seq.get_output(0)  # (b, t, d_model)

        if self.validate_accuracy:
            self.add_activation_as_output(network, out_seq,
                                          "act.phoneme_side.add_pos_enc")

        for layer_idx in range(self.model.phoneme_side_n_layer):
            out_seq = self.populate_fft(
                name='phoneme_side.layer_stack.{}'.format(layer_idx),
                network=network,
                weights=weights,
                seq_tensor=out_seq,
                seq_mask_tensor=out_seq_mask,
                batch_size=self.batch_size,
                max_seq_len=trt_max_input_seq_len,
                d_model=d_model,
                n_heads=self.model.phoneme_side_head,
                d_k=self.model.phoneme_side.d_k,
                d_v=self.model.phoneme_side.d_v,
                self_attn_temp=self.model.phoneme_side.d_k**0.5,
                conv_filter_size=self.model.phoneme_side_conv1d_filter_size,
                conv_kernel_size=self.model.fft_conv1d_kernel,
                conv_padding=self.model.fft_conv1d_padding)

        if self.validate_accuracy:
            self.add_activation_as_output(network, out_seq,
                                          "act.phoneme_side.seq")

        out_seq, out_seq_mask, out_dur = self.populate_length_regulator(
            name="length_regulator",
            network=network,
            weights=weights,
            seq_tensor=out_seq,
            seq_mask_tensor=out_seq_mask,
            batch_size=batch_size,
            trt_max_input_seq_len=trt_max_input_seq_len,
            trt_max_output_seq_len=trt_max_output_seq_len,
            d_model=d_model)

        if self.validate_accuracy:
            self.add_activation_as_output(network, out_seq,
                                          "act.length_regulator.seq")
            self.add_activation_as_output(network, out_dur,
                                          "act.length_regulator.dur")

        ##
        # Mel-side FFT Blocks
        ##

        # Type int to bool: out_seq_mask. TODO: remove if bool output is allowed in the plugin.
        ones = network.add_constant(weights=Weights(
            np.ones(shape=(batch_size, trt_max_output_seq_len, 1),
                    dtype=np.int32)),
                                    shape=(batch_size, trt_max_output_seq_len,
                                           1))  # (b, t, 1)
        out_ones = ones.get_output(0)  # (b, t, 1)
        seq_mask = network.add_elementwise(
            input1=out_seq_mask,
            input2=out_ones,
            op=ElementWiseOperation.EQUAL)  # (b, t, 1)
        seq_mask.name = "mel_side.seq_mask"
        out_seq_mask = seq_mask.get_output(0)

        # Positional Encoding
        seq = network.add_plugin_v2([out_seq],
                                    self.get_plugin('AddPosEncPlugin'))
        seq.name = "mel_side.add_pos_enc"
        out_seq = seq.get_output(0)

        if self.validate_accuracy:
            self.add_activation_as_output(network, out_seq,
                                          "act.mel_side.add_pos_enc")

        for layer_idx in range(self.model.mel_side_n_layer):
            out_seq = self.populate_fft(
                name="mel_side.layer_stack.{}".format(layer_idx),
                network=network,
                weights=weights,
                seq_tensor=out_seq,
                seq_mask_tensor=out_seq_mask,
                batch_size=self.batch_size,
                max_seq_len=trt_max_output_seq_len,
                d_model=d_model,
                n_heads=self.model.mel_side_head,
                d_k=self.model.mel_side.d_k,
                d_v=self.model.mel_side.d_v,
                self_attn_temp=self.model.mel_side.d_k**0.5,
                conv_filter_size=self.model.mel_side_conv1d_filter_size,
                conv_kernel_size=self.model.fft_conv1d_kernel,
                conv_padding=self.model.fft_conv1d_padding)

        if self.validate_accuracy:
            self.add_activation_as_output(network, out_seq, "act.mel_side.seq")

        ##
        # Linear
        ##

        # Pytorch: self.mel_linear = nn.Linear(mel_side_output_size, n_mels, bias=True)
        w = weights["mel_linear.weight"]  # (n_mels, d_model)
        out_w = network.add_constant(shape=(1, self.model.n_mels, d_model),
                                     weights=trt.Weights(w)).get_output(
                                         0)  # (1, n_mels, d_model)
        linear_w = network.add_matrix_multiply(
            out_seq, MatrixOperation.NONE, out_w, MatrixOperation.TRANSPOSE
        )  # (b, t, d_model) * (1->b, d_model, n_mels) => (b, t, n_mels)
        linear_w.name = "linear.w"
        out_seq = linear_w.get_output(0)  # (b, t, n_mels)

        b = weights["mel_linear.bias"]  # (n_mels,)
        out_b = network.add_constant(shape=(1, 1, self.model.n_mels),
                                     weights=trt.Weights(b)).get_output(
                                         0)  # (1, 1, n_mels)
        linear_b = network.add_elementwise(input1=out_seq,
                                           input2=out_b,
                                           op=trt.ElementWiseOperation.SUM)
        linear_b.name = "linear.b"
        out_seq = linear_b.get_output(0)  # (b, t, n_mels)

        ##
        # Outputs
        ##

        if self.validate_accuracy:
            self.add_activation_as_output(network, out_seq_mask,
                                          "out.seq_mask")
            self.add_activation_as_output(network, out_seq, "out.seq")

        seq = network.add_shuffle(
            input=out_seq)  # (b, t, n_mels) to (b, n_mels, t)
        seq.reshape_dims = Dims(
            (batch_size, trt_max_output_seq_len, self.model.n_mels))
        seq.second_transpose = trt.Permutation([0, 2, 1])
        seq.name = "trans_seq"
        out_seq = seq.get_output(0)

        seq_mask = network.add_shuffle(
            input=out_seq_mask)  # (b, t, 1) to (b, t)
        seq_mask.reshape_dims = Dims((batch_size, trt_max_output_seq_len))
        out_seq_mask = seq_mask.get_output(0)  # (b, t)

        network.mark_output(tensor=out_seq)  # (b, n_mels, t)
        network.mark_output(tensor=out_seq_mask)  # (b, t)

        return network