예제 #1
0
 def call(self, x, target_durations, training, durations_scalar=1.):
     padding_mask = create_encoder_padding_mask(x)
     x = self.encoder_prenet(x)
     x, encoder_attention = self.encoder(x,
                                         training=training,
                                         padding_mask=padding_mask,
                                         drop_n_heads=self.drop_n_heads)
     durations = self.dur_pred(x, training=training) * durations_scalar
     durations = (1. -
                  tf.reshape(padding_mask, tf.shape(durations))) * durations
     if target_durations is not None:
         mels = self.expand(x, target_durations)
     else:
         mels = self.expand(x, durations)
     expanded_mask = create_mel_padding_mask(mels)
     mels = self.decoder_prenet(mels)
     mels, decoder_attention = self.decoder(mels,
                                            training=training,
                                            padding_mask=expanded_mask,
                                            drop_n_heads=self.drop_n_heads,
                                            reduction_factor=1)
     mels = self.out(mels)
     mels = self.decoder_postnet(mels, training=training)
     model_out = {
         'mel': mels,
         'duration': durations,
         'expanded_mask': expanded_mask,
         'encoder_attention': encoder_attention,
         'decoder_attention': decoder_attention
     }
     return model_out
예제 #2
0
 def _call_decoder(self, encoder_output, targets, encoder_padding_mask,
                   training, xvectors):
     #xvec
     dec_target_padding_mask = create_mel_padding_mask(targets)
     look_ahead_mask = create_look_ahead_mask(tf.shape(targets)[1])
     combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)
     dec_input = self.decoder_prenet(targets)
     xvec = self.dec_speaker_mod(xvectors)
     dec_input = tf.keras.layers.concatenate([dec_input, xvec], axis=1)
     dec_output, attention_weights = self.decoder(
         inputs=dec_input,
         enc_output=encoder_output,
         training=training,
         decoder_padding_mask=combined_mask,
         encoder_padding_mask=encoder_padding_mask,
         drop_n_heads=self.drop_n_heads,
         reduction_factor=self.r)
     out_proj = self.final_proj_mel(dec_output)[:, :, :self.r *
                                                self.mel_channels]
     b = int(tf.shape(out_proj)[0])
     t = int(tf.shape(out_proj)[1])
     mel = tf.reshape(out_proj, (b, t * self.r, self.mel_channels))
     model_output = self.decoder_postnet(mel, training=training)
     model_output.update({
         'decoder_attention': attention_weights,
         'decoder_output': dec_output,
         'linear': mel
     })
     return model_output
예제 #3
0
def validate(model, val_dataset, summary_manager):
    val_loss = {'loss': 0.}
    norm = 0.
    for phonemes, mel, durations, spk_emb, fname in val_dataset.all_batches():
        model_out = model.val_step(input_sequence=phonemes,
                                   target_sequence=mel,
                                   target_durations=durations,
                                   spk_emb=spk_emb)
        norm += 1
        val_loss['loss'] += model_out['loss']
    val_loss['loss'] /= norm
    summary_manager.display_loss(model_out, tag='Validation', plot_all=True)
    summary_manager.display_attention_heads(model_out,
                                            tag='ValidationAttentionHeads')
    summary_manager.add_histogram(tag=f'Validation/Predicted durations',
                                  values=model_out['duration'])
    summary_manager.add_histogram(tag=f'Validation/Target durations',
                                  values=durations)
    summary_manager.display_mel(
        mel=model_out['mel'][0],
        tag=f'Validation/{fname[0].numpy().decode("utf-8")} predicted_mel')
    summary_manager.display_mel(
        mel=mel[0],
        tag=f'Validation/{fname[0].numpy().decode("utf-8")} target_mel')
    summary_manager.display_audio(
        tag=f'Validation {fname[0].numpy().decode("utf-8")}/prediction',
        mel=model_out['mel'][0])
    summary_manager.display_audio(
        tag=f'Validation {fname[0].numpy().decode("utf-8")}/target',
        mel=mel[0])
    # predict withoyt enforcing durations and pitch
    model_out = model.predict(phonemes, spk_emb=spk_emb, encode=False)
    pred_lengths = tf.cast(
        tf.reduce_sum(1 - model_out['expanded_mask'], axis=-1), tf.int32)
    pred_lengths = tf.squeeze(pred_lengths)
    tar_lengths = tf.cast(
        tf.reduce_sum(1 - create_mel_padding_mask(mel), axis=-1), tf.int32)
    tar_lengths = tf.squeeze(tar_lengths)
    for j, pred_mel in enumerate(model_out['mel']):
        predval = pred_mel[:pred_lengths[j], :]
        tar_value = mel[j, :tar_lengths[j], :]
        summary_manager.display_mel(
            mel=predval,
            tag=f'Test/{fname[j].numpy().decode("utf-8")}/predicted')
        summary_manager.display_mel(
            mel=tar_value,
            tag=f'Test/{fname[j].numpy().decode("utf-8")}/target')
        summary_manager.display_audio(
            tag=f'Prediction {fname[j].numpy().decode("utf-8")}/target',
            mel=tar_value)
        summary_manager.display_audio(
            tag=f'Prediction {fname[j].numpy().decode("utf-8")}/prediction',
            mel=predval)
    return val_loss['loss']
