def forward(self, src_seq, src_len, mel_len=None, d_target=None, p_target=None, e_target=None, max_src_len=None, max_mel_len=None, d_control=1.0, p_control=1.0, e_control=1.0): src_mask = get_mask_from_lengths(src_len, max_src_len) mel_mask = get_mask_from_lengths( mel_len, max_mel_len) if mel_len is not None else None encoder_output = self.encoder(src_seq, src_mask) if d_target is not None: variance_adaptor_output, d_prediction, p_prediction, e_prediction, _, _ = self.variance_adaptor( encoder_output, src_mask, mel_mask, d_target, p_target, e_target, max_mel_len, d_control, p_control, e_control) else: variance_adaptor_output, d_prediction, p_prediction, e_prediction, mel_len, mel_mask = self.variance_adaptor( encoder_output, src_mask, mel_mask, d_target, p_target, e_target, max_mel_len, d_control, p_control, e_control) decoder_output = self.decoder(variance_adaptor_output, mel_mask) mel_output = self.mel_linear(decoder_output) if self.use_postnet: mel_output_postnet = self.postnet(mel_output) + mel_output else: mel_output_postnet = mel_output return mel_output, mel_output_postnet, d_prediction, p_prediction, e_prediction, src_mask, mel_mask, mel_len
def forward(self, src_seq, src_len, mel_len=None, d_target=None, p_target=None, e_target=None, max_src_len=None, max_mel_len=None, speaker_ids=None): src_mask = get_mask_from_lengths(src_len, max_src_len) mel_mask = get_mask_from_lengths(mel_len, max_mel_len) if mel_len is not None else None encoder_output = self.encoder(src_seq, src_mask) if self.use_spk_embed and speaker_ids is not None: spk_embed = self.embed_speakers(speaker_ids) encoder_output = self.speaker_integrator(encoder_output, spk_embed) if d_target is not None: variance_adaptor_output, d_prediction, p_prediction, e_prediction, _, _ = self.variance_adaptor( encoder_output, src_mask, mel_mask, d_target, p_target, e_target, max_mel_len) else: variance_adaptor_output, d_prediction, p_prediction, e_prediction, mel_len, mel_mask = self.variance_adaptor( encoder_output, src_mask, mel_mask, d_target, p_target, e_target, max_mel_len) if self.use_spk_embed and speaker_ids is not None: variance_adaptor_output = self.speaker_integrator(variance_adaptor_output, spk_embed) decoder_output = self.decoder(variance_adaptor_output, mel_mask) mel_output = self.mel_linear(decoder_output) if self.use_postnet: mel_output_postnet = self.postnet(mel_output) + mel_output else: mel_output_postnet = mel_output return mel_output, mel_output_postnet, d_prediction, p_prediction, e_prediction, src_mask, mel_mask, mel_len
def forward(self, src_seq, src_len, mel_len=None, max_src_len=None, max_mel_len=None): src_mask = get_mask_from_lengths(src_len, max_src_len) mel_mask = get_mask_from_lengths( mel_len, max_mel_len) if mel_len is not None else None encoder_output, enc_attns = self.encoder(src_seq, src_mask, True) variance_adaptor_output, d_prediction, W, pred_mel_mask = self.variance_adaptor( encoder_output, src_mask, max_mel_len) decoder_output, dec_attns, pred_mel_mask = self.decoder( variance_adaptor_output, pred_mel_mask, True) mel_output = self.mel_linear(decoder_output) if self.use_postnet: mel_output_postnet = self.postnet(mel_output) + mel_output else: mel_output_postnet = mel_output return mel_output, mel_output_postnet, d_prediction, src_mask, pred_mel_mask, enc_attns, dec_attns, W
def forward(self, src_seq, src_len, mel_len=None, d_target=None, p_target=None, e_target=None, max_src_len=None, max_mel_len=None): # mask[false, false, false, true, true], 前面是src_len个false,后面max_src_len个true,无用的信息是true src_mask = get_mask_from_lengths(src_len, max_src_len) mel_mask = get_mask_from_lengths( mel_len, max_mel_len) if mel_len is not None else None # encoder_output <16, 110, 256> encoder_output = self.encoder(src_seq, src_mask) if d_target is not None: variance_adaptor_output, d_prediction, p_prediction, e_prediction, _, _ = self.variance_adaptor( encoder_output, src_mask, mel_mask, d_target, p_target, e_target, max_mel_len) else: variance_adaptor_output, d_prediction, p_prediction, e_prediction, mel_len, mel_mask = self.variance_adaptor( encoder_output, src_mask, mel_mask, d_target, p_target, e_target, max_mel_len) # variance_adaptor_output, decoder_output <16, 965, 256> decoder_output = self.decoder(variance_adaptor_output, mel_mask) # mel_output <16, 965, 80> mel_output = self.mel_linear(decoder_output) if self.use_postnet: mel_output_postnet = self.postnet(mel_output) + mel_output else: mel_output_postnet = mel_output return mel_output, mel_output_postnet, d_prediction, p_prediction, e_prediction, src_mask, mel_mask, mel_len
def forward(self, embedded_text, text_lengths, mels, mels_lengths): embedded_prosody, _ = self.encoder(mels) # Bottleneck embedded_prosody = self.encoder_bottleneck(embedded_prosody) # Obtain k and v from prosody embedding key, value = torch.split(embedded_prosody, self.prosody_embedding_dim, dim=-1) # [N, Ty, prosody_embedding_dim] * 2 # Get attention mask text_mask = get_mask_from_lengths(text_lengths).float().unsqueeze( -1) # [B, seq_len, 1] mels_mask = get_mask_from_lengths(mels_lengths).float().unsqueeze( -1) # [B, req_len, 1] attn_mask = torch.bmm(text_mask, mels_mask.transpose(-2, -1)) # [N, seq_len, ref_len] # Attention style_embed, alignments = self.ref_attn(embedded_text, key, value, attn_mask) # Apply ReLU as the activation function to force the values of the prosody embedding to lie in [0, ∞]. style_embed = F.relu(style_embed) return style_embed, alignments
def forward(self, src_seq, ref_mel, src_len, mel_len=None, d_target=None, p_target=None, e_target=None, max_src_len=None, max_mel_len=None): src_mask = get_mask_from_lengths(src_len, max_src_len) mel_mask = get_mask_from_lengths( mel_len, max_mel_len) if mel_len is not None else None if hp.vocoder == 'WORLD': ap_mask = get_mask_from_lengths( mel_len, max_mel_len) if mel_len is not None else None sp_mask = get_mask_from_lengths( mel_len, max_mel_len) if mel_len is not None else None # print(src_seq) encoder_output = self.encoder(src_seq, src_mask) # style_embed = self.gst(ref_mel) # [N, 256] # style_embed = style_embed.expand_as(encoder_output) # encoder_output= encoder_output+style_embed encoder_output = encoder_output variance_adaptor_output, d_prediction, p_prediction, e_prediction, mel_len, mel_mask = self.variance_adaptor( encoder_output, src_mask, mel_mask, d_target, p_target, e_target, max_mel_len) decoder_output = self.decoder(variance_adaptor_output, mel_mask) # if hp.vocoder=='WORLD': # f0_decoder_output = self.f0_decoder(variance_adaptor_output, mel_mask) if hp.vocoder == 'WORLD': ap_output = self.ap_linear(decoder_output) sp_output = self.sp_linear(decoder_output) if self.use_postnet: sp_output_postnet = self.postnet(sp_output) + sp_output else: sp_output_postnet = sp_output return ap_output, sp_output, sp_output_postnet, d_prediction, p_prediction, e_prediction, src_mask, ap_mask, sp_mask else: mel_output = self.mel_linear(decoder_output) if self.use_postnet: mel_output_postnet = self.postnet(mel_output) + mel_output else: mel_output_postnet = mel_output return mel_output, mel_output_postnet, d_prediction, p_prediction, e_prediction, src_mask, mel_mask, mel_len
def forward(self, src_seq, src_len, mel_len=None, d_target=None, p_target=None, e_target=None, max_src_len=None, max_mel_len=None): # print(src_seq.shape) # print(src_len.shape) src_mask = get_mask_from_lengths(src_len, max_src_len) # print(src_mask.shape) mel_mask = get_mask_from_lengths(mel_len, max_mel_len) if mel_len is not None else None if hp.vocoder=='WORLD': ap_mask = get_mask_from_lengths(mel_len, max_mel_len) if mel_len is not None else None sp_mask = get_mask_from_lengths(mel_len, max_mel_len) if mel_len is not None else None # print(src_seq) encoder_output = self.encoder(src_seq, src_mask) # style_embed = self.gst(ref_mel) # [N, 256] # style_embed = style_embed.expand_as(encoder_output) # encoder_output= encoder_output+style_embed encoder_output= encoder_output variance_adaptor_output, d_prediction, p_prediction, e_prediction, mel_len, mel_mask = self.variance_adaptor( encoder_output, src_seq, src_mask, mel_mask, d_target, p_target, e_target, max_mel_len) # print( variance_adaptor_output.shape) # plt.matshow( variance_adaptor_output[0].detach().cpu().numpy()) # plt.savefig('variance_adaptor_output.png') # plt.cla() # print(mel_mask) decoder_output = self.decoder(variance_adaptor_output, mel_mask) # print(sp_mask[0]) # if hp.vocoder=='WORLD': # f0_decoder_output = self.f0_decoder(variance_adaptor_output, mel_mask) if hp.vocoder=='WORLD': ap_output = self.ap_linear(decoder_output) sp_output = self.sp_linear(decoder_output) if self.use_postnet: sp_output_postnet = self.postnet(sp_output) + sp_output else: sp_output_postnet = sp_output return ap_output, sp_output, sp_output_postnet, d_prediction, p_prediction, e_prediction, src_mask, ap_mask, sp_mask,variance_adaptor_output,decoder_output else: mel_output = self.mel_linear(decoder_output) if self.use_postnet: mel_output_postnet = self.postnet(mel_output) + mel_output else: mel_output_postnet = mel_output return mel_output, mel_output_postnet, d_prediction, p_prediction, e_prediction, src_mask, mel_mask, mel_len
def forward(self, attention_hidden_state, memory, processed_memory, attention_weights_cat, mask, mel_iter, duration=None): """ PARAMS ------ attention_hidden_state: attention rnn last output memory: encoder outputs processed_memory: processed encoder outputs attention_weights_cat: previous and cummulative attention weights mask: binary mask for padded data """ if memory.size(0) == 1: # batch = 1 일때, (합성할때) if duration != None: mel_per = round(duration / memory.size(1) + 0.5) A = 400 B = duration - A + 100 a = [memory.size(1), mel_iter // mel_per + A // 2] b = [memory.size(1), mel_iter // mel_per - A // 2] c = [memory.size(1), B // mel_per] else: A = 500 B = memory.size(1) * 14 // 3 a = [memory.size(1), mel_iter // 6 + 250] b = [memory.size(1), mel_iter // 6 - 250] c = [memory.size(1), 500] a = torch.tensor(a).cuda() b = torch.tensor(b).cuda() c = torch.tensor(c).cuda() if memory.size(1) < A: # 짧은문장 alignment = self.get_alignment_energies( attention_hidden_state, processed_memory, attention_weights_cat, mel_iter) else: if mel_iter < A: sw_mask = ~get_mask_from_lengths(a)[1:] # else: elif mel_iter >= A and mel_iter < B: sw_mask = get_mask_from_lengths2(a)[1:] ^ ~get_mask_from_lengths2(b)[1:] else: sw_mask = get_mask_from_lengths2(b)[1:] alignment = self.get_alignment_energies( attention_hidden_state, processed_memory, attention_weights_cat, mel_iter) alignment.data.masked_fill_(sw_mask, self.score_mask_value) else: # batch 2 이상 : 학습할때 alignment = self.get_alignment_energies( attention_hidden_state, processed_memory, attention_weights_cat, mel_iter) if mask is not None: alignment.data.masked_fill_(mask, self.score_mask_value) attention_weights = F.softmax(alignment, dim=1) attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) attention_context = attention_context.squeeze(1) return attention_context, attention_weights
def forward(self, memory, decoder_inputs, memory_lengths): """ Decoder forward pass for training PARAMS ------ memory: Encoder outputs decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs memory_lengths: Encoder output lengths for attention masking. RETURNS ------- mel_outputs: mel outputs from the decoder gate_outputs: gate outputs from the decoder alignments: sequence of attention weights from the decoder """ decoder_input = self.get_go_frame(memory).unsqueeze(0) decoder_inputs = self.parse_decoder_inputs(decoder_inputs) decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0) decoder_inputs = self.prenet(decoder_inputs) # audio features # f0_dummy = self.get_end_f0(f0s) # f0s = torch.cat((f0s, f0_dummy), dim=2) # f0s = F.relu(self.prenet_f0(f0s)) # f0s = f0s.permute(2, 0, 1) f0 = memory.new_zeros( memory.size(0), # batch_size self.f0_size # dummy f0 for using pretrained model ) self.initialize_decoder_states( memory, mask=~get_mask_from_lengths(memory_lengths)) mel_outputs, gate_outputs, alignments = [], [], [] while len(mel_outputs) < decoder_inputs.size(0) - 1: # if len(mel_outputs) == 0 or np.random.uniform(0.0, 1.0) <= self.p_teacher_forcing: # decoder_input = torch.cat((decoder_inputs[len(mel_outputs)], # f0s[len(mel_outputs)]), dim=1) # else: # decoder_input = torch.cat((self.prenet(mel_outputs[-1]), # f0s[len(mel_outputs)]), dim=1) if len(mel_outputs) == 0 or np.random.uniform( 0.0, 1.0) < self.p_teacher_forcing: decoder_input = torch.cat( (decoder_inputs[len(mel_outputs)], f0), dim=1) else: decoder_input = torch.cat((self.prenet(mel_outputs[-1]), f0), dim=1) mel_output, gate_output, attention_weights = self.decode( decoder_input) mel_outputs += [mel_output.squeeze(1)] gate_outputs += [gate_output.squeeze()] alignments += [attention_weights] mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( mel_outputs, gate_outputs, alignments) return mel_outputs, gate_outputs, alignments
def forward(self, inputs): text, text_lengths, mels, mel_lengths = inputs encoder_outputs = self.encoder(text, text_lengths) if self.speaker_encoder is not None: fragments = mels.unfold(1, self.n_fragment_mel_windows, self.n_fragment_mel_windows // 2).transpose(2, 3) fragment_counts = mel_lengths // (self.n_fragment_mel_windows // 2) - 1 fragments = torch.cat([f[:fc] for f, fc in zip(fragments, fragment_counts)]) speaker_embeddings = self.speaker_encoder.inference(fragments, fragment_counts.tolist()) encoder_outputs = torch.cat([encoder_outputs, speaker_embeddings.unsqueeze(1).repeat(1, encoder_outputs.size(1), 1)], dim=2) mel_outputs, gate_outputs, alignments = self.decoder(encoder_outputs, text_lengths, mels) mel_outputs_postnet = self.postnet(mel_outputs) mel_outputs_postnet = mel_outputs + mel_outputs_postnet if mel_lengths is not None: mask = ~get_mask_from_lengths(mel_lengths) mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1)) mask = mask.permute(1, 2, 0) mel_outputs.data.masked_fill_(mask, 0.0) mel_outputs_postnet.data.masked_fill_(mask, 0.0) gate_outputs.data.masked_fill_(mask[:, :, 0], 1e3) # gate energies return mel_outputs, mel_outputs_postnet, gate_outputs, alignments
def forward(self, x, src_mask, mel_mask=None, duration_target=None, pitch_target=None, energy_target=None, max_len=None): log_duration_prediction = self.duration_predictor(x, src_mask) pitch_prediction = self.pitch_predictor(x, src_mask) pitch_embedding = self.pitch_embedding_producer( pitch_prediction.unsqueeze(2)) energy_prediction = self.energy_predictor(x, src_mask) energy_embedding = self.energy_embedding_producer( energy_prediction.unsqueeze(2)) x = x + pitch_embedding + energy_embedding if duration_target is not None: x, mel_len = self.length_regulator(x, duration_target, max_len) else: duration_rounded = torch.clamp(torch.round( torch.exp(log_duration_prediction) - hp.log_offset), min=0) x, mel_len = self.length_regulator(x, duration_rounded, max_len) mel_mask = utils.get_mask_from_lengths(mel_len) return x, log_duration_prediction, pitch_prediction, energy_prediction, mel_len, mel_mask
def forward(self, memory, decoder_inputs, memory_lengths): # memory : (B, Seq_len, 512) --> encoder outputs # decoder_inputs : (B, Mel_Channels : 80, frames) # memory lengths : (B) decoder_input = self.get_go_frame(memory).unsqueeze(0) # print('go frames : ', decoder_input.size()) # print('decoder inputs : ', decoder_inputs.size()) decoder_inputs = self.parse_decoder_inputs(decoder_inputs) decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0) # print('decoder inputs : ', decoder_inputs.size()) decoder_inputs = self.prenet(decoder_inputs) self.initailze_decoder_states(memory, mask=~get_mask_from_lengths(memory_lengths)) mel_outputs, alignments, gate_outputs = [], [], [] while len(mel_outputs) < decoder_inputs.size(0) - 1: decoder_input = decoder_inputs[len(mel_outputs)] mel_output, attention_weights, gate_output = self.decode(decoder_input) # mel_output : (1, B, 240) mel_outputs.append(mel_output) alignments.append(attention_weights) gate_outputs.append(gate_output) # print('decoder prediction : ', len(mel_outputs)) mel_outputs, alignments, gate_outputs = self.parse_decoder_outputs(mel_outputs, alignments, gate_outputs) # print('mel outputs', mel_outputs.size()) return mel_outputs, alignments, gate_outputs
def forward(self, memory, decoder_inputs, mel_length, memory_lengths): """ Decoder forward pass for training PARAMS ------ memory: Encoder outputs decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs memory_lengths: Encoder output lengths for attention masking. RETURNS ------- mel_outputs: mel outputs from the decoder gate_outputs: gate outputs from the decoder alignments: sequence of attention weights from the decoder """ decoder_input = self.get_go_frame(memory).unsqueeze(0) decoder_inputs = self.parse_decoder_inputs(decoder_inputs) decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0) decoder_inputs = self.prenet(decoder_inputs) self.initialize_decoder_states( memory, mask=~get_mask_from_lengths(memory_lengths)) mel_outputs, gate_outputs, alignments = [], [], [] while len(mel_outputs) < decoder_inputs.size(0) - 1: decoder_input = decoder_inputs[len(mel_outputs)] mel_output, gate_output, attention_weights = self.decode( decoder_input, memory_lengths) mel_outputs += [mel_output.squeeze(1)] gate_outputs += [gate_output.squeeze()] alignments += [attention_weights] mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( mel_outputs, gate_outputs, alignments) return mel_outputs, gate_outputs, alignments
def forward(self, model_output, targets): mel_target, gate_target, output_lengths, *_ = targets mel_target.requires_grad = False gate_target.requires_grad = False mel_out, mel_out_postnet, gate_out, _ = model_output gate_target = gate_target.view(-1, 1) gate_out = gate_out.view(-1, 1) # remove paddings before loss calc if self.masked_select: mask = get_mask_from_lengths(output_lengths) mask = mask.expand(mel_target.size(1), mask.size(0), mask.size(1)) mask = mask.permute(1, 0, 2) mel_target = torch.masked_select(mel_target, mask) mel_out = torch.masked_select(mel_out, mask) mel_out_postnet = torch.masked_select(mel_out_postnet, mask) if self.loss_func == 'MSELoss': mel_loss = nn.MSELoss()(mel_out, mel_target) + \ nn.MSELoss()(mel_out_postnet, mel_target) elif self.loss_func == 'SmoothL1Loss': mel_loss = nn.SmoothL1Loss()(mel_out, mel_target) + \ nn.SmoothL1Loss()(mel_out_postnet, mel_target) gate_loss = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight)( gate_out, gate_target) return mel_loss + gate_loss, gate_loss
def mask(self, x, predicted, cemb, input_lengths, output_lengths, max_c_len, max_mel_len): x_mask = ~utils.get_mask_from_lengths(output_lengths, max_mel_len) x_mask = x_mask.expand(hparams.n_mel_channels, x_mask.size(0), x_mask.size(1)) x_mask = x_mask.permute(1, 2, 0) c_mask = ~utils.get_mask_from_lengths(input_lengths, max_c_len) c_mask = c_mask.expand(hparams.pre_gru_out_dim, c_mask.size(0), c_mask.size(1)) c_mask = c_mask.permute(1, 2, 0) x.data.masked_fill_(x_mask, 0.0) cemb.data.masked_fill_(c_mask, 0.0) predicted.data.masked_fill_(c_mask[:, :, 0], 0.0) return x, predicted, cemb
def hybrid_forward(self, F, memory, decoder_inputs, memory_lengths, *args, **kwargs): """ Decoder forward pass for training PARAMS ------ memory: Encoder outputs decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs memory_lengths: Encoder output lengths for attention masking. RETURNS ------- mel_outputs: mel outputs from the decoder gate_outputs: gate outputs from the decoder alignments: sequence of attention weights from the decoder """ #memory = memory.transpose((1, 0, 2)) decoder_input = self.get_go_frame(F, memory).expand_dims(0) decoder_inputs = self.parse_decoder_inputs(decoder_inputs) decoder_inputs = F.concat(decoder_input, decoder_inputs, dim=0) decoder_inputs = self._prenet(decoder_inputs) self.initialize_decoder_states(F, memory, mask=get_mask_from_lengths(F, memory_lengths)) #self.initialize_decoder_states(F, memory, mask=~get_mask_from_lengths(F, memory_lengths)) mel_outputs, gate_outputs, alignments = [], [], [] while len(mel_outputs) < decoder_inputs.shape[0] - 1: decoder_input = decoder_inputs[len(mel_outputs)] mel_output, gate_output, attention_weights = self.decode(F, decoder_input) mel_outputs += [mel_output.squeeze(1)] gate_outputs += [gate_output.squeeze()] alignments += [attention_weights] mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(mel_outputs, gate_outputs, alignments) return mel_outputs, gate_outputs, alignments
def parse_output(self, outputs, outputs_length=None): if outputs_length is not None: mask = ~get_mask_from_lengths(outputs_length) mask = mask.expand(self.n_landmarks_channels, mask.size(0), mask.size(1)) mask = mask.permute(1, 0, 2) outputs.masked_fill_(mask, 0.0)
def mask(self, mel_1, mel_2, length_mel, max_mel_len): x_mask = ~utils.get_mask_from_lengths(length_mel, max_mel_len) x_mask = x_mask.expand(hp.n_mel_channels, x_mask.size(0), x_mask.size(1)) x_mask = x_mask.permute(1, 2, 0) mel_1.data.masked_fill_(x_mask, 0.0) mel_2.data.masked_fill_(x_mask, 0.0) return mel_1, mel_2
def forward(self, text, text_len, is_mask=True): x, x_len = self.embed(text).transpose(1, 2), text_len if is_mask: mask = get_mask_from_lengths(x_len) else: mask = None out = self.predictor(x, mask) out = self.projection(out).squeeze(1) return out
def parse_output(self, outputs, output_lengths=None): if self.mask_padding and output_lengths is not None: mask = ~get_mask_from_lengths(output_lengths) mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1)) mask = mask.permute(1, 0, 2) outputs[0].data.masked_fill_(mask, 0.0) outputs[1].data.masked_fill_(mask, 0.0) outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies return outputs
def forward(self, memory, decoder_inputs, memory_lengths, f0s): """ Decoder forward pass for training PARAMS ------ memory: Encoder outputs decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs memory_lengths: Encoder output lengths for attention masking. RETURNS ------- mel_outputs: mel outputs from the decoder gate_outputs: gate outputs from the decoder alignments: sequence of attention weights from the decoder """ decoder_input = self.get_go_frame(memory).unsqueeze( 0) # initialize with zero frame decoder_inputs = self.parse_decoder_inputs( decoder_inputs) # reflect n_frames_per_step decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0) decoder_inputs = self.prenet( decoder_inputs) #mel dim to hparams.prenet_dim # audio features f0_dummy = self.get_end_f0(f0s) f0s = torch.cat((f0s, f0_dummy), dim=2) f0s = F.relu(self.prenet_f0(f0s)) f0s = f0s.permute(2, 0, 1) #(T,B,C) self.initialize_decoder_states( # initialize decoder member variables for dymanic tensor memory, mask=~get_mask_from_lengths(memory_lengths)) mel_outputs, gate_outputs, alignments = [], [], [] while len(mel_outputs) < decoder_inputs.size( 0) - 1: # Decoder.forward() continues until g.t. mel length. if len(mel_outputs) == 0 or np.random.uniform( 0.0, 1.0) <= self.p_teacher_forcing: decoder_input = torch.cat( (decoder_inputs[len(mel_outputs)], f0s[len(mel_outputs)]), dim=1) else: decoder_input = torch.cat( (self.prenet(mel_outputs[-1]), f0s[len(mel_outputs)]), dim=1) mel_output, gate_output, attention_weights = self.decode( decoder_input) mel_outputs += [mel_output.squeeze(1)] gate_outputs += [gate_output.squeeze()] alignments += [attention_weights] mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( mel_outputs, gate_outputs, alignments) return mel_outputs, gate_outputs, alignments
def forward(self, memory, decoder_inputs, memory_lengths): """ Decoder forward pass for training PARAMS ------ memory: Encoder outputs decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs memory_lengths: Encoder output lengths for attention masking. RETURNS ------- mel_outputs: mel outputs from the decoder gate_outputs: gate outputs from the decoder alignments: sequence of attention weights from the decoder """ decoder_input = self.get_go_frame(memory).unsqueeze(0) decoder_inputs = self.parse_decoder_inputs(decoder_inputs) decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0) decoder_inputs = self.prenet(decoder_inputs) self.initialize_decoder_states( memory, mask=~get_mask_from_lengths(memory_lengths)) mel_outputs, gate_outputs, alignments = [], [], [] # scheduled sampling decoder_inputs_student = list() while len(mel_outputs) < decoder_inputs.size(0) - 1: # scheduled sampling use_student_input = False if len(decoder_inputs_student) > 0: prob = self.student_input_prob toss = random.random() use_student_input = toss < prob if use_student_input: decoder_input = self.prenet(decoder_inputs_student[-1]) else: decoder_input = decoder_inputs[len(mel_outputs)] mel_output, gate_output, attention_weights = self.decode( decoder_input) decoder_inputs_student.append(mel_output) mel_outputs += [mel_output.squeeze(1)] gate_outputs += [gate_output.squeeze()] * self.n_frames_per_step alignments += [attention_weights] mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( mel_outputs, gate_outputs, alignments) return mel_outputs, gate_outputs, alignments
def forward(self, text, durs, is_mask=True): x, x_len = self.embed(text, durs).transpose(1, 2), durs.sum(-1) if is_mask: mask = get_mask_from_lengths(x_len) else: mask = None out = self.predictor(x, mask) uv = self.sil_proj(out).squeeze(1) value = self.body_proj(out).squeeze(1) return uv, value
def forward(self, src_seq, mel_target, mel_aug, p_norm, e_input, src_len, mel_len, d_target=None, p_target=None, e_target=None, max_src_len=None, max_mel_len=None, speaker_embed=None, d_control=1.0, p_control=1.0, e_control=1.0): src_mask = get_mask_from_lengths(src_len, max_src_len) mel_mask = get_mask_from_lengths(mel_len, max_mel_len) # Style modeling if d_target is not None: style_modeling_output, noise_encoding, d_prediction, p_prediction, e_prediction, _, _, (aug_posterior_d, aug_posterior_p, aug_posterior_e) = self.style_modeling( src_seq, speaker_embed, mel_target, mel_aug, p_norm, e_input, src_len, mel_len, src_mask, mel_mask, d_target, p_target, e_target, max_mel_len, d_control, p_control, e_control) else: style_modeling_output, noise_encoding, d_prediction, p_prediction, e_prediction, mel_len, mel_mask, (aug_posterior_d, aug_posterior_p, aug_posterior_e) = self.style_modeling( src_seq, speaker_embed, mel_target, mel_aug, p_norm, e_input, src_len, mel_len, src_mask, mel_mask, d_target, p_target, e_target, max_mel_len, d_control, p_control, e_control) # Clean decoding mel_output, mel_output_postnet = self.decode(style_modeling_output, mel_mask) # Noisy decoding mel_output_noisy, mel_output_postnet_noisy = self.decode(style_modeling_output.detach() + noise_encoding, mel_mask) return (mel_output, mel_output_noisy), (mel_output_postnet, mel_output_postnet_noisy), d_prediction, p_prediction, e_prediction, src_mask, mel_mask, mel_len, \ (aug_posterior_d, aug_posterior_p, aug_posterior_e)
def parse_output(outputs, output_lengths=None, n_mel_channels=80): if output_lengths is not None: max_len = torch.max(output_lengths).item() mask = ~get_mask_from_lengths(output_lengths, max_len) mask = mask.expand(n_mel_channels, mask.size(0), mask.size(1)) mask = mask.permute(1, 0, 2) outputs[0].data.masked_fill_(mask, 0.0) outputs[1].data.masked_fill_(mask, 0.0) outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies return outputs
def parse_output(self, outputs, text_lengths=None, output_lengths=None): if self.mask_padding and text_lengths is not None and output_lengths is not None: mask = ~get_mask_from_lengths(output_lengths) mask1 = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1)) mask1 = mask1.permute(1, 0, 2) outputs[0].data.masked_fill_(mask1, 0.0) outputs[1].data.masked_fill_(mask1, 0.0) mask2 = mask.expand( max(text_lengths).item() + 1, mask.size(0), mask.size(1)) mask2 = mask2.permute(1, 2, 0) outputs[2].data.masked_fill_(mask2, 0.0) mask = ~get_mask_from_lengths(text_lengths) mask3 = mask.expand(self.frame_level_rnn_dim, mask.size(0), mask.size(1)) mask3 = mask3.permute(1, 2, 0) outputs[3].data.masked_fill_(mask3, 0.0) outputs[4].data.masked_fill_(mask3, 0.0) return outputs
def parse_output(self, outputs, output_lengths=None): if self.mask_padding and output_lengths is not None: mask = ~get_mask_from_lengths(output_lengths) print("Shapeo of mask: ", mask.shape) mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1)) mask = mask.permute(1, 0, 2) outputs[0].data.masked_fill_(mask, 0.0) outputs[1].data.masked_fill_(mask, 0.0) # outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies outputs = fp16_to_fp32(outputs) if self.fp16_run else outputs return outputs
def forward(self, src_seq,speaker_emb, src_len, hz_seq = None,mel_len=None, d_target=None, max_src_len=None, max_mel_len=None, d_control=1.0, p_control=1.0, e_control=1.0): src_mask = get_mask_from_lengths(src_len, max_src_len) mel_mask = get_mask_from_lengths( mel_len, max_mel_len) if mel_len is not None else None encoder_output = self.encoder(src_seq, src_mask,hz_seq=hz_seq) if d_target is not None: variance_adaptor_output, d_prediction, _, _ = self.variance_adaptor( encoder_output, src_mask, mel_mask, d_target, max_mel_len, d_control, p_control, e_control) else: variance_adaptor_output, d_prediction, mel_len, mel_mask = self.variance_adaptor( encoder_output, src_mask, mel_mask, d_target, max_mel_len, d_control, p_control, e_control) decoder_output = self.decoder(variance_adaptor_output, mel_mask,speaker_emb) mel_output = self.mel_linear(decoder_output) if self.use_postnet: unet_out = self.postnet(torch.unsqueeze(mel_output,1)) mel_output_postnet = unet_out[:,0,:,:]+ mel_output else: mel_output_postnet = mel_output return mel_output, mel_output_postnet, d_prediction, src_mask, mel_mask, mel_len
def parse_outputs(self, mel_outputs, mel_outputs_postnet, gate_outputs, output_lengths): mask = ~get_mask_from_lengths(output_lengths, pad=True) mask = mask.expand(80, mask.size(0), mask.size(1)) mask = mask.permute(1, 0, 2) # mask : (B, 80, Frames) mel_outputs.data.masked_fill_(mask, 0.0) mel_outputs_postnet.data.masked_fill_(mask, 0.0) # gate outputs : (B, Frames // 3) slice_mask = torch.arange(0, mask.size(2), 1) gate_outputs.data.masked_fill_(mask[:, 0, slice_mask], 1e3) return mel_outputs, mel_outputs_postnet, gate_outputs
def forward(self, x, src_mask, mel_mask=None, duration_target=None, pitch_target=None, energy_target=None, max_len=None): log_duration_prediction = self.duration_predictor(x, src_mask) if duration_target is not None: x, mel_len = self.length_regulator(x, duration_target, max_len) else: duration_rounded = torch.clamp(torch.round( torch.exp(log_duration_prediction) - hp.log_offset), min=0) x, mel_len = self.length_regulator(x, duration_rounded, max_len) mel_mask = utils.get_mask_from_lengths(mel_len) pitch_prediction = self.pitch_predictor(x, mel_mask) if pitch_target is not None: src = torch.ceil((pitch_target - hp.f0_min) / (hp.f0_max - hp.f0_min) * hp.n_bins).long() # print(src) pitch_embedding = self.pitch_embedding(src) else: src = torch.ceil((pitch_prediction - hp.f0_min) / (hp.f0_max - hp.f0_min) * hp.n_bins).long() pitch_embedding = self.pitch_embedding(src) energy_prediction = self.energy_predictor(x, mel_mask) if energy_target is not None: # print(energy_target) src = torch.ceil( (energy_target - hp.energy_min) / (hp.energy_max - hp.energy_min) * hp.n_bins).long() # print(src) energy_embedding = self.energy_embedding(src) else: src = torch.ceil( (energy_prediction - hp.energy_min) / (hp.energy_max - hp.energy_min) * hp.n_bins).long() energy_embedding = self.energy_embedding(src) x = x + pitch_embedding + energy_embedding return x, log_duration_prediction, pitch_prediction, energy_prediction, mel_len, mel_mask