def infer_waveglow_trt(waveglow, waveglow_context, mel, measurements, fp16): mel_size = mel.size(2) batch_size = mel.size(0) stride = 256 n_group = 8 z_size = mel_size * stride z_size = z_size // n_group z = torch.randn(batch_size, n_group, z_size).cuda() audios = torch.zeros(batch_size, mel_size * stride).cuda() mel = mel.unsqueeze(3) z = z.unsqueeze(3) if fp16: z = z.half() mel = mel.half() audios = audios.half() waveglow_tensors = { "inputs": { 'mel': mel, 'z': z }, "outputs": { 'audio': audios } } print("Running WaveGlow with TensorRT") with MeasureTime(measurements, "waveglow_time"): run_trt_engine(waveglow_context, waveglow, waveglow_tensors) return audios
def infer_tacotron2_trt(encoder, decoder_iter, postnet, encoder_context, decoder_context, postnet_context, sequences, sequence_lengths, measurements, fp16, loop): batch_size = len(sequence_lengths) max_sequence_len = sequence_lengths[0] memory = torch.zeros((batch_size, max_sequence_len, 512)).cuda() if fp16: memory = memory.half() device = memory.device dtype = memory.dtype processed_memory = torch.zeros((batch_size, max_sequence_len, 128), device=device, dtype=dtype) lens = torch.zeros_like(sequence_lengths) print(f"batch_size: {batch_size}, max sequence length: {max_sequence_len}") encoder_tensors = { "inputs": { 'sequences': sequences, 'sequence_lengths': sequence_lengths }, "outputs": { 'memory': memory, 'lens': lens, 'processed_memory': processed_memory } } print("Running Tacotron2 Encoder") with MeasureTime(measurements, "tacotron2_encoder_time"): run_trt_engine(encoder_context, encoder, encoder_tensors) max_decoder_steps = 1024 device = memory.device mel_lengths = torch.zeros([memory.size(0)], dtype=torch.int32, device=device) not_finished = torch.ones([memory.size(0)], dtype=torch.int32, device=device) mel_outputs = torch.ones((batch_size, 80, max_decoder_steps), device=device, dtype=dtype).cuda() gate_threshold = 0.5 first_iter = True decoder_inputs = init_decoder_inputs(memory, processed_memory, sequence_lengths) decoder_outputs = init_decoder_outputs(memory, sequence_lengths) if loop: if decoder_context is None: print("Running Tacotron2 Decoder with loop with ONNX-RT") decoder_inputs_onnxrt = [ x.cpu().numpy().copy() for x in decoder_inputs ] import onnx import onnxruntime sess = onnxruntime.InferenceSession(decoder_iter) with MeasureTime(measurements, "tacotron2_decoder_time"): result = sess.run( ["mel_outputs", "mel_lengths_t"], { 'decoder_input_0': decoder_inputs_onnxrt[0], 'attention_hidden_0': decoder_inputs_onnxrt[1], 'attention_cell_0': decoder_inputs_onnxrt[2], 'decoder_hidden_0': decoder_inputs_onnxrt[3], 'decoder_cell_0': decoder_inputs_onnxrt[4], 'attention_weights_0': decoder_inputs_onnxrt[5], 'attention_weights_cum_0': decoder_inputs_onnxrt[6], 'attention_context_0': decoder_inputs_onnxrt[7], 'memory': decoder_inputs_onnxrt[8], 'processed_memory': decoder_inputs_onnxrt[9], 'mask': decoder_inputs_onnxrt[10] }) mel_outputs = torch.tensor(result[0], device=device) mel_lengths = torch.tensor(result[1], device=device) else: print("Running Tacotron2 Decoder with loop") decoder_tensors = { "inputs": { 'decoder_input_0': decoder_inputs[0], 'attention_hidden_0': decoder_inputs[1], 'attention_cell_0': decoder_inputs[2], 'decoder_hidden_0': decoder_inputs[3], 'decoder_cell_0': decoder_inputs[4], 'attention_weights_0': decoder_inputs[5], 'attention_weights_cum_0': decoder_inputs[6], 'attention_context_0': decoder_inputs[7], 'memory': decoder_inputs[8], 'processed_memory': decoder_inputs[9], 'mask': decoder_inputs[10] }, "outputs": { 'mel_outputs': mel_outputs, 'mel_lengths_t': mel_lengths } } with MeasureTime(measurements, "tacotron2_decoder_time"): run_trt_engine(decoder_context, decoder_iter, decoder_tensors) mel_outputs = mel_outputs[:, :, :torch.max(mel_lengths)] else: print("Running Tacotron2 Decoder") measurements_decoder = {} while True: decoder_tensors = init_decoder_tensors(decoder_inputs, decoder_outputs) with MeasureTime(measurements_decoder, "step"): run_trt_engine(decoder_context, decoder_iter, decoder_tensors) if first_iter: mel_outputs = torch.unsqueeze(decoder_outputs[7], 2) gate_outputs = torch.unsqueeze(decoder_outputs[8], 2) alignments = torch.unsqueeze(decoder_outputs[4], 2) measurements['tacotron2_decoder_time'] = measurements_decoder[ 'step'] first_iter = False else: mel_outputs = torch.cat( (mel_outputs, torch.unsqueeze(decoder_outputs[7], 2)), 2) gate_outputs = torch.cat( (gate_outputs, torch.unsqueeze(decoder_outputs[8], 2)), 2) alignments = torch.cat( (alignments, torch.unsqueeze(decoder_outputs[4], 2)), 2) measurements['tacotron2_decoder_time'] += measurements_decoder[ 'step'] dec = torch.le(torch.sigmoid(decoder_outputs[8]), gate_threshold).to(torch.int32).squeeze(1) not_finished = not_finished * dec mel_lengths += not_finished if torch.sum(not_finished) == 0: print("Stopping after", mel_outputs.size(2), "decoder steps") break if mel_outputs.size(2) == max_decoder_steps: print("Warning! Reached max decoder steps") break decoder_inputs, decoder_outputs = swap_inputs_outputs( decoder_inputs, decoder_outputs) mel_outputs = mel_outputs.clone().detach() mel_outputs_postnet = torch.zeros_like(mel_outputs, device=device, dtype=dtype) postnet_tensors = { "inputs": { 'mel_outputs': mel_outputs }, "outputs": { 'mel_outputs_postnet': mel_outputs_postnet } } print("Running Tacotron2 Postnet") with MeasureTime(measurements, "tacotron2_postnet_time"): run_trt_engine(postnet_context, postnet, postnet_tensors) print("Tacotron2 Postnet done") return mel_outputs_postnet, mel_lengths
def infer_tacotron2_trt(encoder, decoder_iter, postnet, encoder_context, decoder_context, postnet_context, sequences, sequence_lengths, measurements, fp16): memory = torch.zeros( (len(sequence_lengths), sequence_lengths[0], 512)).cuda() if fp16: memory = memory.half() device = memory.device dtype = memory.dtype processed_memory = torch.zeros( (len(sequence_lengths), sequence_lengths[0], 128), device=device, dtype=dtype) lens = torch.zeros_like(sequence_lengths) encoder_tensors = { "inputs": { 'sequences': sequences, 'sequence_lengths': sequence_lengths }, "outputs": { 'memory': memory, 'lens': lens, 'processed_memory': processed_memory } } print("Running Tacotron2 Encoder") with MeasureTime(measurements, "tacotron2_encoder_time"): run_trt_engine(encoder_context, encoder, encoder_tensors) device = memory.device mel_lengths = torch.zeros([memory.size(0)], dtype=torch.int32, device=device) not_finished = torch.ones([memory.size(0)], dtype=torch.int32, device=device) mel_outputs, gate_outputs, alignments = (torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device)) gate_threshold = 0.5 max_decoder_steps = 1664 first_iter = True decoder_inputs = init_decoder_inputs(memory, processed_memory, sequence_lengths) decoder_outputs = init_decoder_outputs(memory, sequence_lengths) print("Running Tacotron2 Decoder") measurements_decoder = {} while True: decoder_tensors = init_decoder_tensors(decoder_inputs, decoder_outputs) with MeasureTime(measurements_decoder, "step"): run_trt_engine(decoder_context, decoder_iter, decoder_tensors) if first_iter: mel_outputs = torch.unsqueeze(decoder_outputs[7], 2) gate_outputs = torch.unsqueeze(decoder_outputs[8], 2) alignments = torch.unsqueeze(decoder_outputs[4], 2) measurements['tacotron2_decoder_time'] = measurements_decoder[ 'step'] first_iter = False else: mel_outputs = torch.cat( (mel_outputs, torch.unsqueeze(decoder_outputs[7], 2)), 2) gate_outputs = torch.cat( (gate_outputs, torch.unsqueeze(decoder_outputs[8], 2)), 2) alignments = torch.cat( (alignments, torch.unsqueeze(decoder_outputs[4], 2)), 2) measurements['tacotron2_decoder_time'] += measurements_decoder[ 'step'] dec = torch.le(torch.sigmoid(decoder_outputs[8]), gate_threshold).to(torch.int32).squeeze(1) not_finished = not_finished * dec mel_lengths += not_finished if torch.sum(not_finished) == 0: print("Stopping after", mel_outputs.size(2), "decoder steps") break if mel_outputs.size(2) == max_decoder_steps: print("Warning! Reached max decoder steps") break decoder_inputs, decoder_outputs = swap_inputs_outputs( decoder_inputs, decoder_outputs) mel_outputs_postnet = torch.zeros_like(mel_outputs, device=device, dtype=dtype) postnet_tensors = { "inputs": { 'mel_outputs': mel_outputs }, "outputs": { 'mel_outputs_postnet': mel_outputs_postnet } } print("Running Tacotron2 Postnet") with MeasureTime(measurements, "tacotron2_postnet_time"): run_trt_engine(postnet_context, postnet, postnet_tensors) print("Tacotron2 Postnet done") return mel_outputs_postnet, mel_lengths