예제 #4
0
    def call(self,
             x,
             target_durations,
             spk_emb,
             training,
             durations_scalar=1.,
             max_durations_mask=None,
             min_durations_mask=None):
        encoder_padding_mask = create_encoder_padding_mask(x)
        x = self.encoder_prenet(x)
        x, encoder_attention = self.encoder(x,
                                            training=training,
                                            padding_mask=encoder_padding_mask,
                                            drop_n_heads=0)
        padding_mask = 1. - tf.squeeze(encoder_padding_mask, axis=(1, 2))[:, :,
                                                                          None]
        spk_emb = tf.math.softplus(self.speaker_fc(spk_emb))
        spk_emb = tf.expand_dims(spk_emb, 1)
        x = x + spk_emb  #tf.tile(pitch_embed, [1, tf.shape(x)[1], 1])

        durations = self.dur_pred(x, training=training, mask=padding_mask)

        if target_durations is not None:
            use_durations = target_durations
        else:
            use_durations = durations * durations_scalar
        if max_durations_mask is not None:
            use_durations = tf.math.minimum(
                use_durations, tf.expand_dims(max_durations_mask, -1))
        if min_durations_mask is not None:
            use_durations = tf.math.maximum(
                use_durations, tf.expand_dims(min_durations_mask, -1))
        mels = self.expand(x, use_durations)
        expanded_mask = create_mel_padding_mask(mels)
        mels, decoder_attention = self.decoder(mels,
                                               training=training,
                                               padding_mask=expanded_mask,
                                               drop_n_heads=0)
        mels = self.out(mels)
        model_out = {
            'mel': mels,
            'duration': durations,
            'expanded_mask': expanded_mask,
            'encoder_attention': encoder_attention,
            'decoder_attention': decoder_attention
        }
        return model_out
 def _call_decoder(self, dec_input, enc_output, enc_padding_mask, training):
     dec_target_padding_mask = create_mel_padding_mask(dec_input)
     look_ahead_mask = create_look_ahead_mask(tf.shape(dec_input)[1])
     combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)
     dec_output, attention_weights = self.decoder(
         x=dec_input,
         enc_output=enc_output,
         training=training,
         look_ahead_mask=combined_mask,
         padding_mask=enc_padding_mask)
     linear = self.linear(dec_output)
     model_out = {
         'linear': linear,
         'decoder_attention': attention_weights,
         'decoder_output': dec_output
     }
     return model_out
예제 #6
0
 def _call_decoder(self, encoder_output, targets, encoder_padding_mask, training):
     dec_target_padding_mask = create_mel_padding_mask(targets)
     look_ahead_mask = create_look_ahead_mask(tf.shape(targets)[1])
     combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)
     dec_input = self.decoder_prenet(targets, training=training, dropout_rate=self.decoder_prenet_dropout)
     dec_output, attention_weights = self.decoder(inputs=dec_input,
                                                  enc_output=encoder_output,
                                                  training=training,
                                                  decoder_padding_mask=combined_mask,
                                                  encoder_padding_mask=encoder_padding_mask,
                                                  drop_n_heads=self.drop_n_heads,
                                                  reduction_factor=self.r)
     out_proj = self.final_proj_mel(dec_output)[:, :, :self.r * self.mel_channels]
     b = int(tf.shape(out_proj)[0])
     t = int(tf.shape(out_proj)[1])
     mel = tf.reshape(out_proj, (b, t * self.r, self.mel_channels))
     model_output = self.decoder_postnet(mel, training=training)
     model_output.update(
         {'decoder_attention': attention_weights, 'decoder_output': dec_output, 'out_proj': out_proj})
     return model_output
