예제 #1
0
    def forward(self, memory, targets, memory_lengths, gta=False):
        """ Decoder forward pass for training
        PARAMS
        ------
        memory: Encoder outputs
        targets: Decoder inputs for teacher forcing. i.e. mel-specs
        memory_lengths: Encoder output lengths for attention masking.

        RETURNS
        -------
        mel_outputs: mel outputs from the decoder
        gate_outputs: gate outputs from the decoder
        alignments: sequence of attention weights from the decoder
        """
        go_frame = memory.new(memory.size(0), self.n_mel_channels).zero_().unsqueeze(0)
        # (B, n_mel_channels, T_out) -> (T_out, B, n_mel_channels)
        targets = targets.permute(2, 0, 1)
        decoder_inputs = torch.cat((go_frame, targets), dim=0)
        prenet_outputs = self.prenet(decoder_inputs, inference=gta)

        mask =~ get_mask_from_lengths(memory_lengths) if memory.size(0) > 1 else None
        self.initialize_decoder_states(memory, mask)

        mel_outputs, gate_outputs, alignments = [], [], []
        # size - 1 for ignoring EOS synbol
        while len(mel_outputs) < decoder_inputs.size(0) - 1:
            prenet_output = prenet_outputs[len(mel_outputs)]
            mel_output, gate_output, attention_weights = self.decode(prenet_output)

            mel_outputs += [mel_output]
            gate_outputs += [gate_output]
            alignments += [attention_weights]

        return self.parse_decoder_outputs(mel_outputs, gate_outputs, alignments)
예제 #2
0
    def forward(self, memory, targets, memory_lengths):
        """ Decoder forward pass for training
        PARAMS
        ------
        memory: Encoder outputs
        targets: Decoder inputs for teacher forcing. i.e. mel-specs
        memory_lengths: Encoder output lengths for attention masking.

        RETURNS
        -------
        mel_outputs: mel outputs from the decoder
        gate_outputs: gate outputs from the decoder
        alignments: sequence of attention weights from the decoder
        """
        go_frame = self.get_go_frame(memory).unsqueeze(0)
        targets = self.parse_decoder_inputs(targets)
        decoder_inputs = torch.cat((go_frame, targets), dim=0)
        prenet_outputs = self.prenet(decoder_inputs)

        self.initialize_decoder_states(
            memory, mask=~get_mask_from_lengths(memory_lengths))

        mel_outputs, gate_outputs, alignments = [], [], []
        while len(mel_outputs) < decoder_inputs.size(0) - 1:
            prenet_output = prenet_outputs[len(mel_outputs)]
            mel_output, gate_output, attention_weights = self.decode(
                prenet_output)

            mel_outputs += [mel_output]
            gate_outputs += [gate_output]
            alignments += [attention_weights]

        return self.parse_decoder_outputs(mel_outputs, gate_outputs,
                                          alignments)
예제 #3
0
    def forward(self, memory, decoder_inputs, memory_lengths, teacher_forcing):
        """ Decoder forward pass for training
        PARAMS
        ------
        memory: Encoder outputs
        decoder_inputs: Decoder inputs for teacher forcing. i.e. acoustic-feats.
        memory_lengths: Encoder output lengths for attention masking.

        RETURNS
        -------
        acoustic_outputs: acoustic outputs from the decoder
        gate_outputs: gate outputs from the decoder
        alignments: sequence of attention weights from the decoder
        """

        decoder_input = self.get_go_frame(memory).unsqueeze(0)
        decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
        decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
        decoder_inputs = self.prenet(decoder_inputs)

        self.initialize_decoder_states(
            memory, mask=~get_mask_from_lengths(memory_lengths))

        acoustic_outputs, gate_outputs, alignments = [], [], []
        teach_force_flags = np.random.choice(
            2, [decoder_inputs.size(0) - 1],
            p=[1 - teacher_forcing, teacher_forcing])

        while len(acoustic_outputs) < decoder_inputs.size(0) - 1:
            step = len(acoustic_outputs)
            if step > 0 and not teach_force_flags[step]:
                decoder_input = self.prenet(acoustic_output)
            else:
                decoder_input = decoder_inputs[step]
            if self.attention_window_size is not None:
                attention_windowed_mask = get_mask_from_lengths_window_and_time_step(
                    memory_lengths, self.attention_window_size, step)
            else:
                attention_windowed_mask = None

            #decoder_input = decoder_inputs[len(acoustic_outputs)]
            #if self.attention_window_size is not None:
            #    time_step = len(acoustic_outputs)
            #    attention_windowed_mask = \
            #        get_mask_from_lengths_window_and_time_step(
            #            memory_lengths, self.attention_window_size, time_step)
            #else:
            #    attention_windowed_mask = None

            acoustic_output, gate_output, attention_weights = self.decode(
                decoder_input, attention_windowed_mask)

            acoustic_outputs += [acoustic_output.squeeze(1)]
            gate_outputs += [gate_output.squeeze()]
            alignments += [attention_weights]

        acoustic_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
            acoustic_outputs, gate_outputs, alignments)

        return acoustic_outputs, gate_outputs, alignments
