def train(logging_start_epoch, epoch, data, model, criterion, optimizer): """Main training procedure. Arguments: logging_start_epoch -- number of the first epoch to be logged epoch -- current epoch data -- DataLoader which can provide batches for an epoch model -- model to be trained criterion -- instance of loss function to be optimized optimizer -- instance of optimizer which will be used for parameter updates """ text_logger = logging.getLogger(__name__) model.train() # initialize counters, etc. learning_rate = optimizer.param_groups[0]['lr'] cla = 0 done, start_time = 0, time.time() total_loss = AverageMeter('Total Loss', ':.4e') mel_pre_loss = AverageMeter('Mel Pre Loss', ':.4e') mel_post_loss = AverageMeter('Mel Post Loss', ':.4e') lang_class_acc = AverageMeter('Lang Class Acc', ':.4e') progress = ProgressMeter(len(data), total_loss, mel_pre_loss, mel_post_loss, lang_class_acc, prefix="Epoch: [{}]".format(epoch), logger=text_logger) # loop through epoch batches for i, batch in enumerate(data): global_step = done + epoch * len(data) optimizer.zero_grad() # parse batch batch = list(map(to_gpu, batch)) src, src_len, trg_mel, trg_lin, trg_len, stop_trg, spkrs, langs = batch # get teacher forcing ratio if hp.constant_teacher_forcing: tf = hp.teacher_forcing else: tf = cos_decay( max(global_step - hp.teacher_forcing_start_steps, 0), hp.teacher_forcing_steps) # run the model post_pred, pre_pred, stop_pred, alignment, spkrs_pred, enc_output = model( src, src_len, trg_mel, trg_len, spkrs, langs, tf) # evaluate loss function post_trg = trg_lin if hp.predict_linear else trg_mel classifier = model._reversal_classifier if hp.reversal_classifier else None loss, batch_losses = criterion(src_len, trg_len, pre_pred, trg_mel, post_pred, post_trg, stop_pred, stop_trg, alignment, spkrs, spkrs_pred, enc_output, classifier) total_loss.update(loss, src.size(0)) mel_pre_loss.update(batch_losses['mel_pre'], src.size(0)) mel_post_loss.update(batch_losses['mel_pos'], src.size(0)) # evaluate adversarial classifier accuracy, if present if hp.reversal_classifier: input_mask = lengths_to_mask(src_len) trg_spkrs = torch.zeros_like(input_mask, dtype=torch.int64) for s in range(hp.speaker_number): speaker_mask = (spkrs == s) trg_spkrs[speaker_mask] = s matches = (trg_spkrs == torch.argmax(torch.nn.functional.softmax( spkrs_pred, dim=-1), dim=-1)) matches[~input_mask] = False cla = torch.sum(matches).item() / torch.sum(input_mask).item() lang_class_acc.update(cla, src.size(0)) # comptute gradients and make a step loss.backward() gradient = torch.nn.utils.clip_grad_norm_(model.parameters(), hp.gradient_clipping) optimizer.step() # log training progress if epoch >= logging_start_epoch: Logger.training(global_step, batch_losses, gradient, learning_rate, time.time() - start_time, cla) progress.print(i) # update criterion states (params and decay of the loss and so on ...) criterion.update_states() start_time = time.time() done += 1
def evaluate(epoch, data, model, criterion): """Main evaluation procedure. Arguments: epoch -- current epoch data -- DataLoader which can provide validation batches model -- model to be evaluated criterion -- instance of loss function to measure performance """ text_logger = logging.getLogger(__name__) model.eval() # initialize counters, etc. mcd, mcd_count = 0, 0 cla, cla_count = 0, 0 eval_losses = {} total_loss = AverageMeter('Total Loss', ':.4e') mel_pre_loss = AverageMeter('Mel Pre Loss', ':.4e') mel_post_loss = AverageMeter('Mel Post Loss', ':.4e') lang_class_acc = AverageMeter('Lang Class Acc', ':.4e') progress = ProgressMeter(len(data), total_loss, mel_pre_loss, mel_post_loss, lang_class_acc, prefix="Epoch: [{}]".format(epoch), logger=text_logger) # loop through epoch batches with torch.no_grad(): for i, batch in enumerate(data): # parse batch batch = list(map(to_gpu, batch)) src, src_len, trg_mel, trg_lin, trg_len, stop_trg, spkrs, langs = batch # run the model (twice, with and without teacher forcing) post_pred, pre_pred, stop_pred, alignment, spkrs_pred, enc_output = model( src, src_len, trg_mel, trg_len, spkrs, langs, 1.0) post_pred_0, _, stop_pred_0, alignment_0, _, _ = model( src, src_len, trg_mel, trg_len, spkrs, langs, 0.0) stop_pred_probs = torch.sigmoid(stop_pred_0) # evaluate loss function post_trg = trg_lin if hp.predict_linear else trg_mel classifier = model._reversal_classifier if hp.reversal_classifier else None loss, batch_losses = criterion(src_len, trg_len, pre_pred, trg_mel, post_pred, post_trg, stop_pred, stop_trg, alignment, spkrs, spkrs_pred, enc_output, classifier) total_loss.update(loss, src.size(0)) mel_pre_loss.update(batch_losses['mel_pre'], src.size(0)) mel_post_loss.update(batch_losses['mel_pos'], src.size(0)) # compute mel cepstral distorsion for j, (gen, ref, stop) in enumerate( zip(post_pred_0, trg_mel, stop_pred_probs)): stop_idxes = np.where(stop.cpu().numpy() > 0.5)[0] stop_idx = min( np.min(stop_idxes) + hp.stop_frames, gen.size()[1]) if len(stop_idxes) > 0 else gen.size()[1] gen = gen[:, :stop_idx].data.cpu().numpy() ref = ref[:, :trg_len[j]].data.cpu().numpy() if hp.normalize_spectrogram: gen = audio.denormalize_spectrogram( gen, not hp.predict_linear) ref = audio.denormalize_spectrogram(ref, True) if hp.predict_linear: gen = audio.linear_to_mel(gen) mcd = (mcd_count * mcd + audio.mel_cepstral_distorision( gen, ref, 'dtw')) / (mcd_count + 1) mcd_count += 1 # compute adversarial classifier accuracy if hp.reversal_classifier: input_mask = lengths_to_mask(src_len) trg_spkrs = torch.zeros_like(input_mask, dtype=torch.int64) for s in range(hp.speaker_number): speaker_mask = (spkrs == s) trg_spkrs[speaker_mask] = s matches = (trg_spkrs == torch.argmax( torch.nn.functional.softmax(spkrs_pred, dim=-1), dim=-1)) matches[~input_mask] = False cla = (cla_count * cla + torch.sum(matches).item() / torch.sum(input_mask).item()) / (cla_count + 1) cla_count += 1 lang_class_acc.update(cla, src.size(0)) # add batch losses to epoch losses for k, v in batch_losses.items(): eval_losses[k] = v + eval_losses[k] if k in eval_losses else v # normalize loss per batch for k in eval_losses.keys(): eval_losses[k] /= len(data) # log evaluation progress.print(i) Logger.evaluation(epoch + 1, eval_losses, mcd, src_len, trg_len, src, post_trg, post_pred, post_pred_0, stop_pred_probs, stop_trg, alignment_0, cla) return sum(eval_losses.values())