예제 #7
0
    n_convs = int(n_layers - n_dense)
    if n_convs > 0:
        last_layer_key = f'Decoder_ConvBlock{n_convs}_CrossAttention'
    else:
        last_layer_key = f'Decoder_DenseBlock{n_dense}_CrossAttention'
    print(f'Extracting attention from layer {last_layer_key}')
    val_batches = []
    iterator = tqdm(enumerate(val_dataset.all_batches()))
    for c, (val_mel, val_text, val_stop) in iterator:
        iterator.set_description(f'Processing validation set')
        outputs = model.val_step(inp=val_text, tar=val_mel, stop_prob=val_stop)
        if args.use_GT:
            batch = (val_mel.numpy(), val_text.numpy(),
                     outputs['decoder_attention'][last_layer_key].numpy())
        else:
            mask = create_mel_padding_mask(val_mel)
            out_val = tf.expand_dims(
                1 - tf.squeeze(create_mel_padding_mask(val_mel[:, 1:, :])),
                -1) * outputs['final_output'].numpy()
            batch = (out_val.numpy(), val_text.numpy(),
                     outputs['decoder_attention'][last_layer_key].numpy())
        if args.store_predictions:
            with open(str(val_predictions_dir / f'{c}_batch_prediction.npy'),
                      'wb') as file:
                pickle.dump(batch, file)
        val_batches.append(batch)

    train_batches = []
    iterator = tqdm(enumerate(train_dataset.all_batches()))
    for c, (train_mel, train_text, train_stop) in iterator:
        iterator.set_description(f'Processing training set')