예제 #4
0
    def infer(self, memory, memory_lengths):
        """ Decoder inference
        PARAMS
        ------
        memory: Encoder outputs

        RETURNS
        -------
        mel_outputs: mel outputs from the decoder
        gate_outputs: gate outputs from the decoder
        alignments: sequence of attention weights from the decoder
        """
        decoder_input = self.get_go_frame(memory)

        if memory.size(0) > 1:
            mask =~ get_mask_from_lengths(memory_lengths)
        else:
            mask = None

        self.initialize_decoder_states(memory, mask=mask)

        mel_lengths = torch.zeros([memory.size(0)], dtype=torch.int32)
        not_finished = torch.ones([memory.size(0)], dtype=torch.int32)
        if torch.cuda.is_available():
            mel_lengths = mel_lengths.cuda()
            not_finished = not_finished.cuda()


        mel_outputs, gate_outputs, alignments = [], [], []
        while True:
            decoder_input = self.prenet(decoder_input, inference=True)
            mel_output, gate_output, alignment = self.decode(decoder_input)

            dec = torch.le(torch.sigmoid(gate_output.data),
                           self.gate_threshold).to(torch.int32).squeeze(1)

            not_finished = not_finished*dec
            mel_lengths += not_finished

            if self.early_stopping and torch.sum(not_finished) == 0:
                break

            mel_outputs += [mel_output.squeeze(1)]
            gate_outputs += [gate_output]
            alignments += [alignment]

            if len(mel_outputs) == self.max_decoder_steps:
                print("Warning! Reached max decoder steps")
                break

            decoder_input = mel_output

        mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
            mel_outputs, gate_outputs, alignments)

        return mel_outputs, gate_outputs, alignments, mel_lengths
예제 #5
0
    def parse_output(self, outputs, output_lengths=None):
        if self.mask_padding and output_lengths is not None:
            mask = ~get_mask_from_lengths(output_lengths)
            mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
            mask = mask.permute(1, 0, 2)

            outputs[0].data.masked_fill_(mask, 0.0)
            outputs[1].data.masked_fill_(mask, 0.0)
            outputs[2].data.masked_fill_(mask[:, 0, :], 1e3)  # gate energies

        return outputs
예제 #6
0
파일: model.py 프로젝트: houserjohn/res
    def parse_output(self, outputs, output_lengths):
        # type: (List[Tensor], Tensor) -> List[Tensor]
        mask = get_mask_from_lengths(output_lengths)
        mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
        mask = mask.permute(1, 0, 2)

        outputs[0].masked_fill_(mask, 0.0)
        outputs[1].masked_fill_(mask, 0.0)
        outputs[2].masked_fill_(mask[:, 0, :], 1e3)  # gate energies

        return outputs
예제 #7
0
    def parse_output(self, outputs, output_lengths=None):
        if self.mask_padding and output_lengths is not None:
            mask = ~get_mask_from_lengths(output_lengths)
            mask = mask.expand(self.n_acoustic_feat_dims, mask.size(0),
                               mask.size(1))
            mask = mask.permute(1, 0, 2)

            outputs[0].data.masked_fill_(mask, 0.0)
            outputs[1].data.masked_fill_(mask, 0.0)
            outputs[2].data.masked_fill_(mask[:, 0, :], 1e3)  # gate energies

        outputs = fp16_to_fp32(outputs) if self.fp16_run else outputs
        return outputs
