def call(self, x, target_durations, training, durations_scalar=1.): padding_mask = create_encoder_padding_mask(x) x = self.encoder_prenet(x) x, encoder_attention = self.encoder(x, training=training, padding_mask=padding_mask, drop_n_heads=self.drop_n_heads) durations = self.dur_pred(x, training=training) * durations_scalar durations = (1. - tf.reshape(padding_mask, tf.shape(durations))) * durations if target_durations is not None: mels = self.expand(x, target_durations) else: mels = self.expand(x, durations) expanded_mask = create_mel_padding_mask(mels) mels = self.decoder_prenet(mels) mels, decoder_attention = self.decoder(mels, training=training, padding_mask=expanded_mask, drop_n_heads=self.drop_n_heads, reduction_factor=1) mels = self.out(mels) mels = self.decoder_postnet(mels, training=training) model_out = { 'mel': mels, 'duration': durations, 'expanded_mask': expanded_mask, 'encoder_attention': encoder_attention, 'decoder_attention': decoder_attention } return model_out
def _call_decoder(self, encoder_output, targets, encoder_padding_mask, training, xvectors): #xvec dec_target_padding_mask = create_mel_padding_mask(targets) look_ahead_mask = create_look_ahead_mask(tf.shape(targets)[1]) combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask) dec_input = self.decoder_prenet(targets) xvec = self.dec_speaker_mod(xvectors) dec_input = tf.keras.layers.concatenate([dec_input, xvec], axis=1) dec_output, attention_weights = self.decoder( inputs=dec_input, enc_output=encoder_output, training=training, decoder_padding_mask=combined_mask, encoder_padding_mask=encoder_padding_mask, drop_n_heads=self.drop_n_heads, reduction_factor=self.r) out_proj = self.final_proj_mel(dec_output)[:, :, :self.r * self.mel_channels] b = int(tf.shape(out_proj)[0]) t = int(tf.shape(out_proj)[1]) mel = tf.reshape(out_proj, (b, t * self.r, self.mel_channels)) model_output = self.decoder_postnet(mel, training=training) model_output.update({ 'decoder_attention': attention_weights, 'decoder_output': dec_output, 'linear': mel }) return model_output
def validate(model, val_dataset, summary_manager): val_loss = {'loss': 0.} norm = 0. for phonemes, mel, durations, spk_emb, fname in val_dataset.all_batches(): model_out = model.val_step(input_sequence=phonemes, target_sequence=mel, target_durations=durations, spk_emb=spk_emb) norm += 1 val_loss['loss'] += model_out['loss'] val_loss['loss'] /= norm summary_manager.display_loss(model_out, tag='Validation', plot_all=True) summary_manager.display_attention_heads(model_out, tag='ValidationAttentionHeads') summary_manager.add_histogram(tag=f'Validation/Predicted durations', values=model_out['duration']) summary_manager.add_histogram(tag=f'Validation/Target durations', values=durations) summary_manager.display_mel( mel=model_out['mel'][0], tag=f'Validation/{fname[0].numpy().decode("utf-8")} predicted_mel') summary_manager.display_mel( mel=mel[0], tag=f'Validation/{fname[0].numpy().decode("utf-8")} target_mel') summary_manager.display_audio( tag=f'Validation {fname[0].numpy().decode("utf-8")}/prediction', mel=model_out['mel'][0]) summary_manager.display_audio( tag=f'Validation {fname[0].numpy().decode("utf-8")}/target', mel=mel[0]) # predict withoyt enforcing durations and pitch model_out = model.predict(phonemes, spk_emb=spk_emb, encode=False) pred_lengths = tf.cast( tf.reduce_sum(1 - model_out['expanded_mask'], axis=-1), tf.int32) pred_lengths = tf.squeeze(pred_lengths) tar_lengths = tf.cast( tf.reduce_sum(1 - create_mel_padding_mask(mel), axis=-1), tf.int32) tar_lengths = tf.squeeze(tar_lengths) for j, pred_mel in enumerate(model_out['mel']): predval = pred_mel[:pred_lengths[j], :] tar_value = mel[j, :tar_lengths[j], :] summary_manager.display_mel( mel=predval, tag=f'Test/{fname[j].numpy().decode("utf-8")}/predicted') summary_manager.display_mel( mel=tar_value, tag=f'Test/{fname[j].numpy().decode("utf-8")}/target') summary_manager.display_audio( tag=f'Prediction {fname[j].numpy().decode("utf-8")}/target', mel=tar_value) summary_manager.display_audio( tag=f'Prediction {fname[j].numpy().decode("utf-8")}/prediction', mel=predval) return val_loss['loss']
def call(self, x, target_durations, spk_emb, training, durations_scalar=1., max_durations_mask=None, min_durations_mask=None): encoder_padding_mask = create_encoder_padding_mask(x) x = self.encoder_prenet(x) x, encoder_attention = self.encoder(x, training=training, padding_mask=encoder_padding_mask, drop_n_heads=0) padding_mask = 1. - tf.squeeze(encoder_padding_mask, axis=(1, 2))[:, :, None] spk_emb = tf.math.softplus(self.speaker_fc(spk_emb)) spk_emb = tf.expand_dims(spk_emb, 1) x = x + spk_emb #tf.tile(pitch_embed, [1, tf.shape(x)[1], 1]) durations = self.dur_pred(x, training=training, mask=padding_mask) if target_durations is not None: use_durations = target_durations else: use_durations = durations * durations_scalar if max_durations_mask is not None: use_durations = tf.math.minimum( use_durations, tf.expand_dims(max_durations_mask, -1)) if min_durations_mask is not None: use_durations = tf.math.maximum( use_durations, tf.expand_dims(min_durations_mask, -1)) mels = self.expand(x, use_durations) expanded_mask = create_mel_padding_mask(mels) mels, decoder_attention = self.decoder(mels, training=training, padding_mask=expanded_mask, drop_n_heads=0) mels = self.out(mels) model_out = { 'mel': mels, 'duration': durations, 'expanded_mask': expanded_mask, 'encoder_attention': encoder_attention, 'decoder_attention': decoder_attention } return model_out
def _call_decoder(self, dec_input, enc_output, enc_padding_mask, training): dec_target_padding_mask = create_mel_padding_mask(dec_input) look_ahead_mask = create_look_ahead_mask(tf.shape(dec_input)[1]) combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask) dec_output, attention_weights = self.decoder( x=dec_input, enc_output=enc_output, training=training, look_ahead_mask=combined_mask, padding_mask=enc_padding_mask) linear = self.linear(dec_output) model_out = { 'linear': linear, 'decoder_attention': attention_weights, 'decoder_output': dec_output } return model_out
def _call_decoder(self, encoder_output, targets, encoder_padding_mask, training): dec_target_padding_mask = create_mel_padding_mask(targets) look_ahead_mask = create_look_ahead_mask(tf.shape(targets)[1]) combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask) dec_input = self.decoder_prenet(targets, training=training, dropout_rate=self.decoder_prenet_dropout) dec_output, attention_weights = self.decoder(inputs=dec_input, enc_output=encoder_output, training=training, decoder_padding_mask=combined_mask, encoder_padding_mask=encoder_padding_mask, drop_n_heads=self.drop_n_heads, reduction_factor=self.r) out_proj = self.final_proj_mel(dec_output)[:, :, :self.r * self.mel_channels] b = int(tf.shape(out_proj)[0]) t = int(tf.shape(out_proj)[1]) mel = tf.reshape(out_proj, (b, t * self.r, self.mel_channels)) model_output = self.decoder_postnet(mel, training=training) model_output.update( {'decoder_attention': attention_weights, 'decoder_output': dec_output, 'out_proj': out_proj}) return model_output
n_convs = int(n_layers - n_dense) if n_convs > 0: last_layer_key = f'Decoder_ConvBlock{n_convs}_CrossAttention' else: last_layer_key = f'Decoder_DenseBlock{n_dense}_CrossAttention' print(f'Extracting attention from layer {last_layer_key}') val_batches = [] iterator = tqdm(enumerate(val_dataset.all_batches())) for c, (val_mel, val_text, val_stop) in iterator: iterator.set_description(f'Processing validation set') outputs = model.val_step(inp=val_text, tar=val_mel, stop_prob=val_stop) if args.use_GT: batch = (val_mel.numpy(), val_text.numpy(), outputs['decoder_attention'][last_layer_key].numpy()) else: mask = create_mel_padding_mask(val_mel) out_val = tf.expand_dims( 1 - tf.squeeze(create_mel_padding_mask(val_mel[:, 1:, :])), -1) * outputs['final_output'].numpy() batch = (out_val.numpy(), val_text.numpy(), outputs['decoder_attention'][last_layer_key].numpy()) if args.store_predictions: with open(str(val_predictions_dir / f'{c}_batch_prediction.npy'), 'wb') as file: pickle.dump(batch, file) val_batches.append(batch) train_batches = [] iterator = tqdm(enumerate(train_dataset.all_batches())) for c, (train_mel, train_text, train_stop) in iterator: iterator.set_description(f'Processing training set')
model.step >= config['prediction_start_step']): tar_mel, phonemes, durs = test_batch t.display(f'Predicting', pos=len(config['n_steps_avg_losses']) + 4) timed_pred = time_it(model.predict) model_out, time_taken = timed_pred(phonemes, encode=False) summary_manager.display_attention_heads(model_out, tag='TestAttentionHeads') summary_manager.add_histogram(tag=f'Test/Predicted durations', values=model_out['duration']) summary_manager.add_histogram(tag=f'Test/Target durations', values=durs) pred_lengths = tf.cast( tf.reduce_sum(1 - model_out['expanded_mask'], axis=-1), tf.int32) pred_lengths = tf.squeeze(pred_lengths) tar_lengths = tf.cast( tf.reduce_sum(1 - create_mel_padding_mask(tar_mel), axis=-1), tf.int32) tar_lengths = tf.squeeze(tar_lengths) display_start = time() for j, pred_mel in enumerate(model_out['mel']): predval = pred_mel[:pred_lengths[j], :] tar_value = tar_mel[j, :tar_lengths[j], :] summary_manager.display_mel(mel=predval, tag=f'Test/sample {j}/predicted_mel') summary_manager.display_mel(mel=tar_value, tag=f'Test/sample {j}/target_mel') if j < config['n_predictions']: if model.step >= config['audio_start_step'] and ( model.step % config['audio_prediction_frequency'] == 0): summary_manager.display_audio(tag=f'Target/sample {j}',
def get_durations_from_alignment(batch_alignments, mels, phonemes, weighted=False, binary=False, fill_gaps=False, fix_jumps=False, fill_mode='max'): """ :param batch_alignments: attention weights from autoregressive model. :param mels: mel spectrograms. :param phonemes: phoneme sequence. :param weighted: if True use weighted average of durations of heads, best head if False. :param binary: if True take maximum attention peak, sum if False. :param fill_gaps: if True fills zeros durations with ones. :param fix_jumps: if True, tries to scan alingments for attention jumps and interpolate. :param fill_mode: used only if fill_gaps is True. Is either 'max' or 'next'. Defines where to take the duration needed to fill the gap. Next takes it from the next non-zeros duration value, max from the sequence maximum. :return: """ assert (binary is True) or ( fix_jumps is False), 'Cannot fix jumps in non-binary attention.' mel_pad_mask = create_mel_padding_mask(mels) phon_pad_mask = create_encoder_padding_mask(phonemes) durations = [] # remove start end token or vector unpad_mels = [] unpad_phonemes = [] final_alignment = [] for i, al in enumerate(batch_alignments): mel_len = int(mel_pad_mask[i].shape[-1] - np.sum(mel_pad_mask[i])) phon_len = int(phon_pad_mask[i].shape[-1] - np.sum(phon_pad_mask[i])) unpad_alignments = al[:, 1:mel_len - 1, 1:phon_len - 1] # first dim is heads unpad_mels.append(mels[i, 1:mel_len - 1, :]) unpad_phonemes.append(phonemes[i, 1:phon_len - 1]) alignments_weights = weight_mask(unpad_alignments[0]) heads_scores = [] scored_attention = [] for _, attention_weights in enumerate(unpad_alignments): score = np.sum(alignments_weights * attention_weights) scored_attention.append(attention_weights / score) heads_scores.append(score) if weighted: ref_attention_weights = np.sum(scored_attention, axis=0) else: best_head = np.argmin(heads_scores) ref_attention_weights = unpad_alignments[best_head] if binary: # pick max attention for each mel time-step binary_attn, binary_score = binary_attention(ref_attention_weights) if fix_jumps: binary_attn = fix_attention_jumps( binary_attn=binary_attn, alignments_weights=alignments_weights, binary_score=binary_score) integer_durations = binary_attn.sum(axis=0) else: # takes actual attention values and normalizes to mel_len attention_durations = np.sum(ref_attention_weights, axis=0) normalized_durations = attention_durations * ( (mel_len - 2) / np.sum(attention_durations)) integer_durations = np.round(normalized_durations) tot_duration = np.sum(integer_durations) duration_diff = tot_duration - (mel_len - 2) while duration_diff != 0: rounding_diff = integer_durations - normalized_durations if duration_diff > 0: # duration is too long -> reduce highest (positive) rounding difference max_error_idx = np.argmax(rounding_diff) integer_durations[max_error_idx] -= 1 elif duration_diff < 0: # duration is too short -> increase lowest (negative) rounding difference min_error_idx = np.argmin(rounding_diff) integer_durations[min_error_idx] += 1 tot_duration = np.sum(integer_durations) duration_diff = tot_duration - (mel_len - 2) if fill_gaps: # fill zeros durations integer_durations = fill_zeros(integer_durations, take_from=fill_mode) assert np.sum( integer_durations ) == mel_len - 2, f'{np.sum(integer_durations)} vs {mel_len - 2}' new_alignment = duration_to_alignment_matrix( integer_durations.astype(int)) best_head = np.argmin(heads_scores) best_attention = unpad_alignments[best_head] final_alignment.append(best_attention.T + new_alignment) durations.append(integer_durations) return durations, unpad_mels, unpad_phonemes, final_alignment
new_alignments = [] iterator = tqdm(enumerate(dataset.all_batches())) step = 0 for c, (mel_batch, text_batch, stop_batch, file_name_batch) in iterator: iterator.set_description(f'Processing dataset') outputs = model.val_step(inp=text_batch, tar=mel_batch, stop_prob=stop_batch) attention_values = outputs['decoder_attention'][last_layer_key].numpy() text = text_batch.numpy() if args.use_GT: mel = mel_batch.numpy() else: pred_mel = outputs['final_output'].numpy() mask = create_mel_padding_mask(mel_batch) pred_mel = tf.expand_dims( 1 - tf.squeeze(create_mel_padding_mask(mel_batch[:, 1:, :])), -1) * pred_mel mel = pred_mel.numpy() durations, final_align, jumpiness, peakiness, diag_measure = get_durations_from_alignment( batch_alignments=attention_values, mels=mel, phonemes=text, weighted=weighted, binary=binary, fill_gaps=fill_gaps, fill_mode=fill_mode, fix_jumps=fix_jumps) batch_avg_jumpiness = tf.reduce_mean(jumpiness, axis=0)