def train(train_loader, model, device, mels_criterion, stop_criterion, optimizer, scheduler, writer, train_dir): batch_time = ValueWindow() data_time = ValueWindow() losses = ValueWindow() # switch to train mode model.train() end = time.time() global global_epoch global global_step for i, (txts, mels, stop_tokens, txt_lengths, mels_lengths) in enumerate(train_loader): scheduler.adjust_learning_rate(optimizer, global_step) # measure data loading time data_time.update(time.time() - end) if device > -1: txts = txts.cuda(device) mels = mels.cuda(device) stop_tokens = stop_tokens.cuda(device) txt_lengths = txt_lengths.cuda(device) mels_lengths = mels_lengths.cuda(device) # compute output frames, decoder_frames, stop_tokens_predict, alignment = model( txts, txt_lengths, mels) decoder_frames_loss = mels_criterion(decoder_frames, mels, lengths=mels_lengths) frames_loss = mels_criterion(frames, mels, lengths=mels_lengths) stop_token_loss = stop_criterion(stop_tokens_predict, stop_tokens, lengths=mels_lengths) loss = decoder_frames_loss + frames_loss + stop_token_loss #print(frames_loss, decoder_frames_loss) losses.update(loss.item()) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() if hparams.clip_thresh > 0: grad_norm = torch.nn.utils.clip_grad_norm_( model.get_trainable_parameters(), hparams.clip_thresh) optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % hparams.print_freq == 0: log('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})'.format( global_epoch, i, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses)) # Logs writer.add_scalar("loss", float(loss.item()), global_step) writer.add_scalar( "avg_loss in {} window".format(losses.get_dinwow_size), float(losses.avg), global_step) writer.add_scalar("stop_token_loss", float(stop_token_loss.item()), global_step) writer.add_scalar("decoder_frames_loss", float(decoder_frames_loss.item()), global_step) writer.add_scalar("output_frames_loss", float(frames_loss.item()), global_step) if hparams.clip_thresh > 0: writer.add_scalar("gradient norm", grad_norm, global_step) writer.add_scalar("learning rate", optimizer.param_groups[0]['lr'], global_step) global_step += 1 dst_alignment_path = join(train_dir, "{}_alignment.png".format(global_step)) alignment = alignment.cpu().detach().numpy() plot_alignment(alignment[0, :txt_lengths[0], :mels_lengths[0]], dst_alignment_path, info="{}, {}".format(hparams.builder, global_step))
def validate(val_loader, model, device, mels_criterion, stop_criterion, writer, val_dir): batch_time = ValueWindow() losses = ValueWindow() # switch to evaluate mode model.eval() global global_epoch global global_step with torch.no_grad(): end = time.time() for i, (txts, mels, stop_tokens, txt_lengths, mels_lengths) in enumerate(val_loader): # measure data loading time batch_time.update(time.time() - end) if device > -1: txts = txts.cuda(device) mels = mels.cuda(device) stop_tokens = stop_tokens.cuda(device) txt_lengths = txt_lengths.cuda(device) mels_lengths = mels_lengths.cuda(device) # compute output frames, decoder_frames, stop_tokens_predict, alignment = model( txts, txt_lengths, mels) decoder_frames_loss = mels_criterion(decoder_frames, mels, lengths=mels_lengths) frames_loss = mels_criterion(frames, mels, lengths=mels_lengths) stop_token_loss = stop_criterion(stop_tokens_predict, stop_tokens, lengths=mels_lengths) loss = decoder_frames_loss + frames_loss + stop_token_loss losses.update(loss.item()) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % hparams.print_freq == 0: log('Epoch: [{0}]\t' 'Test: [{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})'.format( global_epoch, i, len(val_loader), batch_time=batch_time, loss=losses)) # Logs writer.add_scalar("loss", float(loss.item()), global_step) writer.add_scalar( "avg_loss in {} window".format(losses.get_dinwow_size), float(losses.avg), global_step) writer.add_scalar("stop_token_loss", float(stop_token_loss.item()), global_step) writer.add_scalar("decoder_frames_loss", float(decoder_frames_loss.item()), global_step) writer.add_scalar("output_frames_loss", float(frames_loss.item()), global_step) dst_alignment_path = join(val_dir, "{}_alignment.png".format(global_step)) alignment = alignment.cpu().detach().numpy() plot_alignment(alignment[0, :txt_lengths[0], :mels_lengths[0]], dst_alignment_path, info="{}, {}".format(hparams.builder, global_step)) return losses.avg