예제 #8
0
    def infer(self, memory, memory_lengths):
        """ Decoder inference
        PARAMS
        ------
        memory: Encoder outputs

        RETURNS
        -------
        mel_outputs: mel outputs from the decoder
        gate_outputs: gate outputs from the decoder
        alignments: sequence of attention weights from the decoder
        """
        frame = self.get_go_frame(memory)

        mask = ~get_mask_from_lengths(memory_lengths) if memory.size(
            0) > 1 else None

        self.initialize_decoder_states(memory, mask=mask)

        mel_lengths = torch.zeros([memory.size(0)], dtype=torch.int32)
        if torch.cuda.is_available():
            mel_lengths = mel_lengths.cuda()

        mel_outputs, gate_outputs, alignments = [], [], []
        while True:
            prenet_output = self.prenet(frame, inference=True)

            mel_output, gate_output, alignment = self.decode(prenet_output)
            gate_output = torch.sigmoid(gate_output)

            finished = torch.gt(gate_output, self.gate_threshold).all(-1)
            mel_lengths += (~finished).to(torch.int32)

            if finished.all():
                break

            mel_outputs += [mel_output]
            gate_outputs += [gate_output]
            alignments += [alignment]

            if len(mel_outputs) == self.max_decoder_steps:
                print("Warning! Reached max decoder steps")
                break

            frame = mel_output[:, :self.n_mel_channels]

        return self.parse_decoder_outputs(mel_outputs, gate_outputs,
                                          alignments, mel_lengths)
예제 #9
0
    def forward(self, memory, decoder_inputs, memory_lengths):
        """ Decoder forward pass for training
        PARAMS
        ------
        memory: Encoder outputs
        decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
        memory_lengths: Encoder output lengths for attention masking.

        RETURNS
        -------
        mel_outputs: mel outputs from the decoder
        gate_outputs: gate outputs from the decoder
        alignments: sequence of attention weights from the decoder
        """

        decoder_input = self.get_go_frame(memory).unsqueeze(0)
        decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
        decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
        decoder_inputs = self.prenet(decoder_inputs)

        mask = get_mask_from_lengths(memory_lengths)
        (attention_hidden, attention_cell, decoder_hidden, decoder_cell,
         attention_weights, attention_weights_cum, attention_context,
         processed_memory) = self.initialize_decoder_states(memory)

        mel_outputs, gate_outputs, alignments = [], [], []
        while len(mel_outputs) < decoder_inputs.size(0) - 1:
            decoder_input = decoder_inputs[len(mel_outputs)]
            (mel_output, gate_output, attention_hidden, attention_cell,
             decoder_hidden, decoder_cell, attention_weights,
             attention_weights_cum, attention_context) = self.decode(
                 decoder_input, attention_hidden, attention_cell,
                 decoder_hidden, decoder_cell, attention_weights,
                 attention_weights_cum, attention_context, memory,
                 processed_memory, mask)

            mel_outputs += [mel_output.squeeze(1)]
            gate_outputs += [gate_output.squeeze()]
            alignments += [attention_weights]

        mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
            torch.stack(mel_outputs), torch.stack(gate_outputs),
            torch.stack(alignments))

        return mel_outputs, gate_outputs, alignments
예제 #10
0
def init_decoder_inputs(memory, processed_memory, memory_lengths):

    device = memory.device
    dtype = memory.dtype
    bs = memory.size(0)
    seq_len = memory.size(1)
    attention_rnn_dim = 1024
    decoder_rnn_dim = 1024
    encoder_embedding_dim = 512
    n_mel_channels = 80

    attention_hidden = torch.zeros(bs,
                                   attention_rnn_dim,
                                   device=device,
                                   dtype=dtype)
    attention_cell = torch.zeros(bs,
                                 attention_rnn_dim,
                                 device=device,
                                 dtype=dtype)
    decoder_hidden = torch.zeros(bs,
                                 decoder_rnn_dim,
                                 device=device,
                                 dtype=dtype)
    decoder_cell = torch.zeros(bs, decoder_rnn_dim, device=device, dtype=dtype)
    attention_weights = torch.zeros(bs, seq_len, device=device, dtype=dtype)
    attention_weights_cum = torch.zeros(bs,
                                        seq_len,
                                        device=device,
                                        dtype=dtype)
    attention_context = torch.zeros(bs,
                                    encoder_embedding_dim,
                                    device=device,
                                    dtype=dtype)
    mask = get_mask_from_lengths(memory_lengths).to(device)
    decoder_input = torch.zeros(bs, n_mel_channels, device=device, dtype=dtype)

    return (decoder_input, attention_hidden, attention_cell, decoder_hidden,
            decoder_cell, attention_weights, attention_weights_cum,
            attention_context, memory, processed_memory, mask)
