def parse_output(self, outputs, output_lengths): # type: (List[Tensor], Tensor) -> List[Tensor] if self.mask_padding and output_lengths is not None: mask = get_mask_from_lengths(output_lengths) mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1)) mask = mask.permute(1, 0, 2) outputs[0].masked_fill_(mask, 0.0) outputs[1].masked_fill_(mask, 0.0) outputs[2].masked_fill_(mask[:, 0, :], 1e3) # gate energies return outputs
def init_decoder_inputs(memory, processed_memory, memory_lengths): device = memory.device dtype = memory.dtype bs = memory.size(0) seq_len = memory.size(1) attention_rnn_dim = 1024 decoder_rnn_dim = 1024 encoder_embedding_dim = 512 n_mel_channels = 80 attention_hidden = torch.zeros(bs, attention_rnn_dim, device=device, dtype=dtype) attention_cell = torch.zeros(bs, attention_rnn_dim, device=device, dtype=dtype) decoder_hidden = torch.zeros(bs, decoder_rnn_dim, device=device, dtype=dtype) decoder_cell = torch.zeros(bs, decoder_rnn_dim, device=device, dtype=dtype) attention_weights = torch.zeros(bs, seq_len, device=device, dtype=dtype) attention_weights_cum = torch.zeros(bs, seq_len, device=device, dtype=dtype) attention_context = torch.zeros(bs, encoder_embedding_dim, device=device, dtype=dtype) mask = get_mask_from_lengths(memory_lengths).to(device) decoder_input = torch.zeros(bs, n_mel_channels, device=device, dtype=dtype) return (decoder_input, attention_hidden, attention_cell, decoder_hidden, decoder_cell, attention_weights, attention_weights_cum, attention_context, memory, processed_memory, mask)
def main(): parser = argparse.ArgumentParser( description='PyTorch Tacotron 2 export to TRT') parser = parse_args(parser) args, _ = parser.parse_known_args() tacotron2 = load_and_setup_model('Tacotron2', parser, args.tacotron2, fp16_run=args.fp16, cpu_run=False) opset_version = 10 sequences = torch.randint(low=0, high=148, size=(1, 50), dtype=torch.long).cuda() sequence_lengths = torch.IntTensor([sequences.size(1)]).cuda().long() dummy_input = (sequences, sequence_lengths) encoder = Encoder(tacotron2) encoder.eval() with torch.no_grad(): encoder(*dummy_input) torch.onnx.export(encoder, dummy_input, args.output + "/" + "encoder.onnx", opset_version=opset_version, do_constant_folding=True, input_names=["sequences", "sequence_lengths"], output_names=["memory", "processed_memory", "lens"], dynamic_axes={ "sequences": { 1: "text_seq" }, "memory": { 1: "mem_seq" }, "processed_memory": { 1: "mem_seq" } }) decoder_iter = DecoderIter(tacotron2) memory = torch.randn( (1, sequence_lengths[0], 512)).cuda() #encoder_outputs if args.fp16: memory = memory.half() memory_lengths = sequence_lengths # initialize decoder states for dummy_input decoder_input = tacotron2.decoder.get_go_frame(memory) mask = get_mask_from_lengths(memory_lengths) (attention_hidden, attention_cell, decoder_hidden, decoder_cell, attention_weights, attention_weights_cum, attention_context, processed_memory) = tacotron2.decoder.initialize_decoder_states(memory) dummy_input = (decoder_input, attention_hidden, attention_cell, decoder_hidden, decoder_cell, attention_weights, attention_weights_cum, attention_context, memory, processed_memory, mask) decoder_iter = DecoderIter(tacotron2) decoder_iter.eval() with torch.no_grad(): decoder_iter(*dummy_input) torch.onnx.export(decoder_iter, dummy_input, args.output + "/" + "decoder_iter.onnx", opset_version=opset_version, do_constant_folding=True, input_names=[ "decoder_input", "attention_hidden", "attention_cell", "decoder_hidden", "decoder_cell", "attention_weights", "attention_weights_cum", "attention_context", "memory", "processed_memory", "mask" ], output_names=[ "decoder_output", "gate_prediction", "out_attention_hidden", "out_attention_cell", "out_decoder_hidden", "out_decoder_cell", "out_attention_weights", "out_attention_weights_cum", "out_attention_context" ], dynamic_axes={ "attention_weights": { 1: "seq_len" }, "attention_weights_cum": { 1: "seq_len" }, "memory": { 1: "seq_len" }, "processed_memory": { 1: "seq_len" }, "mask": { 1: "seq_len" }, "out_attention_weights": { 1: "seq_len" }, "out_attention_weights_cum": { 1: "seq_len" } }) postnet = Postnet(tacotron2) dummy_input = torch.randn((1, 80, 620)).cuda() if args.fp16: dummy_input = dummy_input.half() torch.onnx.export(postnet, dummy_input, args.output + "/" + "postnet.onnx", opset_version=opset_version, do_constant_folding=True, input_names=["mel_outputs"], output_names=["mel_outputs_postnet"], dynamic_axes={ "mel_outputs": { 2: "mel_seq" }, "mel_outputs_postnet": { 2: "mel_seq" } }) mel = test_inference(encoder, decoder_iter, postnet) torch.save(mel, "mel.pt")
def infer(self, memory, memory_lengths): """ Decoder inference PARAMS ------ memory: Encoder outputs RETURNS ------- mel_outputs: mel outputs from the decoder gate_outputs: gate outputs from the decoder alignments: sequence of attention weights from the decoder """ decoder_input = self.get_go_frame(memory) mask = get_mask_from_lengths(memory_lengths) (attention_hidden, attention_cell, decoder_hidden, decoder_cell, attention_weights, attention_weights_cum, attention_context, processed_memory) = self.initialize_decoder_states(memory) mel_lengths = torch.zeros([memory.size(0)], dtype=torch.int32, device=memory.device) not_finished = torch.ones([memory.size(0)], dtype=torch.int32, device=memory.device) mel_outputs, gate_outputs, alignments = ( torch.zeros(1), torch.zeros(1), torch.zeros(1)) first_iter = True while True: decoder_input = self.prenet(decoder_input) (mel_output, gate_output, attention_hidden, attention_cell, decoder_hidden, decoder_cell, attention_weights, attention_weights_cum, attention_context) = self.decode(decoder_input, attention_hidden, attention_cell, decoder_hidden, decoder_cell, attention_weights, attention_weights_cum, attention_context, memory, processed_memory, mask) if first_iter: mel_outputs = mel_output.unsqueeze(0) gate_outputs = gate_output alignments = attention_weights first_iter = False else: mel_outputs = torch.cat( (mel_outputs, mel_output.unsqueeze(0)), dim=0) gate_outputs = torch.cat((gate_outputs, gate_output), dim=0) alignments = torch.cat((alignments, attention_weights), dim=0) dec = torch.le(torch.sigmoid(gate_output), self.gate_threshold).to(torch.int32).squeeze(1) not_finished = not_finished*dec mel_lengths += not_finished if self.early_stopping and torch.sum(not_finished) == 0: break if len(mel_outputs) == self.max_decoder_steps: print("Warning! Reached max decoder steps") break decoder_input = mel_output mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( mel_outputs, gate_outputs, alignments) return mel_outputs, gate_outputs, alignments, mel_lengths
def forward(self, memory, decoder_inputs, memory_lengths): """ Decoder forward pass for training PARAMS ------ memory: Encoder outputs decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs memory_lengths: Encoder output lengths for attention masking. RETURNS ------- mel_outputs: mel outputs from the decoder gate_outputs: gate outputs from the decoder alignments: sequence of attention weights from the decoder """ decoder_input = self.get_go_frame(memory).unsqueeze(0) decoder_inputs = self.parse_decoder_inputs(decoder_inputs) decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0) decoder_inputs = self.prenet(decoder_inputs) mask = get_mask_from_lengths(memory_lengths) (attention_hidden, attention_cell, decoder_hidden, decoder_cell, attention_weights, attention_weights_cum, attention_context, processed_memory) = self.initialize_decoder_states(memory) mel_outputs, gate_outputs, alignments = [], [], [] while len(mel_outputs) < decoder_inputs.size(0) - 1: decoder_input = decoder_inputs[len(mel_outputs)] (mel_output, gate_output, attention_hidden, attention_cell, decoder_hidden, decoder_cell, attention_weights, attention_weights_cum, attention_context) = self.decode(decoder_input, attention_hidden, attention_cell, decoder_hidden, decoder_cell, attention_weights, attention_weights_cum, attention_context, memory, processed_memory, mask) mel_outputs += [mel_output.squeeze(1)] gate_outputs += [gate_output.squeeze()] alignments += [attention_weights] mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( torch.stack(mel_outputs), torch.stack(gate_outputs), torch.stack(alignments)) return mel_outputs, gate_outputs, alignments