예제 #1
0
    def parse_output(self, outputs, output_lengths):
        # type: (List[Tensor], Tensor) -> List[Tensor]
        if self.mask_padding and output_lengths is not None:
            mask = get_mask_from_lengths(output_lengths)
            mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
            mask = mask.permute(1, 0, 2)

            outputs[0].masked_fill_(mask, 0.0)
            outputs[1].masked_fill_(mask, 0.0)
            outputs[2].masked_fill_(mask[:, 0, :], 1e3)  # gate energies

        return outputs
def init_decoder_inputs(memory, processed_memory, memory_lengths):

    device = memory.device
    dtype = memory.dtype
    bs = memory.size(0)
    seq_len = memory.size(1)
    attention_rnn_dim = 1024
    decoder_rnn_dim = 1024
    encoder_embedding_dim = 512
    n_mel_channels = 80

    attention_hidden = torch.zeros(bs,
                                   attention_rnn_dim,
                                   device=device,
                                   dtype=dtype)
    attention_cell = torch.zeros(bs,
                                 attention_rnn_dim,
                                 device=device,
                                 dtype=dtype)
    decoder_hidden = torch.zeros(bs,
                                 decoder_rnn_dim,
                                 device=device,
                                 dtype=dtype)
    decoder_cell = torch.zeros(bs, decoder_rnn_dim, device=device, dtype=dtype)
    attention_weights = torch.zeros(bs, seq_len, device=device, dtype=dtype)
    attention_weights_cum = torch.zeros(bs,
                                        seq_len,
                                        device=device,
                                        dtype=dtype)
    attention_context = torch.zeros(bs,
                                    encoder_embedding_dim,
                                    device=device,
                                    dtype=dtype)
    mask = get_mask_from_lengths(memory_lengths).to(device)
    decoder_input = torch.zeros(bs, n_mel_channels, device=device, dtype=dtype)

    return (decoder_input, attention_hidden, attention_cell, decoder_hidden,
            decoder_cell, attention_weights, attention_weights_cum,
            attention_context, memory, processed_memory, mask)
예제 #3
0
def main():

    parser = argparse.ArgumentParser(
        description='PyTorch Tacotron 2 export to TRT')
    parser = parse_args(parser)
    args, _ = parser.parse_known_args()

    tacotron2 = load_and_setup_model('Tacotron2',
                                     parser,
                                     args.tacotron2,
                                     fp16_run=args.fp16,
                                     cpu_run=False)

    opset_version = 10

    sequences = torch.randint(low=0, high=148, size=(1, 50),
                              dtype=torch.long).cuda()
    sequence_lengths = torch.IntTensor([sequences.size(1)]).cuda().long()
    dummy_input = (sequences, sequence_lengths)

    encoder = Encoder(tacotron2)
    encoder.eval()
    with torch.no_grad():
        encoder(*dummy_input)

    torch.onnx.export(encoder,
                      dummy_input,
                      args.output + "/" + "encoder.onnx",
                      opset_version=opset_version,
                      do_constant_folding=True,
                      input_names=["sequences", "sequence_lengths"],
                      output_names=["memory", "processed_memory", "lens"],
                      dynamic_axes={
                          "sequences": {
                              1: "text_seq"
                          },
                          "memory": {
                              1: "mem_seq"
                          },
                          "processed_memory": {
                              1: "mem_seq"
                          }
                      })

    decoder_iter = DecoderIter(tacotron2)
    memory = torch.randn(
        (1, sequence_lengths[0], 512)).cuda()  #encoder_outputs
    if args.fp16:
        memory = memory.half()
    memory_lengths = sequence_lengths
    # initialize decoder states for dummy_input
    decoder_input = tacotron2.decoder.get_go_frame(memory)
    mask = get_mask_from_lengths(memory_lengths)
    (attention_hidden, attention_cell, decoder_hidden, decoder_cell,
     attention_weights, attention_weights_cum, attention_context,
     processed_memory) = tacotron2.decoder.initialize_decoder_states(memory)
    dummy_input = (decoder_input, attention_hidden, attention_cell,
                   decoder_hidden, decoder_cell, attention_weights,
                   attention_weights_cum, attention_context, memory,
                   processed_memory, mask)

    decoder_iter = DecoderIter(tacotron2)
    decoder_iter.eval()
    with torch.no_grad():
        decoder_iter(*dummy_input)

    torch.onnx.export(decoder_iter,
                      dummy_input,
                      args.output + "/" + "decoder_iter.onnx",
                      opset_version=opset_version,
                      do_constant_folding=True,
                      input_names=[
                          "decoder_input", "attention_hidden",
                          "attention_cell", "decoder_hidden", "decoder_cell",
                          "attention_weights", "attention_weights_cum",
                          "attention_context", "memory", "processed_memory",
                          "mask"
                      ],
                      output_names=[
                          "decoder_output", "gate_prediction",
                          "out_attention_hidden", "out_attention_cell",
                          "out_decoder_hidden", "out_decoder_cell",
                          "out_attention_weights", "out_attention_weights_cum",
                          "out_attention_context"
                      ],
                      dynamic_axes={
                          "attention_weights": {
                              1: "seq_len"
                          },
                          "attention_weights_cum": {
                              1: "seq_len"
                          },
                          "memory": {
                              1: "seq_len"
                          },
                          "processed_memory": {
                              1: "seq_len"
                          },
                          "mask": {
                              1: "seq_len"
                          },
                          "out_attention_weights": {
                              1: "seq_len"
                          },
                          "out_attention_weights_cum": {
                              1: "seq_len"
                          }
                      })

    postnet = Postnet(tacotron2)
    dummy_input = torch.randn((1, 80, 620)).cuda()
    if args.fp16:
        dummy_input = dummy_input.half()
    torch.onnx.export(postnet,
                      dummy_input,
                      args.output + "/" + "postnet.onnx",
                      opset_version=opset_version,
                      do_constant_folding=True,
                      input_names=["mel_outputs"],
                      output_names=["mel_outputs_postnet"],
                      dynamic_axes={
                          "mel_outputs": {
                              2: "mel_seq"
                          },
                          "mel_outputs_postnet": {
                              2: "mel_seq"
                          }
                      })

    mel = test_inference(encoder, decoder_iter, postnet)
    torch.save(mel, "mel.pt")
