def infer_waveglow_trt(waveglow, waveglow_context, mel, measurements, fp16): mel = mel.unsqueeze(3) mel_size = mel.size(2) batch_size = mel.size(0) stride = 256 kernel_size = 1024 n_group = 8 z_size = (mel_size-1)*stride+(kernel_size-1)+1 z_size = z_size - (kernel_size-stride) z_size = z_size//n_group z = torch.randn(batch_size, n_group, z_size, 1).cuda() audios = torch.zeros(batch_size, mel_size*stride).cuda() if fp16: z = z.half() mel = mel.half() audios = audios.half() waveglow_tensors = { # inputs 'mel': mel, 'z': z, # outputs 'audio': audios } print("Running WaveGlow") with MeasureTime(measurements, "waveglow_time"): run_trt_engine(waveglow_context, waveglow, waveglow_tensors) return audios
def infer_waveglow_trt(waveglow, waveglow_context, mel, measurements, fp16): mel_size = mel.size(2) batch_size = mel.size(0) stride = 256 n_group = 8 z_size = mel_size*stride z_size = z_size//n_group z = torch.randn(batch_size, n_group, z_size).cuda() audios = torch.zeros(batch_size, mel_size*stride).cuda() if fp16: z = z.half() mel = mel.half() audios = audios.half() waveglow_tensors = { "inputs" : {'mel': mel, 'z': z}, "outputs" : {'audio': audios} } print("Running WaveGlow") with MeasureTime(measurements, "waveglow_time"): run_trt_engine(waveglow_context, waveglow, waveglow_tensors) return audios
def infer_tacotron2_trt(encoder, decoder_iter, postnet, encoder_context, decoder_context, postnet_context, sequences, sequence_lengths, measurements, fp16): memory = torch.zeros((len(sequence_lengths), sequence_lengths[0], 512)).cuda() if fp16: memory = memory.half() device = memory.device dtype = memory.dtype processed_memory = torch.zeros((len(sequence_lengths),sequence_lengths[0],128), device=device, dtype=dtype) lens = torch.zeros_like(sequence_lengths) encoder_tensors = { # inputs 'sequences': sequences, 'sequence_lengths': sequence_lengths, # outputs 'memory': memory, 'lens': lens, 'processed_memory': processed_memory } print("Running Tacotron2 Encoder") with MeasureTime(measurements, "tacotron2_encoder_time"): run_trt_engine(encoder_context, encoder, encoder_tensors) device = memory.device mel_lengths = torch.zeros([memory.size(0)], dtype=torch.int32, device = device) not_finished = torch.ones([memory.size(0)], dtype=torch.int32, device = device) mel_outputs, gate_outputs, alignments = (torch.zeros(1, device = device), torch.zeros(1, device = device), torch.zeros(1, device = device)) gate_threshold = 0.5 max_decoder_steps = 1664 first_iter = True decoder_inputs = init_decoder_inputs(memory, processed_memory, sequence_lengths) decoder_outputs = init_decoder_outputs(memory, sequence_lengths) print("Running Tacotron2 Decoder") while True: decoder_tensors = init_decoder_tensors(decoder_inputs, decoder_outputs) with MeasureTime(measurements, "step"): run_trt_engine(decoder_context, decoder_iter, decoder_tensors) if first_iter: mel_outputs = torch.unsqueeze(decoder_outputs[7], 2) gate_outputs = torch.unsqueeze(decoder_outputs[8], 2) alignments = torch.unsqueeze(decoder_outputs[4], 2) measurements['tacotron2_decoder_time'] = measurements['step'] first_iter = False else: mel_outputs = torch.cat((mel_outputs, torch.unsqueeze(decoder_outputs[7], 2)), 2) gate_outputs = torch.cat((gate_outputs, torch.unsqueeze(decoder_outputs[8], 2)), 2) alignments = torch.cat((alignments, torch.unsqueeze(decoder_outputs[4], 2)), 2) measurements['tacotron2_decoder_time'] += measurements['step'] dec = torch.le(torch.sigmoid(decoder_outputs[8]), gate_threshold).to(torch.int32).squeeze(1) not_finished = not_finished*dec mel_lengths += not_finished if torch.sum(not_finished) == 0: print("Stopping after",mel_outputs.size(2),"decoder steps") break if mel_outputs.size(2) == max_decoder_steps: print("Warning! Reached max decoder steps") break decoder_inputs, decoder_outputs = swap_inputs_outputs(decoder_inputs, decoder_outputs) mel_outputs_postnet = torch.zeros_like(mel_outputs, device=device, dtype=dtype) postnet_tensors = { # inputs 'mel_outputs': mel_outputs, # outputs 'mel_outputs_postnet': mel_outputs_postnet } print("Running Tacotron2 Postnet") with MeasureTime(measurements, "tacotron2_postnet_time"): run_trt_engine(postnet_context, postnet, postnet_tensors) print("Tacotron2 Postnet done") return mel_outputs_postnet, mel_lengths