def fix_attention_jumps(binary_attn, binary_score, mel_len, phon_len): """ Scans for jumps in attention and attempts to fix. If score decreases, a collapse is likely so it tries to relax the jump size. Lower jumps size is more accurate, but more prone to collapse. """ clean_scores = [] clean_attns = [] for jumpth in [2, 3, 4, 5]: cl_at = clean_attention(binary_attention=binary_attn, jump_threshold=jumpth) clean_attns.append(cl_at) sclean_score = attention_score(att=tf.cast(cl_at[None, None, :, :], tf.float32), mel_len=mel_len, phon_len=phon_len, r=1) clean_scores.append(tf.reduce_mean(sclean_score)) best_idx = np.argmax(clean_scores) best_score = clean_scores[best_idx] best_cleaned_attention = clean_attns[best_idx] while (binary_score > best_score) and (jumpth < 20): jumpth += 1 best_cleaned_attention = clean_attention(binary_attention=binary_attn, jump_threshold=jumpth) best_score = attention_score(att=tf.cast( best_cleaned_attention[None, None, :, :], tf.float32), mel_len=mel_len, phon_len=phon_len, r=1) best_score = tf.reduce_mean(best_score) if binary_score > best_score: best_cleaned_attention = binary_attn return best_cleaned_attention
def create_align_features( model: Tacotron, train_set: DataLoader, val_set: DataLoader, save_path_alg: Path, # save_path_pitch: Path ): assert model.r == 1, f'Reduction factor of tacotron must be 1 for creating alignment features! ' \ f'Reduction factor was: {model.r}' model.eval() device = next( model.parameters()).device # use same device as model parameters if val_set is not None: iters = len(val_set) + len(train_set) dataset = itertools.chain(train_set, val_set) else: # print('here') iters = len(train_set) # print(iters) dataset = itertools.chain(train_set) att_score_dict = {} if hp.extract_durations_with_dijkstra: print('Extracting durations using dijkstra...') dur_extraction_func = extract_durations_with_dijkstra else: print('Extracting durations using attention peak counts...') dur_extraction_func = extract_durations_per_count # for i in dataset: # print(i) for i, (x, mels, ids, x_lens, mel_lens) in enumerate(dataset, 1): x, mels = x.to(device), mels.to(device) # print(x) # print(mels) with torch.no_grad(): _, _, att_batch = model(x, mels) align_score, sharp_score = attention_score(att_batch, mel_lens, r=1) att_batch = np_now(att_batch) seq, att, mel_len, item_id = x[0], att_batch[0], mel_lens[0], ids[0] align_score, sharp_score = float(align_score[0]), float(sharp_score[0]) att_score_dict[item_id] = (align_score, sharp_score) durs = dur_extraction_func(seq, att, mel_len) if np.sum(durs) != mel_len: print( f'WARNINNG: Sum of durations did not match mel length for item {item_id}!' ) np.save(str(save_path_alg / f'{item_id}.npy'), durs, allow_pickle=False) bar = progbar(i, iters) msg = f'{bar} {i}/{iters} Batches ' stream(msg) pickle_binary(att_score_dict, paths.data / 'att_score_dict.pkl')
def get_durations_from_alignment(batch_alignments, mels, phonemes, weighted=False): """ :param batch_alignments: attention weights from autoregressive model. :param mels: mel spectrograms. :param phonemes: phoneme sequence. :param weighted: if True use weighted average of durations of heads, best head if False. :param binary: if True take maximum attention peak, sum if False. :param fill_gaps: if True fills zeros durations with ones. :param fix_jumps: if True, tries to scan alingments for attention jumps and interpolate. :param fill_mode: used only if fill_gaps is True. Is either 'max' or 'next'. Defines where to take the duration needed to fill the gap. Next takes it from the next non-zeros duration value, max from the sequence maximum. :return: """ # mel_len - 1 because we remove last timestep, which is end_vector. start vector is not predicted (or removed from GTA) mel_len = mel_lengths(mels, padding_value=0.) - 1 # [N] # phonemes contain start and end tokens (start will be removed later) phon_len = phoneme_lengths(phonemes) - 1 jumpiness, peakiness, diag_measure = attention_score(att=batch_alignments, mel_len=mel_len, phon_len=phon_len, r=1) attn_scores = diag_measure + jumpiness + peakiness durations = [] final_alignment = [] for batch_num, al in enumerate(batch_alignments): unpad_mel_len = mel_len[batch_num] unpad_phon_len = phon_len[batch_num] unpad_alignments = al[:, 1:unpad_mel_len, 1:unpad_phon_len] # first dim is heads scored_attention = unpad_alignments * attn_scores[batch_num][:, None, None] if weighted: ref_attention_weights = np.sum(scored_attention, axis=0) else: best_head = np.argmax(attn_scores[batch_num]) ref_attention_weights = unpad_alignments[best_head] integer_durations = extract_durations_with_dijkstra( ref_attention_weights) assert np.sum(integer_durations) == mel_len[ batch_num] - 1, f'{np.sum(integer_durations)} vs {mel_len[batch_num]-1}' new_alignment = duration_to_alignment_matrix( integer_durations.astype(int)) best_head = np.argmax(attn_scores[batch_num]) best_attention = unpad_alignments[best_head] final_alignment.append(best_attention.T + new_alignment) durations.append(integer_durations) return durations, final_alignment, jumpiness, peakiness, diag_measure
def evaluate(self, model: Tacotron, val_set: Dataset) -> Tuple[float, float]: model.eval() val_loss = 0 val_att_score = 0 device = next(model.parameters()).device for i, (x, m, ids, x_lens, mel_lens) in enumerate(val_set, 1): x, m = x.to(device), m.to(device) with torch.no_grad(): m1_hat, m2_hat, attention = model(x, m) m1_loss = F.l1_loss(m1_hat, m) m2_loss = F.l1_loss(m2_hat, m) val_loss += m1_loss.item() + m2_loss.item() _, att_score = attention_score(attention, mel_lens) val_att_score += torch.mean(att_score).item() return val_loss / len(val_set), val_att_score / len(val_set)
def evaluate(self, model: Tacotron, val_set: Dataset) -> Tuple[float, float]: model.eval() val_loss = 0 val_att_score = 0 device = next(model.parameters()).device for i, batch in enumerate(val_set, 1): batch = to_device(batch, device=device) with torch.no_grad(): m1_hat, m2_hat, attention = model(batch['x'], batch['mel']) m1_loss = F.l1_loss(m1_hat, batch['mel']) m2_loss = F.l1_loss(m2_hat, batch['mel']) val_loss += m1_loss.item() + m2_loss.item() _, att_score = attention_score(attention, batch['mel_len']) val_att_score += torch.mean(att_score).item() return val_loss / len(val_set), val_att_score / len(val_set)
def train_session(self, model: Tacotron, optimizer: Optimizer, session: TTSSession) -> None: current_step = model.get_step() training_steps = session.max_step - current_step total_iters = len(session.train_set) epochs = training_steps // total_iters + 1 model.r = session.r simple_table([(f'Steps with r={session.r}', str(training_steps // 1000) + 'k Steps'), ('Batch Size', session.bs), ('Learning Rate', session.lr), ('Outputs/Step (r)', model.r)]) for g in optimizer.param_groups: g['lr'] = session.lr loss_avg = Averager() duration_avg = Averager() device = next( model.parameters()).device # use same device as model parameters for e in range(1, epochs + 1): for i, (x, m, ids, x_lens, mel_lens) in enumerate(session.train_set, 1): start = time.time() model.train() x, m = x.to(device), m.to(device) m1_hat, m2_hat, attention = model(x, m) m1_loss = F.l1_loss(m1_hat, m) m2_loss = F.l1_loss(m2_hat, m) loss = m1_loss + m2_loss optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), hp.tts_clip_grad_norm) optimizer.step() loss_avg.add(loss.item()) step = model.get_step() k = step // 1000 duration_avg.add(time.time() - start) speed = 1. / duration_avg.get() msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss: {loss_avg.get():#.4} ' \ f'| {speed:#.2} steps/s | Step: {k}k | ' if step % hp.tts_checkpoint_every == 0: ckpt_name = f'taco_step{k}K' save_checkpoint('tts', self.paths, model, optimizer, name=ckpt_name, is_silent=True) if step % hp.tts_plot_every == 0: self.generate_plots(model, session) _, att_score = attention_score(attention, mel_lens) att_score = torch.mean(att_score) self.writer.add_scalar('Attention_Score/train', att_score, model.get_step()) self.writer.add_scalar('Loss/train', loss, model.get_step()) self.writer.add_scalar('Params/reduction_factor', session.r, model.get_step()) self.writer.add_scalar('Params/batch_size', session.bs, model.get_step()) self.writer.add_scalar('Params/learning_rate', session.lr, model.get_step()) stream(msg) val_loss, val_att_score = self.evaluate(model, session.val_set) self.writer.add_scalar('Loss/val', val_loss, model.get_step()) self.writer.add_scalar('Attention_Score/val', val_att_score, model.get_step()) save_checkpoint('tts', self.paths, model, optimizer, is_silent=True) loss_avg.reset() duration_avg.reset() print(' ')
summary_manager.display_mel(mel=output['final_output'][0], tag=f'Train/predicted_mel') residual = abs(output['mel_linear'] - output['final_output']) summary_manager.display_mel(mel=residual[0], tag=f'Train/conv-linear_residual') summary_manager.display_mel(mel=mel[0], tag=f'Train/target_mel') summary_manager.display_audio(tag=f'Train/prediction', mel=output['final_output'][0]) summary_manager.display_audio(tag=f'Train/target', mel=mel[0]) for layer, k in enumerate(output['decoder_attention'].keys()): mel_lens = mel_lengths(mel_batch=mel, padding_value=0) // model.r # [N] phon_len = phoneme_lengths(phonemes) loc_score, peak_score, diag_measure = attention_score( att=output['decoder_attention'][k], mel_len=mel_lens, phon_len=phon_len, r=model.r) loc_score = tf.reduce_mean(loc_score, axis=0) peak_score = tf.reduce_mean(peak_score, axis=0) diag_measure = tf.reduce_mean(diag_measure, axis=0) for i in range(tf.shape(loc_score)[0]): summary_manager.display_scalar( tag=f'TrainDecoderAttentionJumpiness/layer{layer}_head{i}', scalar_value=tf.reduce_mean(loc_score[i])) summary_manager.display_scalar( tag=f'TrainDecoderAttentionPeakiness/layer{layer}_head{i}', scalar_value=tf.reduce_mean(peak_score[i])) summary_manager.display_scalar( tag= f'TrainDecoderAttentionDiagonality/layer{layer}_head{i}',
def get_durations_from_alignment(batch_alignments, mels, phonemes, weighted=False, binary=False, fill_gaps=False, fix_jumps=False, fill_mode='max'): """ :param batch_alignments: attention weights from autoregressive model. :param mels: mel spectrograms. :param phonemes: phoneme sequence. :param weighted: if True use weighted average of durations of heads, best head if False. :param binary: if True take maximum attention peak, sum if False. :param fill_gaps: if True fills zeros durations with ones. :param fix_jumps: if True, tries to scan alingments for attention jumps and interpolate. :param fill_mode: used only if fill_gaps is True. Is either 'max' or 'next'. Defines where to take the duration needed to fill the gap. Next takes it from the next non-zeros duration value, max from the sequence maximum. :return: """ assert (binary is True) or ( fix_jumps is False), 'Cannot fix jumps in non-binary attention.' # mel_len - 1 because we remove last timestep, which is end_vector. start vector is not predicted (or removed from GTA) mel_len = mel_lengths(mels, padding_value=0.) - 1 # [N] # phonemes contain start and end tokens (start will be removed later) phon_len = phoneme_lengths(phonemes) - 1 jumpiness, peakiness, diag_measure = attention_score(att=batch_alignments, mel_len=mel_len, phon_len=phon_len, r=1) attn_scores = diag_measure + jumpiness + peakiness durations = [] final_alignment = [] for batch_num, al in enumerate(batch_alignments): unpad_mel_len = mel_len[batch_num] unpad_phon_len = phon_len[batch_num] unpad_alignments = al[:, :unpad_mel_len, 1:unpad_phon_len] # first dim is heads scored_attention = unpad_alignments * attn_scores[batch_num][:, None, None] if weighted: ref_attention_weights = np.sum(scored_attention, axis=0) else: best_head = np.argmax(attn_scores[batch_num]) ref_attention_weights = unpad_alignments[best_head] if binary: # pick max attention for each mel time-step binary_attn = binary_attention(ref_attention_weights) binary_attn_score = attention_score( tf.cast(binary_attn, tf.float32)[None, None, :, :], mel_len=unpad_mel_len[None], phon_len=unpad_phon_len[None] - 1, r=1) binary_score = tf.reduce_mean(binary_attn_score) if fix_jumps: binary_attn = fix_attention_jumps( binary_attn=binary_attn, mel_len=unpad_mel_len[None], phon_len=unpad_phon_len[None] - 1, binary_score=binary_score) integer_durations = binary_attn.sum(axis=0) # integer_durations = tf.reduce_sum(binary_attn, axis=0) else: # takes actual attention values and normalizes to mel_len attention_durations = np.sum(ref_attention_weights, axis=0) normalized_durations = attention_durations * ( (unpad_mel_len) / np.sum(attention_durations)) integer_durations = np.round(normalized_durations) tot_duration = np.sum(integer_durations) duration_diff = tot_duration - (unpad_mel_len) while duration_diff != 0: rounding_diff = integer_durations - normalized_durations if duration_diff > 0: # duration is too long -> reduce highest (positive) rounding difference max_error_idx = np.argmax(rounding_diff) integer_durations[max_error_idx] -= 1 elif duration_diff < 0: # duration is too short -> increase lowest (negative) rounding difference min_error_idx = np.argmin(rounding_diff) integer_durations[min_error_idx] += 1 tot_duration = np.sum(integer_durations) duration_diff = tot_duration - (unpad_mel_len) if fill_gaps: # fill zeros durations integer_durations = fill_zeros(integer_durations, take_from=fill_mode) assert np.sum(integer_durations) == mel_len[ batch_num], f'{np.sum(integer_durations)} vs {mel_len[batch_num]}' new_alignment = duration_to_alignment_matrix( integer_durations.astype(int)) best_head = np.argmax(attn_scores[batch_num]) best_attention = unpad_alignments[best_head] final_alignment.append(best_attention.T + new_alignment) durations.append(integer_durations) return durations, final_alignment, jumpiness, peakiness, diag_measure