Exemple #1
0
def fix_attention_jumps(binary_attn, binary_score, mel_len, phon_len):
    """ Scans for jumps in attention and attempts to fix. If score decreases, a collapse
        is likely so it tries to relax the jump size.
        Lower jumps size is more accurate, but more prone to collapse.
    """
    clean_scores = []
    clean_attns = []
    for jumpth in [2, 3, 4, 5]:
        cl_at = clean_attention(binary_attention=binary_attn,
                                jump_threshold=jumpth)
        clean_attns.append(cl_at)
        sclean_score = attention_score(att=tf.cast(cl_at[None, None, :, :],
                                                   tf.float32),
                                       mel_len=mel_len,
                                       phon_len=phon_len,
                                       r=1)
        clean_scores.append(tf.reduce_mean(sclean_score))
    best_idx = np.argmax(clean_scores)
    best_score = clean_scores[best_idx]
    best_cleaned_attention = clean_attns[best_idx]
    while (binary_score > best_score) and (jumpth < 20):
        jumpth += 1
        best_cleaned_attention = clean_attention(binary_attention=binary_attn,
                                                 jump_threshold=jumpth)
        best_score = attention_score(att=tf.cast(
            best_cleaned_attention[None, None, :, :], tf.float32),
                                     mel_len=mel_len,
                                     phon_len=phon_len,
                                     r=1)
        best_score = tf.reduce_mean(best_score)
    if binary_score > best_score:
        best_cleaned_attention = binary_attn
    return best_cleaned_attention
def create_align_features(
    model: Tacotron,
    train_set: DataLoader,
    val_set: DataLoader,
    save_path_alg: Path,
    #   save_path_pitch: Path
):
    assert model.r == 1, f'Reduction factor of tacotron must be 1 for creating alignment features! ' \
                         f'Reduction factor was: {model.r}'
    model.eval()
    device = next(
        model.parameters()).device  # use same device as model parameters
    if val_set is not None:
        iters = len(val_set) + len(train_set)
        dataset = itertools.chain(train_set, val_set)
    else:
        # print('here')
        iters = len(train_set)
        # print(iters)
        dataset = itertools.chain(train_set)

    att_score_dict = {}

    if hp.extract_durations_with_dijkstra:
        print('Extracting durations using dijkstra...')

        dur_extraction_func = extract_durations_with_dijkstra

    else:
        print('Extracting durations using attention peak counts...')
        dur_extraction_func = extract_durations_per_count
    # for i in dataset:
    # print(i)
    for i, (x, mels, ids, x_lens, mel_lens) in enumerate(dataset, 1):
        x, mels = x.to(device), mels.to(device)
        # print(x)
        # print(mels)
        with torch.no_grad():
            _, _, att_batch = model(x, mels)
        align_score, sharp_score = attention_score(att_batch, mel_lens, r=1)
        att_batch = np_now(att_batch)
        seq, att, mel_len, item_id = x[0], att_batch[0], mel_lens[0], ids[0]
        align_score, sharp_score = float(align_score[0]), float(sharp_score[0])
        att_score_dict[item_id] = (align_score, sharp_score)
        durs = dur_extraction_func(seq, att, mel_len)
        if np.sum(durs) != mel_len:
            print(
                f'WARNINNG: Sum of durations did not match mel length for item {item_id}!'
            )
        np.save(str(save_path_alg / f'{item_id}.npy'),
                durs,
                allow_pickle=False)
        bar = progbar(i, iters)
        msg = f'{bar} {i}/{iters} Batches '
        stream(msg)
    pickle_binary(att_score_dict, paths.data / 'att_score_dict.pkl')
