def train(device, model, data_loader, optimizer, writer, init_lr=0.002, checkpoint_dir=None, checkpoint_interval=None, nepochs=None, max_clip=100, clip_thresh=1.0): r = hparams.outputs_per_step current_lr = init_lr binary_criterion = nn.BCELoss() l1 = nn.L1Loss() test_flag = False global global_step, global_epoch while global_epoch < nepochs: running_loss = 0. print("{}epoch:".format(global_epoch)) for step, (x, input_lengths, mel, y, positions, done, target_lengths, speaker_ids) \ in tqdm(enumerate(data_loader)): model.train() ismultispeaker = speaker_ids is not None # Learning rate schedule if hparams.lr_schedule is not None: lr_schedule_f = getattr(lrschedule, hparams.lr_schedule) current_lr = lr_schedule_f( init_lr, global_step, **hparams.lr_schedule_kwargs) for param_group in optimizer.param_groups: param_group['lr'] = current_lr optimizer.zero_grad() # Used for Position encoding text_positions, frame_positions = positions # Lengths input_lengths = input_lengths.long().numpy() decoder_lengths = target_lengths.long().numpy() // r max_seq_len = max(input_lengths.max(), decoder_lengths.max()) if max_seq_len >= hparams.max_positions: raise RuntimeError( """max_seq_len ({}) >= max_posision ({}) Input text or decoder targget length exceeded the maximum length. Please set a larger value for ``max_position`` in hyper parameters.""".format( max_seq_len, hparams.max_positions)) # Transform data to CUDA device x = x.to(device) text_positions = text_positions.to(device) frame_positions = frame_positions.to(device) voiced, f0, sp, ap = y f0 = f0.to(device) sp = sp.to(device) ap = ap.to(device) voiced = voiced.to(device) mel, done = mel.to(device), done.to(device) target_lengths = target_lengths.to(device) speaker_ids = speaker_ids.to(device) if ismultispeaker else None # Apply model mel_outputs, world_outputs, attn, done_hat= model( x, mel, speaker_ids=speaker_ids, text_positions=text_positions, frame_positions=frame_positions, input_lengths=input_lengths) # reshape mel_outputs = mel_outputs.view(len(mel), -1, mel.size(-1)) # Losses # mel: mel_loss = l1(mel_outputs[:, :-r, :], mel[:, r:, :]) # done: done_loss = binary_criterion(done_hat, done) rw = int(r * hparams.world_upsample) vo_hat, f0_outputs, sp_outputs, ap_outputs = world_outputs f0_loss = l1(f0_outputs[:, :-rw], f0[:, rw:]) sp_loss = l1(sp_outputs[:, :-rw, :], sp[:, rw:, :]) ap_loss = l1(ap_outputs[:, :-rw, :], ap[:, rw:, :]) voiced_loss = binary_criterion(vo_hat[:, :-rw], voiced[:, rw:]) # Combine losses loss = mel_loss + done_loss + f0_loss + sp_loss + ap_loss + voiced_loss if global_epoch == 0 and global_step == 0: tm.save_states( global_step, writer, mel_outputs, world_outputs, attn, mel, y, input_lengths, checkpoint_dir) tm.save_checkpoint( model, optimizer, global_step, checkpoint_dir, global_epoch) # Update loss.backward() if clip_thresh > 0: grad_norm = torch.nn.utils.clip_grad_norm_( model.get_trainable_parameters(), max_clip) torch.nn.utils.clip_grad_value_( model.get_trainable_parameters(),clip_thresh) optimizer.step() # Logs writer.add_scalar("loss", float(loss.item()), global_step) writer.add_scalar("done_loss", float(done_loss.item()), global_step) writer.add_scalar("mel_l1_loss", float(mel_loss.item()), global_step) writer.add_scalar("f0_l1_loss", float(f0_loss.item()), global_step) writer.add_scalar("sp_l1_loss", float(sp_loss.item()), global_step) writer.add_scalar("ap_l1_loss", float(ap_loss.item()), global_step) writer.add_scalar("voiced_loss", float(voiced_loss.item()), global_step) if clip_thresh > 0: writer.add_scalar("gradient norm", grad_norm, global_step) writer.add_scalar("learning rate", current_lr, global_step) global_step += 1 running_loss += loss.item() if global_step > 0 and global_step % checkpoint_interval == 0: tm.save_states( global_step, writer, mel_outputs, world_outputs, attn, mel, y, input_lengths, checkpoint_dir) tm.save_checkpoint( model, optimizer, global_step, checkpoint_dir, global_epoch) if not test_flag or global_step > 1e5 and global_step % hparams.eval_interval == 0 : tm.eval_model(global_step, writer, device, model, checkpoint_dir, ismultispeaker) test_flag = True averaged_loss = running_loss / (len(data_loader)) writer.add_scalar("loss (per epoch)", averaged_loss, global_epoch) print("Loss: {}".format(running_loss / (len(data_loader)))) global_epoch += 1
# Load embedding if load_embedding is not None: print("Loading embedding from {}".format(load_embedding)) tm._load_embedding(load_embedding, model) # Setup summary writer for tensorboard if log_event_path is None: if platform.system() == "Windows": log_event_path = "log/run-test" + \ str(datetime.now()).replace(" ", "_").replace(":", "_") else: log_event_path = "log/run-test" + str(datetime.now()).replace(" ", "_") print("Log event path: {}".format(log_event_path)) writer = SummaryWriter(log_event_path) # Train! try: train(device, model, data_loader, optimizer, writer, init_lr=hparams.initial_learning_rate, checkpoint_dir=checkpoint_dir, checkpoint_interval=hparams.checkpoint_interval, nepochs=hparams.nepochs, max_clip=hparams.max_clip, clip_thresh=hparams.clip_thresh) except KeyboardInterrupt: tm.save_checkpoint( model, optimizer, global_step, checkpoint_dir, global_epoch) print("Finished") sys.exit(0)