예제 #4
0
    def infer(self, memory, memory_lengths):
        """ Decoder inference
        PARAMS
        ------
        memory: Encoder outputs

        RETURNS
        -------
        mel_outputs: mel outputs from the decoder
        gate_outputs: gate outputs from the decoder
        alignments: sequence of attention weights from the decoder
        """
        decoder_input = self.get_go_frame(memory)

        mask = get_mask_from_lengths(memory_lengths)
        (attention_hidden,
         attention_cell,
         decoder_hidden,
         decoder_cell,
         attention_weights,
         attention_weights_cum,
         attention_context,
         processed_memory) = self.initialize_decoder_states(memory)

        mel_lengths = torch.zeros([memory.size(0)], dtype=torch.int32, device=memory.device)
        not_finished = torch.ones([memory.size(0)], dtype=torch.int32, device=memory.device)

        mel_outputs, gate_outputs, alignments = (
            torch.zeros(1), torch.zeros(1), torch.zeros(1))
        first_iter = True
        while True:
            decoder_input = self.prenet(decoder_input)
            (mel_output,
             gate_output,
             attention_hidden,
             attention_cell,
             decoder_hidden,
             decoder_cell,
             attention_weights,
             attention_weights_cum,
             attention_context) = self.decode(decoder_input,
                                              attention_hidden,
                                              attention_cell,
                                              decoder_hidden,
                                              decoder_cell,
                                              attention_weights,
                                              attention_weights_cum,
                                              attention_context,
                                              memory,
                                              processed_memory,
                                              mask)

            if first_iter:
                mel_outputs = mel_output.unsqueeze(0)
                gate_outputs = gate_output
                alignments = attention_weights
                first_iter = False
            else:
                mel_outputs = torch.cat(
                    (mel_outputs, mel_output.unsqueeze(0)), dim=0)
                gate_outputs = torch.cat((gate_outputs, gate_output), dim=0)
                alignments = torch.cat((alignments, attention_weights), dim=0)

            dec = torch.le(torch.sigmoid(gate_output),
                           self.gate_threshold).to(torch.int32).squeeze(1)

            not_finished = not_finished*dec
            mel_lengths += not_finished

            if self.early_stopping and torch.sum(not_finished) == 0:
                break
            if len(mel_outputs) == self.max_decoder_steps:
                print("Warning! Reached max decoder steps")
                break

            decoder_input = mel_output

        mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
            mel_outputs, gate_outputs, alignments)

        return mel_outputs, gate_outputs, alignments, mel_lengths
예제 #5
0
    def forward(self, memory, decoder_inputs, memory_lengths):
        """ Decoder forward pass for training
        PARAMS
        ------
        memory: Encoder outputs
        decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
        memory_lengths: Encoder output lengths for attention masking.

        RETURNS
        -------
        mel_outputs: mel outputs from the decoder
        gate_outputs: gate outputs from the decoder
        alignments: sequence of attention weights from the decoder
        """

        decoder_input = self.get_go_frame(memory).unsqueeze(0)
        decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
        decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
        decoder_inputs = self.prenet(decoder_inputs)

        mask = get_mask_from_lengths(memory_lengths)
        (attention_hidden,
         attention_cell,
         decoder_hidden,
         decoder_cell,
         attention_weights,
         attention_weights_cum,
         attention_context,
         processed_memory) = self.initialize_decoder_states(memory)

        mel_outputs, gate_outputs, alignments = [], [], []
        while len(mel_outputs) < decoder_inputs.size(0) - 1:
            decoder_input = decoder_inputs[len(mel_outputs)]
            (mel_output,
             gate_output,
             attention_hidden,
             attention_cell,
             decoder_hidden,
             decoder_cell,
             attention_weights,
             attention_weights_cum,
             attention_context) = self.decode(decoder_input,
                                              attention_hidden,
                                              attention_cell,
                                              decoder_hidden,
                                              decoder_cell,
                                              attention_weights,
                                              attention_weights_cum,
                                              attention_context,
                                              memory,
                                              processed_memory,
                                              mask)

            mel_outputs += [mel_output.squeeze(1)]
            gate_outputs += [gate_output.squeeze()]
            alignments += [attention_weights]

        mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
            torch.stack(mel_outputs),
            torch.stack(gate_outputs),
            torch.stack(alignments))

        return mel_outputs, gate_outputs, alignments