예제 #11
0
    def parse_output(self, outputs: list,
                     output_lengths: torch.Tensor) -> list:
        """Fill `mel_outputs`, `mel_outputs_postnet`, `gate_outputs` with 0.0/1e3 at tails.

        Args:
            outputs (list): [mel_outputs, mel_outputs_postnet, gate_outputs, alignments]
            output_lengths (torch.Tensor): [description]

        Returns:
            [list]: `outputs` filled with 0.0/1e3 according to mask.
        """
        # type: (List[Tensor], Tensor) -> List[Tensor]
        if self.mask_padding and output_lengths is not None:
            mask = get_mask_from_lengths(output_lengths)
            mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
            mask = mask.permute(1, 0, 2)

            outputs[0].masked_fill_(mask, 0.0)
            outputs[1].masked_fill_(mask, 0.0)
            outputs[2].masked_fill_(mask[:, 0, :], 1e3)  # gate energies

        return outputs
def init_decoder_inputs(memory, processed_memory, memory_lengths):

    bs = memory.size(0)
    seq_len = memory.size(1)
    attention_rnn_dim = 1024
    decoder_rnn_dim = 1024
    encoder_embedding_dim = 512
    n_mel_channels = 80

    attention_hidden = torch.zeros(bs, attention_rnn_dim).cuda().float()
    attention_cell = torch.zeros(bs, attention_rnn_dim).cuda().float()
    decoder_hidden = torch.zeros(bs, decoder_rnn_dim).cuda().float()
    decoder_cell = torch.zeros(bs, decoder_rnn_dim).cuda().float()
    attention_weights = torch.zeros(bs, seq_len).cuda().float()
    attention_weights_cum = torch.zeros(bs, seq_len).cuda().float()
    attention_context = torch.zeros(bs, encoder_embedding_dim).cuda().float()
    mask = get_mask_from_lengths(memory_lengths).cuda()
    decoder_input = torch.zeros(bs, n_mel_channels).cuda().float()

    return (decoder_input, attention_hidden, attention_cell, decoder_hidden,
            decoder_cell, attention_weights, attention_weights_cum,
            attention_context, memory, processed_memory, mask)
예제 #13
0
def main():

    parser = argparse.ArgumentParser(
        description='PyTorch Tacotron 2 export to TRT')
    parser = parse_args(parser)
    args, _ = parser.parse_known_args()

    tacotron2 = load_and_setup_model('Tacotron2', parser, args.tacotron2,
                                     fp16_run=args.fp16, cpu_run=False)

    opset_version = 10

    sequences = torch.randint(low=0, high=148, size=(1,50),
                             dtype=torch.long).cuda()
    sequence_lengths = torch.IntTensor([sequences.size(1)]).cuda().long()
    dummy_input = (sequences, sequence_lengths)

    encoder = Encoder(tacotron2)
    encoder.eval()
    with torch.no_grad():
        encoder(*dummy_input)

    torch.onnx.export(encoder, dummy_input, args.output+"/"+"encoder.onnx",
                      opset_version=opset_version,
                      do_constant_folding=True,
                      input_names=["sequences", "sequence_lengths"],
                      output_names=["memory", "processed_memory", "lens"],
                      dynamic_axes={"sequences": {1: "text_seq"},
                                    "memory": {1: "mem_seq"},
                                    "processed_memory": {1: "mem_seq"}
                      })

    decoder_iter = DecoderIter(tacotron2)
    memory = torch.randn((1,sequence_lengths[0],512)).cuda() #encoder_outputs
    if args.fp16:
        memory = memory.half()
    memory_lengths = sequence_lengths
    # initialize decoder states for dummy_input
    decoder_input = tacotron2.decoder.get_go_frame(memory)
    mask = get_mask_from_lengths(memory_lengths)
    (attention_hidden,
     attention_cell,
     decoder_hidden,
     decoder_cell,
     attention_weights,
     attention_weights_cum,
     attention_context,
     processed_memory) = tacotron2.decoder.initialize_decoder_states(memory)
    dummy_input = (decoder_input,
                   attention_hidden,
                   attention_cell,
                   decoder_hidden,
                   decoder_cell,
                   attention_weights,
                   attention_weights_cum,
                   attention_context,
                   memory,
                   processed_memory,
                   mask)

    decoder_iter = DecoderIter(tacotron2)
    decoder_iter.eval()
    with torch.no_grad():
        decoder_iter(*dummy_input)

    torch.onnx.export(decoder_iter, dummy_input, args.output+"/"+"decoder_iter.onnx",
                      opset_version=opset_version,
                      do_constant_folding=True,
                      input_names=["decoder_input",
                                   "attention_hidden",
                                   "attention_cell",
                                   "decoder_hidden",
                                   "decoder_cell",
                                   "attention_weights",
                                   "attention_weights_cum",
                                   "attention_context",
                                   "memory",
                                   "processed_memory",
                                   "mask"],
                      output_names=["decoder_output",
                                    "gate_prediction",
                                    "out_attention_hidden",
                                    "out_attention_cell",
                                    "out_decoder_hidden",
                                    "out_decoder_cell",
                                    "out_attention_weights",
                                    "out_attention_weights_cum",
                                    "out_attention_context"],
                      dynamic_axes={"attention_weights" : {1: "seq_len"},
                                    "attention_weights_cum" : {1: "seq_len"},
                                    "memory" : {1: "seq_len"},
                                    "processed_memory" : {1: "seq_len"},
                                    "mask" : {1: "seq_len"},
                                    "out_attention_weights" : {1: "seq_len"},
                                    "out_attention_weights_cum" : {1: "seq_len"}
                      })

    postnet = Postnet(tacotron2)
    dummy_input = torch.randn((1,80,620)).cuda()
    if args.fp16:
        dummy_input = dummy_input.half()
    torch.onnx.export(postnet, dummy_input, args.output+"/"+"postnet.onnx",
                      opset_version=opset_version,
                      do_constant_folding=True,
                      input_names=["mel_outputs"],
                      output_names=["mel_outputs_postnet"],
                      dynamic_axes={"mel_outputs": {2: "mel_seq"},
                                    "mel_outputs_postnet": {2: "mel_seq"}})

    mel = test_inference(encoder, decoder_iter, postnet)
    torch.save(mel, "mel.pt")
