def populate_length_regulator(self, name, network, weights, seq_tensor, seq_mask_tensor, batch_size, trt_max_input_seq_len, trt_max_output_seq_len, d_model): out_dur = self.populate_duration_predictor(name="{}.duration_predictor".format(name), network=network, weights=weights, seq_tensor=seq_tensor, seq_mask_tensor=seq_mask_tensor, batch_size=batch_size, max_seq_len=trt_max_input_seq_len, d_model=d_model) # (b, t) # Pytorch: output.append(torch.repeat_interleave(input[i], repeats, dim=0)) seq = network.add_plugin_v2([seq_tensor, out_dur], self.get_plugin('RepeatPlugin')) seq.name = "{}.repeat_seq".format(name) out_seq = seq.get_output(0) # (b, t, d), (b, t) => (b, t', d), dtype: float32 # Type bool to int: seq_mask_tensor. TODO: remove if bool input is allowed in the plugin. zeros = network.add_constant(weights=Weights( np.zeros(shape=(batch_size, trt_max_input_seq_len, 1), dtype=np.int32)), shape=(batch_size, trt_max_input_seq_len, 1)) out_zeros = zeros.get_output(0) # (b, t, 1) ones = network.add_constant(weights=Weights( np.ones(shape=(batch_size, trt_max_input_seq_len, 1), dtype=np.int32)), shape=(batch_size, trt_max_input_seq_len, 1)) out_ones = ones.get_output(0) # (b, t, 1) seq_mask = network.add_select(condition=seq_mask_tensor, then_input=out_ones, else_input=out_zeros) seq_mask.name = "{}.seq_mask".format(name) out_seq_mask = seq_mask.get_output(0) # (b, t, 1) seq_mask = network.add_plugin_v2([out_seq_mask, out_dur], self.get_plugin('RepeatPlugin')) seq_mask.name = "{}.repeat_seq_mask".format(name) out_seq_mask = seq_mask.get_output(0) # (b, t, 1), (b, t) => (b, t', 1), dtype: int32 return out_seq, out_seq_mask, out_dur
def populate_scaled_dot(self, name, network, q_tensor, k_tensor, v_tensor, mask_tensor, batch_size, max_seq_len, n_heads, temperature): # if self.validate_accuracy: # self.add_activation_as_output(network, q_tensor, "act.{}.q".format(name)) # self.add_activation_as_output(network, k_tensor, "act.{}.k".format(name)) # self.add_activation_as_output(network, v_tensor, "act.{}.v".format(name)) # Pytorch: attn = self.bmm1(q, k.transpose(1, 2)) attn = network.add_matrix_multiply(q_tensor, MatrixOperation.NONE, k_tensor, MatrixOperation.TRANSPOSE) # (b, n, t, d_k) * (b, n, d_k, t) = (b, n, t, t) attn.name = "{}.bmm1".format(name) out = attn.get_output(0) # if self.validate_accuracy: # self.add_activation_as_output(network, out, "act.{}.bmm1".format(name)) # Pytorch: attn = attn / self.temperature temperature = network.add_constant(weights=Weights(np.full((batch_size, n_heads, max_seq_len, max_seq_len), temperature, dtype=np.float32)), shape=Dims((batch_size, n_heads, max_seq_len, max_seq_len))) # (b, n, t, t) output_temperature = temperature.get_output(0) attn = network.add_elementwise(input1=out, input2=output_temperature, op=ElementWiseOperation.DIV) # (b, n, t, t) attn.name = "{}.div".format(name) out = attn.get_output(0) # Pytorch: attn = attn.masked_fill(mask, -65504) minus_inf = network.add_constant(weights=Weights(np.full((batch_size, n_heads, max_seq_len, max_seq_len), -65504, dtype=np.float32)), shape=Dims((batch_size, n_heads, max_seq_len, max_seq_len))) # (b, n, t, t) output_minus_inf = minus_inf.get_output(0) mask = network.add_shuffle(input=mask_tensor) mask.reshape_dims = Dims((batch_size, 1, 1, max_seq_len)) # (b, t, 1) -> (b, 1, 1, t) mask.name = "{}.mask_reshape".format(name) mask_tensor = mask.get_output(0) attn = network.add_select(condition=mask_tensor, # (b, 1->n, 1, t) then_input=out, # (b, n, t, t) else_input=output_minus_inf) # (b, n, t, t) attn.name = "{}.mask".format(name) out = attn.get_output(0) # if self.validate_accuracy: # self.add_activation_as_output(network, out, "act.{}.masked_fill".format(name)) # Pytorch: attn = self.softmax(attn) softmax = network.add_softmax(input=out) softmax.axes = (1 << 3) # dim=3 softmax.name = "{}.softmax".format(name) out = softmax.get_output(0) # if self.validate_accuracy: # self.add_activation_as_output(network, out, "act.{}.softmax".format(name)) # Pytorch: output = self.bmm2(attn, v) attn = network.add_matrix_multiply(out, MatrixOperation.NONE, v_tensor, MatrixOperation.NONE) # (b, n, t, t) * (b, n, t, d_k) => (b, n, t, d_k) attn.name = "{}.bmm2".format(name) out = attn.get_output(0) # if self.validate_accuracy: # self.add_activation_as_output(network, out, "act.{}.bmm2".format(name)) return out
def populate_fft(self, name, network, weights, seq_tensor, seq_mask_tensor, batch_size, max_seq_len, d_model, n_heads, d_k, d_v, self_attn_temp, conv_filter_size, conv_kernel_size, conv_padding): # Self attn out = self.populate_slf_attn("{}.slf_attn".format(name), network, weights, seq_tensor, seq_mask_tensor, batch_size, max_seq_len, d_model, n_heads, d_k, d_v) # (b, t, d_model) # Masking zeros = network.add_constant(weights=Weights( np.zeros(shape=(batch_size, max_seq_len, 1), dtype=np.float32)), shape=(batch_size, max_seq_len, 1)) # (b, t, 1) out_zeros = zeros.get_output(0) # (b, t, 1) seq = network.add_select(condition=seq_mask_tensor, then_input=out, else_input=out_zeros) seq.name = "{}.mask1".format(name) out = seq.get_output(0) # (b, t, d_model) # Position-wise out = self.populate_pos_wise("{}.pos_ffn".format(name), network, weights, out, batch_size, max_seq_len, d_model, conv_filter_size, conv_kernel_size, conv_padding) # (b, t, d_model) # Masking seq = network.add_select(condition=seq_mask_tensor, then_input=out, else_input=out_zeros) seq.name = "{}.mask2".format(name) out = seq.get_output(0) # (b, t, d_model) if self.validate_accuracy: self.add_activation_as_output(network, out, "act.{}".format(name)) return out
def _make_implicit_batch_size_tensorrt_model() -> TensorRTModel: with Logger() as logger, Builder(logger) as builder, builder.create_network() as network: input_x = network.add_input(name='x', dtype=DataType.FLOAT, shape=[4]) input_y = network.add_input(name='y', dtype=DataType.FLOAT, shape=[4]) weight = network.add_constant( shape=[4], weights=Weights(a=numpy.array([2.0, 3.0, 4.0, 5.0], dtype=numpy.float32)) ).get_output(0) output_z = network.add_elementwise(input1=network.add_elementwise(input1=input_x, input2=input_y, op=ElementWiseOperation.SUM).get_output(0), input2=weight, op=ElementWiseOperation.SUM).get_output(0) output_z.name = 'z' network.mark_output(tensor=output_z) return TensorRTModel(cuda_engine=builder.build_cuda_engine(network), input_data_formats=[None, None])
def populate_layernorm(self, name, network, weights, seq_tensor, batch_size, max_seq_len, d_layer): # m mean = network.add_reduce(input=seq_tensor, op=trt.ReduceOperation.AVG, axes=(1 << 2), keep_dims=True) mean.name = "{}.mean".format(name) out_mean = mean.get_output(0) # (b, t, 1) # m^2 square_mean = network.add_elementwise(input1=out_mean, input2=out_mean, op=ElementWiseOperation.PROD) square_mean.name = "{}.square_mean".format(name) out_square_mean = square_mean.get_output(0) # (b, t, 1) # x^2 square = network.add_elementwise(input1=seq_tensor, input2=seq_tensor, op=ElementWiseOperation.PROD) square.name = "{}.square".format(name) out_square = square.get_output(0) # (b, t, h) # e[x^2] mean_square = network.add_reduce(input=out_square, op=trt.ReduceOperation.AVG, axes=(1 << 2), keep_dims=True) mean_square.name = "{}.mean_square".format(name) out_mean_square = mean_square.get_output(0) # (b, t, 1) # e[x^2] - m^2 sub_square = network.add_elementwise(input1=out_mean_square, input2=out_square_mean, op=ElementWiseOperation.SUB) sub_square.name = "{}.sub_square".format(name) out_sub_square = sub_square.get_output(0) # (b, t, 1) # + eps eps = network.add_constant(weights=Weights( np.full((batch_size, max_seq_len, 1), 1e-5, dtype=np.float32)), shape=Dims((batch_size, max_seq_len, 1))) # (b, t, 1) out_eps = eps.get_output(0) eps.name = "{}.eps".format(name) std = network.add_elementwise(input1=out_sub_square, input2=out_eps, op=ElementWiseOperation.SUM) std.name = "{}.std".format(name) out_std = std.get_output(0) # (b, t, 1) # std sqrt = network.add_unary(input=out_std, op=trt.UnaryOperation.SQRT) sqrt.name = "{}.sqrt".format(name) out_sqrt = sqrt.get_output(0) # (b, t, 1) # y = (x - mean) / std sub = network.add_elementwise(input1=seq_tensor, input2=out_mean, op=ElementWiseOperation.SUB) sub.name = "{}.sub".format(name) out_sub_square = sub.get_output(0) # (b, t, h) div = network.add_elementwise(input1=out_sub_square, input2=out_sqrt, op=ElementWiseOperation.DIV) div.name = "{}.div".format(name) out = div.get_output(0) # (b, t, h) # Pytorch: y = self.weight * y + self.bias w = weights["{}.weight".format(name)] # (h, ) out_w = network.add_constant(shape=(1, 1, d_layer), weights=trt.Weights(w)).get_output( 0) # (1, 1, h) scale_w = network.add_elementwise( input1=out, input2=out_w, op=ElementWiseOperation.PROD ) # (b, t, h) * (1->b, 1->t, h) => (b, t, h) scale_w.name = "{}.scale.w".format(name) out = scale_w.get_output(0) # (b, t, h) b = weights["{}.bias".format(name)] # (h, ) out_b = network.add_constant(shape=(1, 1, d_layer), weights=trt.Weights(b)).get_output( 0) # (1, 1, h) scale_b = network.add_elementwise( input1=out, input2=out_b, op=ElementWiseOperation.SUM ) # (b, t, h) * (1->b, 1->t, h) => (b, t, h) scale_b.name = "{}.scale.b".format(name) out = scale_b.get_output(0) # (b, t, h) return out
def populate_duration_predictor(self, name, network, weights, seq_tensor, seq_mask_tensor, batch_size, max_seq_len, d_model): duration_predictor_filter_size = self.model.duration_predictor_filter_size duration_predictor_kernel_size = self.model.duration_predictor_kernel_size # Pytorch: input *= input_mask.to(input.dtype) # can be skipped. # Pytorch: out = self.conv1d_1(input.transpose(1,2)).transpose(1,2) trans1 = network.add_shuffle( input=seq_tensor) # (b, t, d_model) to (b, d_model, t, 1) trans1.first_transpose = trt.Permutation([0, 2, 1]) trans1.reshape_dims = Dims((batch_size, d_model, max_seq_len, 1)) trans1.name = "{}.trans1".format(name) out = trans1.get_output(0) # (b, d_model, t, 1) conv1_w = weights["{}.conv1d_1.weight".format( name )] # (1, d_model, duration_predictor_filter_size, duration_predictor_kernel_size, 1) conv1_b = weights["{}.conv1d_1.bias".format( name)] # (duration_predictor_filter_size, ) conv1 = network.add_convolution( input=out, num_output_maps=duration_predictor_filter_size, kernel_shape=trt.DimsHW(duration_predictor_kernel_size, 1), kernel=Weights(conv1_w), bias=Weights(conv1_b)) conv1.padding = trt.DimsHW(1, 0) conv1.name = "{}.conv1".format(name) out = conv1.get_output(0) # (b, duration_predictor_filter_size, t, 1) trans2 = network.add_shuffle( input=out ) # (b, duration_predictor_filter_size, t, 1) to (b, t, duration_predictor_filter_size) trans2.first_transpose = trt.Permutation([0, 2, 1, 3]) trans2.reshape_dims = Dims( (batch_size, max_seq_len, duration_predictor_filter_size)) trans2.name = "{}.trans2".format(name) out = trans2.get_output(0) # (b, t, duration_predictor_filter_size) # Pytorch: out = self.relu_1(out) relu = network.add_activation(input=out, type=trt.ActivationType.RELU) relu.name = "{}.relu1".format(name) out_relu = relu.get_output(0) # (b, t, duration_predictor_filter_size) # Pytorch: out = self.layer_norm_1(out) out = self.populate_layernorm(name="{}.layer_norm_1".format(name), network=network, weights=weights, seq_tensor=out_relu, d_layer=duration_predictor_filter_size, batch_size=batch_size, max_seq_len=max_seq_len) # Pytorch: out = self.conv1d_2(out.transpose(1,2)).transpose(1,2) trans3 = network.add_shuffle( input=out ) # (b, t, duration_predictor_filter_size) to (b, duration_predictor_filter_size, t, 1) trans3.first_transpose = trt.Permutation([0, 2, 1]) trans3.reshape_dims = Dims( (batch_size, duration_predictor_filter_size, max_seq_len, 1)) trans3.name = "{}.trans3".format(name) out = trans3.get_output(0) # (b, duration_predictor_filter_size, t, 1) conv2_w = weights["{}.conv1d_2.weight".format( name )] # (1, duration_predictor_filter_size, duration_predictor_filter_size, duration_predictor_kernel_size, 1) conv2_b = weights["{}.conv1d_2.bias".format( name)] # (duration_predictor_filter_size, ) conv2 = network.add_convolution( input=out, num_output_maps=duration_predictor_filter_size, kernel_shape=trt.DimsHW(duration_predictor_kernel_size, 1), kernel=Weights(conv2_w), bias=Weights(conv2_b)) conv2.padding = trt.DimsHW(1, 0) conv2.name = "{}.conv2".format(name) out = conv2.get_output(0) trans4 = network.add_shuffle( input=out ) # (b, duration_predictor_filter_size, t, 1) to (b, t, duration_predictor_filter_size) trans4.first_transpose = trt.Permutation([0, 2, 1, 3]) trans4.reshape_dims = Dims( (batch_size, max_seq_len, duration_predictor_filter_size)) trans4.name = "{}.trans4".format(name) out = trans4.get_output(0) # (b, t, duration_predictor_filter_size) # Pytorch: out = self.relu_2(out) relu = network.add_activation(input=out, type=trt.ActivationType.RELU) relu.name = "{}.relu2".format(name) out_relu = relu.get_output(0) # (b, t, duration_predictor_filter_size) # Pytorch: out = self.layer_norm_2(out) out = self.populate_layernorm( name="{}.layer_norm_2".format(name), network=network, weights=weights, seq_tensor=out_relu, d_layer=duration_predictor_filter_size, batch_size=batch_size, max_seq_len=max_seq_len, ) # (b, t, duration_predictor_filter_size) # Pytorch: out = self.linear_layer(out) w = weights["{}.linear_layer.weight".format( name)] # (1, duration_predictor_filter_size) out_w = network.add_constant( shape=(1, 1, duration_predictor_filter_size), weights=trt.Weights(w)).get_output( 0) # (1, 1, duration_predictor_filter_size) linear_w = network.add_matrix_multiply( out, MatrixOperation.NONE, out_w, MatrixOperation.TRANSPOSE ) # (b, t, duration_predictor_filter_size) * (1->b, duration_predictor_filter_size, 1) => (b, t, 1) linear_w.name = "{}.linear.w".format(name) out = linear_w.get_output(0) # (b, t, 1) b = weights["{}.linear_layer.bias".format(name)] # (1,) out_b = network.add_constant( shape=(1, 1, 1), weights=trt.Weights(b)).get_output(0) # (1, 1, 1) linear_b = network.add_elementwise(input1=out, input2=out_b, op=trt.ElementWiseOperation.SUM) linear_b.name = "{}.linear.b".format(name) out = linear_b.get_output(0) # (b, t, 1) # Pytorch: out *= input_mask.to(out.dtype) zeros = network.add_constant(weights=Weights( np.zeros(shape=(batch_size, max_seq_len, 1), dtype=np.float32)), shape=(batch_size, max_seq_len, 1)) out_zeros = zeros.get_output(0) # (b, t, 1) dur = network.add_select(condition=seq_mask_tensor, then_input=out, else_input=out_zeros) dur.name = "{}.mask".format(name) out_dur = dur.get_output(0) # Pytorch: duration = torch.clamp_min(torch.exp(duration) - 1, 0) exp = network.add_unary(input=out_dur, op=trt.UnaryOperation.EXP) exp.name = "{}.exp".format(name) out_exp = exp.get_output(0) ones = network.add_constant(weights=Weights( np.ones(shape=(batch_size, max_seq_len, 1), dtype=np.float32)), shape=(batch_size, max_seq_len, 1)) out_ones = ones.get_output(0) # (b, t, 1) sub = network.add_elementwise(input1=out_exp, input2=out_ones, op=trt.ElementWiseOperation.SUB) sub.name = "{}.sub_one".format(name) out_sub = sub.get_output(0) dur = network.add_elementwise(input1=out_sub, input2=out_zeros, op=trt.ElementWiseOperation.MAX) dur.name = "{}.max".format(name) out_dur = dur.get_output(0) # Pytorch: repeats = torch.round(repeats).long() half_ones = network.add_constant(weights=Weights( np.full((batch_size, max_seq_len, 1), 0.5, dtype=np.float32)), shape=(batch_size, max_seq_len, 1)) out_half_ones = half_ones.get_output(0) # (b, t, 1) add = network.add_elementwise(input1=out_dur, input2=out_half_ones, op=trt.ElementWiseOperation.SUM) add.name = "{}.round_add".format(name) out_add = add.get_output(0) # (b, t, 1) dur = network.add_elementwise(input1=out_add, input2=out_ones, op=trt.ElementWiseOperation.FLOOR_DIV) dur.name = "{}.round_floor_div".format(name) out_dur = dur.get_output(0) # (b, t, 1) dur = network.add_shuffle(input=out_dur) # (b, t, 1) to (b, t) dur.reshape_dims = Dims(shape=(batch_size, max_seq_len)) out_dur = dur.get_output(0) # (b, t) return out_dur
def populate_pos_wise(self, name, network, weights, seq_tensor, batch_size, max_seq_len, d_model, conv_filter_size, conv_kernel_size, conv_padding): # Pytorch: output = x.transpose(1, 2) trans1 = network.add_shuffle( input=seq_tensor) # (b, t, d_model) to (b, d_model, t, 1) trans1.first_transpose = trt.Permutation([0, 2, 1]) trans1.reshape_dims = Dims((batch_size, d_model, max_seq_len, 1)) trans1.name = "{}.trans1".format(name) out = trans1.get_output(0) # (b, d_model, t, 1) # Pytorch: output = self.w_1(output) conv1_w = weights["{}.w_1.weight".format( name)] # (1, conv_filter_size, d_model, conv_kernel_size, 1) conv1_b = weights["{}.w_1.bias".format(name)] # (cov_filter_size,) conv1 = network.add_convolution(input=out, num_output_maps=conv_filter_size, kernel_shape=trt.DimsHW( conv_kernel_size, 1), kernel=Weights(conv1_w), bias=Weights(conv1_b)) conv1.padding = trt.DimsHW(1, 0) conv1.name = "{}.conv1".format(name) out = conv1.get_output(0) # (b, conv_filter_size, t, 1) if self.validate_accuracy: self.add_activation_as_output(network, out, "act.{}.conv1".format(name)) # Pytorch: output = F.relu(output) relu = network.add_activation(input=out, type=trt.ActivationType.RELU) relu.name = "{}.relu".format(name) out = relu.get_output(0) # (b, conv_filter_size, t, 1) # Pytorch: output = self.w_2(output) conv2_w = weights["{}.w_2.weight".format( name)] # (1, d_model, conv_filter_size, conv_kernel_size, 1) conv2_b = weights["{}.w_2.bias".format(name)] # (d_model, ) conv2 = network.add_convolution(input=out, num_output_maps=d_model, kernel_shape=trt.DimsHW( conv_kernel_size, 1), kernel=Weights(conv2_w), bias=Weights(conv2_b)) conv2.padding = trt.DimsHW(1, 0) conv2.name = "{}.conv2".format(name) out = conv2.get_output(0) # (b, d_model, t, 1) if self.validate_accuracy: self.add_activation_as_output(network, out, "act.{}.conv2".format(name)) # Pytorch: output = output.transpose(1, 2) trans2 = network.add_shuffle( input=out) # (b, d_model, t, 1) to (b, t, d_model) trans2.first_transpose = trt.Permutation([0, 2, 1, 3]) trans2.reshape_dims = Dims((batch_size, max_seq_len, d_model)) trans2.name = "{}.trans2".format(name) out = trans2.get_output(0) # (b, t, d_model) # Pytorch: output += residual residual = network.add_elementwise(input1=seq_tensor, input2=out, op=trt.ElementWiseOperation.SUM) residual.name = "{}.residual".format(name) out = residual.get_output(0) # (b, t, d_model) if self.validate_accuracy: self.add_activation_as_output(network, out, "act.{}.residual".format(name)) # Pytorch: output = self.layer_norm(output) out = self.populate_layernorm( name="{}.layer_norm".format(name), network=network, weights=weights, seq_tensor=out, batch_size=self.batch_size, max_seq_len=max_seq_len, d_layer=d_model, ) # (b, t, d_model) if self.validate_accuracy: self.add_activation_as_output(network, out, "act.{}.ln".format(name)) return out
def populate_network(self, network, weights, batch_size, trt_max_input_seq_len, trt_max_output_seq_len): d_model = self.model.d_model ## # Inputs ## out_seq = network.add_input(name="input_seq", dtype=trt.float32, shape=(batch_size, trt_max_input_seq_len, d_model)) # (b, t, d_model) # zeros = network.add_constant(weights=Weights( np.zeros(shape=(batch_size, trt_max_input_seq_len, 1), dtype=np.float32)), shape=(batch_size, trt_max_input_seq_len, 1)) # (b, t, 1) out_zeros = zeros.get_output(0) # (b, t, 1) seq = network.add_elementwise(input1=out_seq, input2=out_zeros, op=trt.ElementWiseOperation.SUM) out_seq = seq.get_output(0) # (b, t, d_model) if self.validate_accuracy: self.add_activation_as_output(network, out_seq, "act.emb") # out_seq_mask = network.add_input( # paddings are False name="input_mask", dtype=trt.bool, shape=(batch_size, trt_max_input_seq_len, 1)) # (b, t, 1) ## # Phoneme-side FFT Blocks ## # Positional Encoding # The plugin adds positional encoding to the padding values also (for better performance), whereas Pytorch impl does not. # It's fine because the padding values will be eventually masked out in coming layers, giving accurate output. seq = network.add_plugin_v2([out_seq], self.get_plugin('AddPosEncPlugin')) seq.name = "phoneme_side.add_pos_enc" out_seq = seq.get_output(0) # (b, t, d_model) if self.validate_accuracy: self.add_activation_as_output(network, out_seq, "act.phoneme_side.add_pos_enc") for layer_idx in range(self.model.phoneme_side_n_layer): out_seq = self.populate_fft( name='phoneme_side.layer_stack.{}'.format(layer_idx), network=network, weights=weights, seq_tensor=out_seq, seq_mask_tensor=out_seq_mask, batch_size=self.batch_size, max_seq_len=trt_max_input_seq_len, d_model=d_model, n_heads=self.model.phoneme_side_head, d_k=self.model.phoneme_side.d_k, d_v=self.model.phoneme_side.d_v, self_attn_temp=self.model.phoneme_side.d_k**0.5, conv_filter_size=self.model.phoneme_side_conv1d_filter_size, conv_kernel_size=self.model.fft_conv1d_kernel, conv_padding=self.model.fft_conv1d_padding) if self.validate_accuracy: self.add_activation_as_output(network, out_seq, "act.phoneme_side.seq") out_seq, out_seq_mask, out_dur = self.populate_length_regulator( name="length_regulator", network=network, weights=weights, seq_tensor=out_seq, seq_mask_tensor=out_seq_mask, batch_size=batch_size, trt_max_input_seq_len=trt_max_input_seq_len, trt_max_output_seq_len=trt_max_output_seq_len, d_model=d_model) if self.validate_accuracy: self.add_activation_as_output(network, out_seq, "act.length_regulator.seq") self.add_activation_as_output(network, out_dur, "act.length_regulator.dur") ## # Mel-side FFT Blocks ## # Type int to bool: out_seq_mask. TODO: remove if bool output is allowed in the plugin. ones = network.add_constant(weights=Weights( np.ones(shape=(batch_size, trt_max_output_seq_len, 1), dtype=np.int32)), shape=(batch_size, trt_max_output_seq_len, 1)) # (b, t, 1) out_ones = ones.get_output(0) # (b, t, 1) seq_mask = network.add_elementwise( input1=out_seq_mask, input2=out_ones, op=ElementWiseOperation.EQUAL) # (b, t, 1) seq_mask.name = "mel_side.seq_mask" out_seq_mask = seq_mask.get_output(0) # Positional Encoding seq = network.add_plugin_v2([out_seq], self.get_plugin('AddPosEncPlugin')) seq.name = "mel_side.add_pos_enc" out_seq = seq.get_output(0) if self.validate_accuracy: self.add_activation_as_output(network, out_seq, "act.mel_side.add_pos_enc") for layer_idx in range(self.model.mel_side_n_layer): out_seq = self.populate_fft( name="mel_side.layer_stack.{}".format(layer_idx), network=network, weights=weights, seq_tensor=out_seq, seq_mask_tensor=out_seq_mask, batch_size=self.batch_size, max_seq_len=trt_max_output_seq_len, d_model=d_model, n_heads=self.model.mel_side_head, d_k=self.model.mel_side.d_k, d_v=self.model.mel_side.d_v, self_attn_temp=self.model.mel_side.d_k**0.5, conv_filter_size=self.model.mel_side_conv1d_filter_size, conv_kernel_size=self.model.fft_conv1d_kernel, conv_padding=self.model.fft_conv1d_padding) if self.validate_accuracy: self.add_activation_as_output(network, out_seq, "act.mel_side.seq") ## # Linear ## # Pytorch: self.mel_linear = nn.Linear(mel_side_output_size, n_mels, bias=True) w = weights["mel_linear.weight"] # (n_mels, d_model) out_w = network.add_constant(shape=(1, self.model.n_mels, d_model), weights=trt.Weights(w)).get_output( 0) # (1, n_mels, d_model) linear_w = network.add_matrix_multiply( out_seq, MatrixOperation.NONE, out_w, MatrixOperation.TRANSPOSE ) # (b, t, d_model) * (1->b, d_model, n_mels) => (b, t, n_mels) linear_w.name = "linear.w" out_seq = linear_w.get_output(0) # (b, t, n_mels) b = weights["mel_linear.bias"] # (n_mels,) out_b = network.add_constant(shape=(1, 1, self.model.n_mels), weights=trt.Weights(b)).get_output( 0) # (1, 1, n_mels) linear_b = network.add_elementwise(input1=out_seq, input2=out_b, op=trt.ElementWiseOperation.SUM) linear_b.name = "linear.b" out_seq = linear_b.get_output(0) # (b, t, n_mels) ## # Outputs ## if self.validate_accuracy: self.add_activation_as_output(network, out_seq_mask, "out.seq_mask") self.add_activation_as_output(network, out_seq, "out.seq") seq = network.add_shuffle( input=out_seq) # (b, t, n_mels) to (b, n_mels, t) seq.reshape_dims = Dims( (batch_size, trt_max_output_seq_len, self.model.n_mels)) seq.second_transpose = trt.Permutation([0, 2, 1]) seq.name = "trans_seq" out_seq = seq.get_output(0) seq_mask = network.add_shuffle( input=out_seq_mask) # (b, t, 1) to (b, t) seq_mask.reshape_dims = Dims((batch_size, trt_max_output_seq_len)) out_seq_mask = seq_mask.get_output(0) # (b, t) network.mark_output(tensor=out_seq) # (b, n_mels, t) network.mark_output(tensor=out_seq_mask) # (b, t) return network