예제 #8
0
     model.step >= config['prediction_start_step']):
 tar_mel, phonemes, durs = test_batch
 t.display(f'Predicting', pos=len(config['n_steps_avg_losses']) + 4)
 timed_pred = time_it(model.predict)
 model_out, time_taken = timed_pred(phonemes, encode=False)
 summary_manager.display_attention_heads(model_out,
                                         tag='TestAttentionHeads')
 summary_manager.add_histogram(tag=f'Test/Predicted durations',
                               values=model_out['duration'])
 summary_manager.add_histogram(tag=f'Test/Target durations',
                               values=durs)
 pred_lengths = tf.cast(
     tf.reduce_sum(1 - model_out['expanded_mask'], axis=-1), tf.int32)
 pred_lengths = tf.squeeze(pred_lengths)
 tar_lengths = tf.cast(
     tf.reduce_sum(1 - create_mel_padding_mask(tar_mel), axis=-1),
     tf.int32)
 tar_lengths = tf.squeeze(tar_lengths)
 display_start = time()
 for j, pred_mel in enumerate(model_out['mel']):
     predval = pred_mel[:pred_lengths[j], :]
     tar_value = tar_mel[j, :tar_lengths[j], :]
     summary_manager.display_mel(mel=predval,
                                 tag=f'Test/sample {j}/predicted_mel')
     summary_manager.display_mel(mel=tar_value,
                                 tag=f'Test/sample {j}/target_mel')
     if j < config['n_predictions']:
         if model.step >= config['audio_start_step'] and (
                 model.step % config['audio_prediction_frequency']
                 == 0):
             summary_manager.display_audio(tag=f'Target/sample {j}',
예제 #9
0
def get_durations_from_alignment(batch_alignments,
                                 mels,
                                 phonemes,
                                 weighted=False,
                                 binary=False,
                                 fill_gaps=False,
                                 fix_jumps=False,
                                 fill_mode='max'):
    """
    
    :param batch_alignments: attention weights from autoregressive model.
    :param mels: mel spectrograms.
    :param phonemes: phoneme sequence.
    :param weighted: if True use weighted average of durations of heads, best head if False.
    :param binary: if True take maximum attention peak, sum if False.
    :param fill_gaps: if True fills zeros durations with ones.
    :param fix_jumps: if True, tries to scan alingments for attention jumps and interpolate.
    :param fill_mode: used only if fill_gaps is True. Is either 'max' or 'next'. Defines where to take the duration
        needed to fill the gap. Next takes it from the next non-zeros duration value, max from the sequence maximum.
    :return:
    """
    assert (binary is True) or (
        fix_jumps is False), 'Cannot fix jumps in non-binary attention.'
    mel_pad_mask = create_mel_padding_mask(mels)
    phon_pad_mask = create_encoder_padding_mask(phonemes)
    durations = []
    # remove start end token or vector
    unpad_mels = []
    unpad_phonemes = []
    final_alignment = []
    for i, al in enumerate(batch_alignments):
        mel_len = int(mel_pad_mask[i].shape[-1] - np.sum(mel_pad_mask[i]))
        phon_len = int(phon_pad_mask[i].shape[-1] - np.sum(phon_pad_mask[i]))
        unpad_alignments = al[:, 1:mel_len - 1,
                              1:phon_len - 1]  # first dim is heads
        unpad_mels.append(mels[i, 1:mel_len - 1, :])
        unpad_phonemes.append(phonemes[i, 1:phon_len - 1])
        alignments_weights = weight_mask(unpad_alignments[0])
        heads_scores = []
        scored_attention = []
        for _, attention_weights in enumerate(unpad_alignments):
            score = np.sum(alignments_weights * attention_weights)
            scored_attention.append(attention_weights / score)
            heads_scores.append(score)

        if weighted:
            ref_attention_weights = np.sum(scored_attention, axis=0)
        else:
            best_head = np.argmin(heads_scores)
            ref_attention_weights = unpad_alignments[best_head]

        if binary:  # pick max attention for each mel time-step
            binary_attn, binary_score = binary_attention(ref_attention_weights)
            if fix_jumps:
                binary_attn = fix_attention_jumps(
                    binary_attn=binary_attn,
                    alignments_weights=alignments_weights,
                    binary_score=binary_score)
            integer_durations = binary_attn.sum(axis=0)

        else:  # takes actual attention values and normalizes to mel_len
            attention_durations = np.sum(ref_attention_weights, axis=0)
            normalized_durations = attention_durations * (
                (mel_len - 2) / np.sum(attention_durations))
            integer_durations = np.round(normalized_durations)
            tot_duration = np.sum(integer_durations)
            duration_diff = tot_duration - (mel_len - 2)
            while duration_diff != 0:
                rounding_diff = integer_durations - normalized_durations
                if duration_diff > 0:  # duration is too long -> reduce highest (positive) rounding difference
                    max_error_idx = np.argmax(rounding_diff)
                    integer_durations[max_error_idx] -= 1
                elif duration_diff < 0:  # duration is too short -> increase lowest (negative) rounding difference
                    min_error_idx = np.argmin(rounding_diff)
                    integer_durations[min_error_idx] += 1
                tot_duration = np.sum(integer_durations)
                duration_diff = tot_duration - (mel_len - 2)

        if fill_gaps:  # fill zeros durations
            integer_durations = fill_zeros(integer_durations,
                                           take_from=fill_mode)

        assert np.sum(
            integer_durations
        ) == mel_len - 2, f'{np.sum(integer_durations)} vs {mel_len - 2}'
        new_alignment = duration_to_alignment_matrix(
            integer_durations.astype(int))
        best_head = np.argmin(heads_scores)
        best_attention = unpad_alignments[best_head]
        final_alignment.append(best_attention.T + new_alignment)
        durations.append(integer_durations)
    return durations, unpad_mels, unpad_phonemes, final_alignment
예제 #10
0
new_alignments = []
iterator = tqdm(enumerate(dataset.all_batches()))
step = 0
for c, (mel_batch, text_batch, stop_batch, file_name_batch) in iterator:
    iterator.set_description(f'Processing dataset')
    outputs = model.val_step(inp=text_batch,
                             tar=mel_batch,
                             stop_prob=stop_batch)
    attention_values = outputs['decoder_attention'][last_layer_key].numpy()
    text = text_batch.numpy()

    if args.use_GT:
        mel = mel_batch.numpy()
    else:
        pred_mel = outputs['final_output'].numpy()
        mask = create_mel_padding_mask(mel_batch)
        pred_mel = tf.expand_dims(
            1 - tf.squeeze(create_mel_padding_mask(mel_batch[:, 1:, :])),
            -1) * pred_mel
        mel = pred_mel.numpy()

    durations, final_align, jumpiness, peakiness, diag_measure = get_durations_from_alignment(
        batch_alignments=attention_values,
        mels=mel,
        phonemes=text,
        weighted=weighted,
        binary=binary,
        fill_gaps=fill_gaps,
        fill_mode=fill_mode,
        fix_jumps=fix_jumps)
    batch_avg_jumpiness = tf.reduce_mean(jumpiness, axis=0)