Exemple #3
0
def get_durations_from_alignment(batch_alignments,
                                 mels,
                                 phonemes,
                                 weighted=False):
    """

    :param batch_alignments: attention weights from autoregressive model.
    :param mels: mel spectrograms.
    :param phonemes: phoneme sequence.
    :param weighted: if True use weighted average of durations of heads, best head if False.
    :param binary: if True take maximum attention peak, sum if False.
    :param fill_gaps: if True fills zeros durations with ones.
    :param fix_jumps: if True, tries to scan alingments for attention jumps and interpolate.
    :param fill_mode: used only if fill_gaps is True. Is either 'max' or 'next'. Defines where to take the duration
        needed to fill the gap. Next takes it from the next non-zeros duration value, max from the sequence maximum.
    :return:
    """
    # mel_len - 1 because we remove last timestep, which is end_vector. start vector is not predicted (or removed from GTA)
    mel_len = mel_lengths(mels, padding_value=0.) - 1  # [N]
    # phonemes contain start and end tokens (start will be removed later)
    phon_len = phoneme_lengths(phonemes) - 1
    jumpiness, peakiness, diag_measure = attention_score(att=batch_alignments,
                                                         mel_len=mel_len,
                                                         phon_len=phon_len,
                                                         r=1)
    attn_scores = diag_measure + jumpiness + peakiness
    durations = []
    final_alignment = []
    for batch_num, al in enumerate(batch_alignments):
        unpad_mel_len = mel_len[batch_num]
        unpad_phon_len = phon_len[batch_num]
        unpad_alignments = al[:, 1:unpad_mel_len,
                              1:unpad_phon_len]  # first dim is heads
        scored_attention = unpad_alignments * attn_scores[batch_num][:, None,
                                                                     None]

        if weighted:
            ref_attention_weights = np.sum(scored_attention, axis=0)
        else:
            best_head = np.argmax(attn_scores[batch_num])
            ref_attention_weights = unpad_alignments[best_head]
        integer_durations = extract_durations_with_dijkstra(
            ref_attention_weights)

        assert np.sum(integer_durations) == mel_len[
            batch_num] - 1, f'{np.sum(integer_durations)} vs {mel_len[batch_num]-1}'
        new_alignment = duration_to_alignment_matrix(
            integer_durations.astype(int))
        best_head = np.argmax(attn_scores[batch_num])
        best_attention = unpad_alignments[best_head]
        final_alignment.append(best_attention.T + new_alignment)
        durations.append(integer_durations)
    return durations, final_alignment, jumpiness, peakiness, diag_measure
    def evaluate(self, model: Tacotron,
                 val_set: Dataset) -> Tuple[float, float]:
        model.eval()
        val_loss = 0
        val_att_score = 0
        device = next(model.parameters()).device
        for i, (x, m, ids, x_lens, mel_lens) in enumerate(val_set, 1):
            x, m = x.to(device), m.to(device)
            with torch.no_grad():
                m1_hat, m2_hat, attention = model(x, m)
                m1_loss = F.l1_loss(m1_hat, m)
                m2_loss = F.l1_loss(m2_hat, m)
                val_loss += m1_loss.item() + m2_loss.item()
            _, att_score = attention_score(attention, mel_lens)
            val_att_score += torch.mean(att_score).item()

        return val_loss / len(val_set), val_att_score / len(val_set)
    def evaluate(self, model: Tacotron,
                 val_set: Dataset) -> Tuple[float, float]:
        model.eval()
        val_loss = 0
        val_att_score = 0
        device = next(model.parameters()).device
        for i, batch in enumerate(val_set, 1):
            batch = to_device(batch, device=device)
            with torch.no_grad():
                m1_hat, m2_hat, attention = model(batch['x'], batch['mel'])
                m1_loss = F.l1_loss(m1_hat, batch['mel'])
                m2_loss = F.l1_loss(m2_hat, batch['mel'])
                val_loss += m1_loss.item() + m2_loss.item()
            _, att_score = attention_score(attention, batch['mel_len'])
            val_att_score += torch.mean(att_score).item()

        return val_loss / len(val_set), val_att_score / len(val_set)
    def train_session(self, model: Tacotron, optimizer: Optimizer,
                      session: TTSSession) -> None:
        current_step = model.get_step()
        training_steps = session.max_step - current_step
        total_iters = len(session.train_set)
        epochs = training_steps // total_iters + 1
        model.r = session.r
        simple_table([(f'Steps with r={session.r}',
                       str(training_steps // 1000) + 'k Steps'),
                      ('Batch Size', session.bs),
                      ('Learning Rate', session.lr),
                      ('Outputs/Step (r)', model.r)])
        for g in optimizer.param_groups:
            g['lr'] = session.lr

        loss_avg = Averager()
        duration_avg = Averager()
        device = next(
            model.parameters()).device  # use same device as model parameters
        for e in range(1, epochs + 1):
            for i, (x, m, ids, x_lens,
                    mel_lens) in enumerate(session.train_set, 1):
                start = time.time()
                model.train()
                x, m = x.to(device), m.to(device)

                m1_hat, m2_hat, attention = model(x, m)

                m1_loss = F.l1_loss(m1_hat, m)
                m2_loss = F.l1_loss(m2_hat, m)
                loss = m1_loss + m2_loss
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               hp.tts_clip_grad_norm)
                optimizer.step()
                loss_avg.add(loss.item())
                step = model.get_step()
                k = step // 1000

                duration_avg.add(time.time() - start)
                speed = 1. / duration_avg.get()
                msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss: {loss_avg.get():#.4} ' \
                      f'| {speed:#.2} steps/s | Step: {k}k | '

                if step % hp.tts_checkpoint_every == 0:
                    ckpt_name = f'taco_step{k}K'
                    save_checkpoint('tts',
                                    self.paths,
                                    model,
                                    optimizer,
                                    name=ckpt_name,
                                    is_silent=True)

                if step % hp.tts_plot_every == 0:
                    self.generate_plots(model, session)

                _, att_score = attention_score(attention, mel_lens)
                att_score = torch.mean(att_score)
                self.writer.add_scalar('Attention_Score/train', att_score,
                                       model.get_step())
                self.writer.add_scalar('Loss/train', loss, model.get_step())
                self.writer.add_scalar('Params/reduction_factor', session.r,
                                       model.get_step())
                self.writer.add_scalar('Params/batch_size', session.bs,
                                       model.get_step())
                self.writer.add_scalar('Params/learning_rate', session.lr,
                                       model.get_step())

                stream(msg)

            val_loss, val_att_score = self.evaluate(model, session.val_set)
            self.writer.add_scalar('Loss/val', val_loss, model.get_step())
            self.writer.add_scalar('Attention_Score/val', val_att_score,
                                   model.get_step())
            save_checkpoint('tts',
                            self.paths,
                            model,
                            optimizer,
                            is_silent=True)

            loss_avg.reset()
            duration_avg.reset()
            print(' ')
 summary_manager.display_mel(mel=output['final_output'][0],
                             tag=f'Train/predicted_mel')
 residual = abs(output['mel_linear'] - output['final_output'])
 summary_manager.display_mel(mel=residual[0],
                             tag=f'Train/conv-linear_residual')
 summary_manager.display_mel(mel=mel[0], tag=f'Train/target_mel')
 summary_manager.display_audio(tag=f'Train/prediction',
                               mel=output['final_output'][0])
 summary_manager.display_audio(tag=f'Train/target', mel=mel[0])
 for layer, k in enumerate(output['decoder_attention'].keys()):
     mel_lens = mel_lengths(mel_batch=mel,
                            padding_value=0) // model.r  # [N]
     phon_len = phoneme_lengths(phonemes)
     loc_score, peak_score, diag_measure = attention_score(
         att=output['decoder_attention'][k],
         mel_len=mel_lens,
         phon_len=phon_len,
         r=model.r)
     loc_score = tf.reduce_mean(loc_score, axis=0)
     peak_score = tf.reduce_mean(peak_score, axis=0)
     diag_measure = tf.reduce_mean(diag_measure, axis=0)
     for i in range(tf.shape(loc_score)[0]):
         summary_manager.display_scalar(
             tag=f'TrainDecoderAttentionJumpiness/layer{layer}_head{i}',
             scalar_value=tf.reduce_mean(loc_score[i]))
         summary_manager.display_scalar(
             tag=f'TrainDecoderAttentionPeakiness/layer{layer}_head{i}',
             scalar_value=tf.reduce_mean(peak_score[i]))
         summary_manager.display_scalar(
             tag=
             f'TrainDecoderAttentionDiagonality/layer{layer}_head{i}',
Exemple #8
0
def get_durations_from_alignment(batch_alignments,
                                 mels,
                                 phonemes,
                                 weighted=False,
                                 binary=False,
                                 fill_gaps=False,
                                 fix_jumps=False,
                                 fill_mode='max'):
    """
    
    :param batch_alignments: attention weights from autoregressive model.
    :param mels: mel spectrograms.
    :param phonemes: phoneme sequence.
    :param weighted: if True use weighted average of durations of heads, best head if False.
    :param binary: if True take maximum attention peak, sum if False.
    :param fill_gaps: if True fills zeros durations with ones.
    :param fix_jumps: if True, tries to scan alingments for attention jumps and interpolate.
    :param fill_mode: used only if fill_gaps is True. Is either 'max' or 'next'. Defines where to take the duration
        needed to fill the gap. Next takes it from the next non-zeros duration value, max from the sequence maximum.
    :return:
    """
    assert (binary is True) or (
        fix_jumps is False), 'Cannot fix jumps in non-binary attention.'
    # mel_len - 1 because we remove last timestep, which is end_vector. start vector is not predicted (or removed from GTA)
    mel_len = mel_lengths(mels, padding_value=0.) - 1  # [N]
    # phonemes contain start and end tokens (start will be removed later)
    phon_len = phoneme_lengths(phonemes) - 1
    jumpiness, peakiness, diag_measure = attention_score(att=batch_alignments,
                                                         mel_len=mel_len,
                                                         phon_len=phon_len,
                                                         r=1)
    attn_scores = diag_measure + jumpiness + peakiness
    durations = []
    final_alignment = []
    for batch_num, al in enumerate(batch_alignments):
        unpad_mel_len = mel_len[batch_num]
        unpad_phon_len = phon_len[batch_num]
        unpad_alignments = al[:, :unpad_mel_len,
                              1:unpad_phon_len]  # first dim is heads
        scored_attention = unpad_alignments * attn_scores[batch_num][:, None,
                                                                     None]

        if weighted:
            ref_attention_weights = np.sum(scored_attention, axis=0)
        else:
            best_head = np.argmax(attn_scores[batch_num])
            ref_attention_weights = unpad_alignments[best_head]

        if binary:  # pick max attention for each mel time-step
            binary_attn = binary_attention(ref_attention_weights)
            binary_attn_score = attention_score(
                tf.cast(binary_attn, tf.float32)[None, None, :, :],
                mel_len=unpad_mel_len[None],
                phon_len=unpad_phon_len[None] - 1,
                r=1)
            binary_score = tf.reduce_mean(binary_attn_score)
            if fix_jumps:
                binary_attn = fix_attention_jumps(
                    binary_attn=binary_attn,
                    mel_len=unpad_mel_len[None],
                    phon_len=unpad_phon_len[None] - 1,
                    binary_score=binary_score)
            integer_durations = binary_attn.sum(axis=0)
            # integer_durations = tf.reduce_sum(binary_attn, axis=0)

        else:  # takes actual attention values and normalizes to mel_len
            attention_durations = np.sum(ref_attention_weights, axis=0)
            normalized_durations = attention_durations * (
                (unpad_mel_len) / np.sum(attention_durations))
            integer_durations = np.round(normalized_durations)
            tot_duration = np.sum(integer_durations)
            duration_diff = tot_duration - (unpad_mel_len)
            while duration_diff != 0:
                rounding_diff = integer_durations - normalized_durations
                if duration_diff > 0:  # duration is too long -> reduce highest (positive) rounding difference
                    max_error_idx = np.argmax(rounding_diff)
                    integer_durations[max_error_idx] -= 1
                elif duration_diff < 0:  # duration is too short -> increase lowest (negative) rounding difference
                    min_error_idx = np.argmin(rounding_diff)
                    integer_durations[min_error_idx] += 1
                tot_duration = np.sum(integer_durations)
                duration_diff = tot_duration - (unpad_mel_len)

        if fill_gaps:  # fill zeros durations
            integer_durations = fill_zeros(integer_durations,
                                           take_from=fill_mode)

        assert np.sum(integer_durations) == mel_len[
            batch_num], f'{np.sum(integer_durations)} vs {mel_len[batch_num]}'
        new_alignment = duration_to_alignment_matrix(
            integer_durations.astype(int))
        best_head = np.argmax(attn_scores[batch_num])
        best_attention = unpad_alignments[best_head]
        final_alignment.append(best_attention.T + new_alignment)
        durations.append(integer_durations)
    return durations, final_alignment, jumpiness, peakiness, diag_measure