예제 #14
0
    def infer(self, memory, memory_lengths):
        """ Decoder inference
        PARAMS
        ------
        memory: Encoder outputs

        RETURNS
        -------
        mel_outputs: mel outputs from the decoder
        gate_outputs: gate outputs from the decoder
        alignments: sequence of attention weights from the decoder
        """
        decoder_input = self.get_go_frame(memory)

        mask = get_mask_from_lengths(memory_lengths)
        (attention_hidden, attention_cell, decoder_hidden, decoder_cell,
         attention_weights, attention_weights_cum, attention_context,
         processed_memory) = self.initialize_decoder_states(memory)

        mel_lengths = torch.zeros([memory.size(0)], dtype=torch.int32)
        not_finished = torch.ones([memory.size(0)], dtype=torch.int32)

        mel_outputs, gate_outputs, alignments = (torch.zeros(1),
                                                 torch.zeros(1),
                                                 torch.zeros(1))
        first_iter = True
        while True:
            decoder_input = self.prenet(decoder_input)
            (mel_output, gate_output, attention_hidden, attention_cell,
             decoder_hidden, decoder_cell, attention_weights,
             attention_weights_cum, attention_context) = self.decode(
                 decoder_input, attention_hidden, attention_cell,
                 decoder_hidden, decoder_cell, attention_weights,
                 attention_weights_cum, attention_context, memory,
                 processed_memory, mask)

            if first_iter:
                mel_outputs = mel_output.unsqueeze(0)
                gate_outputs = gate_output
                alignments = attention_weights
                first_iter = False
            else:
                mel_outputs = torch.cat((mel_outputs, mel_output.unsqueeze(0)),
                                        dim=0)
                gate_outputs = torch.cat((gate_outputs, gate_output), dim=0)
                alignments = torch.cat((alignments, attention_weights), dim=0)

            dec = torch.le(torch.sigmoid(gate_output),
                           self.gate_threshold).to(torch.int32).squeeze(1)

            not_finished = not_finished * dec
            mel_lengths += not_finished

            if self.early_stopping and torch.sum(not_finished) == 0:
                break
            if len(mel_outputs) == self.max_decoder_steps:
                print("Warning! Reached max decoder steps")
                break

            decoder_input = mel_output

        mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
            mel_outputs, gate_outputs, alignments)

        return mel_outputs, gate_outputs, alignments, mel_lengths