def __init__( self, n_mel_channels: int, n_frames_per_step: int, encoder_embedding_dim: int, attention_dim: int, attention_location_n_filters: int, attention_location_kernel_size: int, attention_rnn_dim: int, decoder_rnn_dim: int, prenet_dim: int, max_decoder_steps: int, gate_threshold: float, p_attention_dropout: float, p_decoder_dropout: float, early_stopping: bool, prenet_p_dropout: float = 0.5, ): """ Tacotron 2 Decoder. Consists of a 2 layer LSTM, one of which interfaces with the attention mechanism while the other is used as a regular LSTM. Includes the prenet and attention modules as well. Args: n_mel_channels (int): Number of mel channels to output n_frames_per_step (int): Number of spectrogram frames to predict per decoder step. encoder_embedding_dim (int): The size of the output from the encoder. attention_dim (int): The output dimension of the attention layer. attention_location_n_filters (int): Channel size for the convolution used the attention mechanism. attention_location_kernel_size (int): Kernel size for the convolution used the attention mechanism. attention_rnn_dim (int): The output dimension of the attention LSTM layer. decoder_rnn_dim (int): The output dimension of the second LSTM layer. prenet_dim (int): The output dimension of the prenet. max_decoder_steps (int): For evaluation, the max number of steps to predict. gate_threshold (float): At each step, tacotron 2 predicts a probability of stopping. Rather than sampling, this module checks if predicted probability is above the gate_threshold. Only in evaluation. p_attention_dropout (float): Dropout probability on the attention LSTM. p_decoder_dropout (float): Dropout probability on the second LSTM. early_stopping (bool): In evaluation mode, whether to stop when all batches hit the gate_threshold or to continue until max_decoder_steps. prenet_p_dropout (float): Dropout probability for prenet. Note, dropout is on even in eval() mode. Defaults to 0.5. """ super().__init__() self.n_mel_channels = n_mel_channels self.n_frames_per_step = n_frames_per_step self.encoder_embedding_dim = encoder_embedding_dim self.attention_rnn_dim = attention_rnn_dim self.decoder_rnn_dim = decoder_rnn_dim self.prenet_dim = prenet_dim self.max_decoder_steps = max_decoder_steps self.gate_threshold = gate_threshold self.p_attention_dropout = p_attention_dropout self.p_decoder_dropout = p_decoder_dropout self.early_stopping = early_stopping self.prenet = Prenet(n_mel_channels * n_frames_per_step, [prenet_dim, prenet_dim], prenet_p_dropout) self.attention_rnn = torch.nn.LSTMCell(prenet_dim + encoder_embedding_dim, attention_rnn_dim) self.attention_layer = Attention( attention_rnn_dim, encoder_embedding_dim, attention_dim, attention_location_n_filters, attention_location_kernel_size, ) self.decoder_rnn = torch.nn.LSTMCell(attention_rnn_dim + encoder_embedding_dim, decoder_rnn_dim, 1) self.linear_projection = LinearNorm( decoder_rnn_dim + encoder_embedding_dim, n_mel_channels * n_frames_per_step ) self.gate_layer = LinearNorm(decoder_rnn_dim + encoder_embedding_dim, 1, bias=True, w_init_gain='sigmoid')
class Decoder(NeuralModule): def __init__( self, n_mel_channels: int, n_frames_per_step: int, encoder_embedding_dim: int, attention_dim: int, attention_location_n_filters: int, attention_location_kernel_size: int, attention_rnn_dim: int, decoder_rnn_dim: int, prenet_dim: int, max_decoder_steps: int, gate_threshold: float, p_attention_dropout: float, p_decoder_dropout: float, early_stopping: bool, prenet_p_dropout: float = 0.5, ): """ Tacotron 2 Decoder. Consists of a 2 layer LSTM, one of which interfaces with the attention mechanism while the other is used as a regular LSTM. Includes the prenet and attention modules as well. Args: n_mel_channels (int): Number of mel channels to output n_frames_per_step (int): Number of spectrogram frames to predict per decoder step. encoder_embedding_dim (int): The size of the output from the encoder. attention_dim (int): The output dimension of the attention layer. attention_location_n_filters (int): Channel size for the convolution used the attention mechanism. attention_location_kernel_size (int): Kernel size for the convolution used the attention mechanism. attention_rnn_dim (int): The output dimension of the attention LSTM layer. decoder_rnn_dim (int): The output dimension of the second LSTM layer. prenet_dim (int): The output dimension of the prenet. max_decoder_steps (int): For evaluation, the max number of steps to predict. gate_threshold (float): At each step, tacotron 2 predicts a probability of stopping. Rather than sampling, this module checks if predicted probability is above the gate_threshold. Only in evaluation. p_attention_dropout (float): Dropout probability on the attention LSTM. p_decoder_dropout (float): Dropout probability on the second LSTM. early_stopping (bool): In evaluation mode, whether to stop when all batches hit the gate_threshold or to continue until max_decoder_steps. prenet_p_dropout (float): Dropout probability for prenet. Note, dropout is on even in eval() mode. Defaults to 0.5. """ super().__init__() self.n_mel_channels = n_mel_channels self.n_frames_per_step = n_frames_per_step self.encoder_embedding_dim = encoder_embedding_dim self.attention_rnn_dim = attention_rnn_dim self.decoder_rnn_dim = decoder_rnn_dim self.prenet_dim = prenet_dim self.max_decoder_steps = max_decoder_steps self.gate_threshold = gate_threshold self.p_attention_dropout = p_attention_dropout self.p_decoder_dropout = p_decoder_dropout self.early_stopping = early_stopping self.prenet = Prenet(n_mel_channels * n_frames_per_step, [prenet_dim, prenet_dim], prenet_p_dropout) self.attention_rnn = torch.nn.LSTMCell(prenet_dim + encoder_embedding_dim, attention_rnn_dim) self.attention_layer = Attention( attention_rnn_dim, encoder_embedding_dim, attention_dim, attention_location_n_filters, attention_location_kernel_size, ) self.decoder_rnn = torch.nn.LSTMCell(attention_rnn_dim + encoder_embedding_dim, decoder_rnn_dim, 1) self.linear_projection = LinearNorm( decoder_rnn_dim + encoder_embedding_dim, n_mel_channels * n_frames_per_step ) self.gate_layer = LinearNorm(decoder_rnn_dim + encoder_embedding_dim, 1, bias=True, w_init_gain='sigmoid') @property def input_types(self): input_dict = { "memory": NeuralType(('B', 'T', 'D'), EmbeddedTextType()), "memory_lengths": NeuralType(('B'), LengthsType()), } if self.training: input_dict["decoder_inputs"] = NeuralType(('B', 'D', 'T'), MelSpectrogramType()) return input_dict @property def output_types(self): output_dict = { "mel_outputs": NeuralType(('B', 'D', 'T'), MelSpectrogramType()), "gate_outputs": NeuralType(('B', 'T'), LogitsType()), "alignments": NeuralType(('B', 'T', 'T'), SequenceToSequenceAlignmentType()), } if not self.training: output_dict["mel_lengths"] = NeuralType(('B'), LengthsType()) return output_dict @typecheck() def forward(self, *args, **kwargs): if self.training: return self.train_forward(**kwargs) return self.infer(**kwargs) def get_go_frame(self, memory): B = memory.size(0) decoder_input = Variable(memory.data.new(B, self.n_mel_channels * self.n_frames_per_step).zero_()) return decoder_input def initialize_decoder_states(self, memory, mask): B = memory.size(0) MAX_TIME = memory.size(1) self.attention_hidden = Variable(memory.data.new(B, self.attention_rnn_dim).zero_()) self.attention_cell = Variable(memory.data.new(B, self.attention_rnn_dim).zero_()) self.decoder_hidden = Variable(memory.data.new(B, self.decoder_rnn_dim).zero_()) self.decoder_cell = Variable(memory.data.new(B, self.decoder_rnn_dim).zero_()) self.attention_weights = Variable(memory.data.new(B, MAX_TIME).zero_()) self.attention_weights_cum = Variable(memory.data.new(B, MAX_TIME).zero_()) self.attention_context = Variable(memory.data.new(B, self.encoder_embedding_dim).zero_()) self.memory = memory self.processed_memory = self.attention_layer.memory_layer(memory) self.mask = mask def parse_decoder_inputs(self, decoder_inputs): # (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels) decoder_inputs = decoder_inputs.transpose(1, 2) decoder_inputs = decoder_inputs.view( decoder_inputs.size(0), int(decoder_inputs.size(1) / self.n_frames_per_step), -1, ) # (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels) decoder_inputs = decoder_inputs.transpose(0, 1) return decoder_inputs def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments): # (T_out, B) -> (B, T_out) alignments = torch.stack(alignments).transpose(0, 1) # (T_out, B) -> (B, T_out) # Add a -1 to prevent squeezing the batch dimension in case # batch is 1 gate_outputs = torch.stack(gate_outputs).squeeze(-1).transpose(0, 1) gate_outputs = gate_outputs.contiguous() # (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels) mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous() # decouple frames per step mel_outputs = mel_outputs.view(mel_outputs.size(0), -1, self.n_mel_channels) # (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out) mel_outputs = mel_outputs.transpose(1, 2) return mel_outputs, gate_outputs, alignments def decode(self, decoder_input): cell_input = torch.cat((decoder_input, self.attention_context), -1) # TODO: Pytorch 1.6 has issues with rnns and amp, so cast to float until fixed if _NATIVE_AMP: with autocast(enabled=False): self.attention_hidden, self.attention_cell = self.attention_rnn( cell_input.float(), (self.attention_hidden, self.attention_cell) ) else: self.attention_hidden, self.attention_cell = self.attention_rnn( cell_input, (self.attention_hidden, self.attention_cell) ) self.attention_hidden = F.dropout(self.attention_hidden, self.p_attention_dropout, self.training) attention_weights_cat = torch.cat( (self.attention_weights.unsqueeze(1), self.attention_weights_cum.unsqueeze(1)), dim=1, ) self.attention_context, self.attention_weights = self.attention_layer( self.attention_hidden, self.memory, self.processed_memory, attention_weights_cat, self.mask, ) self.attention_weights_cum += self.attention_weights decoder_input = torch.cat((self.attention_hidden, self.attention_context), -1) # TODO: Pytorch 1.6 has issues with rnns and amp, so cast to float until fixed if _NATIVE_AMP: with autocast(enabled=False): self.decoder_hidden, self.decoder_cell = self.decoder_rnn( decoder_input, (self.decoder_hidden, self.decoder_cell) ) else: self.decoder_hidden, self.decoder_cell = self.decoder_rnn( decoder_input, (self.decoder_hidden, self.decoder_cell) ) self.decoder_hidden = F.dropout(self.decoder_hidden, self.p_decoder_dropout, self.training) decoder_hidden_attention_context = torch.cat((self.decoder_hidden, self.attention_context), dim=1) decoder_output = self.linear_projection(decoder_hidden_attention_context) gate_prediction = self.gate_layer(decoder_hidden_attention_context) return decoder_output, gate_prediction, self.attention_weights def train_forward(self, *, memory, decoder_inputs, memory_lengths): decoder_input = self.get_go_frame(memory).unsqueeze(0) decoder_inputs = self.parse_decoder_inputs(decoder_inputs) decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0) decoder_inputs = self.prenet(decoder_inputs) self.initialize_decoder_states(memory, mask=~get_mask_from_lengths(memory_lengths)) mel_outputs, gate_outputs, alignments = [], [], [] while len(mel_outputs) < decoder_inputs.size(0) - 1: decoder_input = decoder_inputs[len(mel_outputs)] mel_output, gate_output, attention_weights = self.decode(decoder_input) mel_outputs += [mel_output.squeeze(1)] gate_outputs += [gate_output.squeeze()] alignments += [attention_weights] mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(mel_outputs, gate_outputs, alignments) return mel_outputs, gate_outputs, alignments def infer(self, *, memory, memory_lengths): decoder_input = self.get_go_frame(memory) if memory.size(0) > 1: mask = ~get_mask_from_lengths(memory_lengths) else: mask = None self.initialize_decoder_states(memory, mask=mask) mel_lengths = torch.zeros([memory.size(0)], dtype=torch.int32) not_finished = torch.ones([memory.size(0)], dtype=torch.int32) if torch.cuda.is_available(): mel_lengths = mel_lengths.cuda() not_finished = not_finished.cuda() mel_outputs, gate_outputs, alignments = [], [], [] stepped = False while True: decoder_input = self.prenet(decoder_input, inference=True) mel_output, gate_output, alignment = self.decode(decoder_input) dec = torch.le(torch.sigmoid(gate_output.data), self.gate_threshold).to(torch.int32).squeeze(1) not_finished = not_finished * dec mel_lengths += not_finished if self.early_stopping and torch.sum(not_finished) == 0 and stepped: break stepped = True mel_outputs += [mel_output.squeeze(1)] gate_outputs += [gate_output] alignments += [alignment] if len(mel_outputs) == self.max_decoder_steps: logging.warning("Reached max decoder steps %d.", self.max_decoder_steps) break decoder_input = mel_output mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(mel_outputs, gate_outputs, alignments) return mel_outputs, gate_outputs, alignments, mel_lengths def save_to(self, save_path: str): # TODO: Implement me! pass @classmethod def restore_from(cls, restore_path: str): # TODO: Implement me! pass