def train_loop(device, model, optimizer, data_loader, log_dir): global global_step, global_epoch criterion = torch.nn.MSELoss().to(device) while global_epoch < hp.epochs: running_loss = 0 for i, (x, input_lengths, y) in enumerate(data_loader): # Sort by length sorted_lengths, indices = torch.sort(input_lengths.view(-1), dim=0, descending=True) sorted_lengths = sorted_lengths.long().numpy() # Get sorted batch x, y = x[indices], y[indices] x, y = x.to(device), y.to(device) y_hat = model(x, sorted_lengths) loss = criterion(y_hat, y) # calculate learning rate and update learning rate if hp.fixed_learning_rate: current_lr = hp.fixed_learning_rate elif hp.lr_schedule_type == 'step': current_lr = step_learning_rate_decay(hp.init_learning_rate, global_step, hp.step_gamma, hp.lr_step_interval) else: current_lr = noam_learning_rate_decay(hp.init_learning_rate, global_step, hp.noam_warm_up_steps) for param_group in optimizer.param_groups: param_group['lr'] = current_lr # optimizer.zero_grad() # gradient clear loss = loss / hp.accumulation_steps # loss regularization loss.backward() # BP for gradient if ((i + 1) % hp.accumulation_steps) == 0: torch.nn.utils.clip_grad_norm_( model.parameters(), hp.grad_norm) # clip gradient norm optimizer.step() # update parameters of net optimizer.zero_grad() running_loss += loss.item() avg_loss = running_loss / (i + 1) # saving checkpoint if global_step != 0 and global_step % hp.checkpoint_interval == 0: # pruner.prune(global_step) save_checkpoint(device, model, optimizer, global_step, global_epoch, log_dir) global_step += 1 print("epoch:{:4d} [loss={:.5f}, avg_loss={:.5f}, current_lr={}]". format(global_epoch, running_loss, avg_loss, current_lr)) global_epoch += 1
def train_loop(device, model, data_loader, optimizer, checkpoint_dir): """Main training loop. """ # create loss and put on device if hp.input_type == 'raw': if hp.distribution == 'beta': criterion = beta_mle_loss elif hp.distribution == 'gaussian': criterion = gaussian_loss elif hp.input_type == 'mixture': criterion = discretized_mix_logistic_loss elif hp.input_type in ["bits", "mulaw"]: criterion = nll_loss else: raise ValueError("input_type:{} not supported".format(hp.input_type)) global global_step, global_epoch, global_test_step while global_epoch < hp.nepochs: running_loss = 0 for i, (x, m, y) in enumerate(tqdm(data_loader)): x, m, y = x.to(device), m.to(device), y.to(device) y_hat = model(x, m) y = y.unsqueeze(-1) loss = criterion(y_hat, y) # calculate learning rate and update learning rate if hp.fix_learning_rate: current_lr = hp.fix_learning_rate elif hp.lr_schedule_type == 'step': current_lr = step_learning_rate_decay(hp.initial_learning_rate, global_step, hp.step_gamma, hp.lr_step_interval) else: current_lr = noam_learning_rate_decay(hp.initial_learning_rate, global_step, hp.noam_warm_up_steps) for param_group in optimizer.param_groups: param_group['lr'] = current_lr optimizer.zero_grad() loss.backward() # clip gradient norm nn.utils.clip_grad_norm_(model.parameters(), hp.grad_norm) optimizer.step() running_loss += loss.item() avg_loss = running_loss / (i+1) # saving checkpoint if needed if global_step != 0 and global_step % hp.save_every_step == 0: save_checkpoint(device, model, optimizer, global_step, checkpoint_dir, global_epoch) # evaluate model if needed if global_step != 0 and global_test_step !=True and global_step % hp.evaluate_every_step == 0: print("step {}, evaluating model: generating wav from mel...".format(global_step)) evaluate_model(model, data_loader, checkpoint_dir) print("evaluation finished, resuming training...") # reset global_test_step status after evaluation if global_test_step is True: global_test_step = False global_step += 1 print("epoch:{}, running loss:{}, average loss:{}, current lr:{}".format(global_epoch, running_loss, avg_loss, current_lr)) global_epoch += 1
def train(self, global_step=0, global_epoch=1): while global_epoch < self.epoch: running_loss = 0. for step, (melX, melY, lengths) in enumerate(tqdm(self.train_loader)): self.model.train() # Learn rate scheduler current_lr = noam_learning_rate_decay(self.args.learn_rate, global_step) for param_group in self.optimizer.param_groups: param_group['lr'] = current_lr self.optimizer.zero_grad() # Transform data to CUDA device melX = melX.to(self.device) melY = melY.to(self.device) lengths = lengths.to(self.device) target_mask = sequence_mask(lengths, max_len=melY.size(1)).unsqueeze(-1) # Apply model melX_output = self.model(melX) # TODO : code model # Losses mel_l1_loss, mel_binary_div = self.spec_loss(melX_output, melY, target_mask) loss = (1 - self.w) * mel_l1_loss + self.w * mel_binary_div # Update loss.backward() self.optimizer.step() # Logs self.writer.add_scalar("loss", float(loss.item()), global_step) self.writer.add_scalar("mel_l1_loss", float(mel_l1_loss.item()), global_step) self.writer.add_scalar("mel_binary_div_loss", float(mel_binary_div.item()), global_step) self.writer.add_scalar("learning rate", current_lr, global_step) global_step += 1 running_loss += loss.item() if (global_epoch % self.checkpoint_interval == 0): self.save_checkpoint(global_step, global_epoch) if global_epoch % self.eval_interval == 0: self.save_states(global_epoch, melX_output, melX, melY, lengths) self.eval_model(global_epoch) avg_loss = running_loss / len(self.train_loader) self.writer.add_scalar("train loss (per epoch)", avg_loss, global_epoch) print("Train Loss: {}".format(avg_loss)) global_epoch += 1
def train_loop(device, model, optimizer, data_loader, log_dir): global global_step, global_epoch criterion = torch.nn.MSELoss().to(device) while global_epoch < hp.epochs: running_loss = 0 for i, (x, y) in enumerate(data_loader): x, y = x.to(device), y.to(device) y_hat = model(x) loss = criterion(y_hat, y) # calculate learning rate and update learning rate if hp.fixed_learning_rate: current_lr = hp.fixed_learning_rate elif hp.lr_schedule_type == 'step': current_lr = step_learning_rate_decay(hp.init_learning_rate, global_step, hp.step_gamma, hp.lr_step_interval) else: current_lr = noam_learning_rate_decay(hp.init_learning_rate, global_step, hp.noam_warm_up_steps) for param_group in optimizer.param_groups: param_group['lr'] = current_lr optimizer.zero_grad() # gradient clear loss.backward() # BP for gradient torch.nn.utils.clip_grad_norm_(model.parameters(), hp.grad_norm) # clip gradient norm optimizer.step() # update weight parameter running_loss += loss.item() avg_loss = running_loss / (i+1) # saving checkpoint if global_step != 0 and global_step % hp.checkpoint_interval == 0: # pruner.prune(global_step) save_checkpoint(device, model, optimizer, global_step, global_epoch, log_dir) global_step += 1 print("epoch:{:4d} [loss={:.5f}, avg_loss={:.5f}, current_lr={}]".format(global_epoch, running_loss, avg_loss, current_lr)) global_epoch += 1
def get_learning_rate(global_step, n_iters): if hp.fix_learning_rate: current_lr = hp.fix_learning_rate elif hp.lr_schedule_type == 'step': current_lr = step_learning_rate_decay(hp.initial_learning_rate, global_step, hp.step_gamma, hp.lr_step_interval) elif hp.lr_schedule_type == 'one-cycle': max_iters = n_iters*hp.nepochs cycle_width = int(max_iters*hp.fine_tune) step_size = cycle_width//2 if global_step < cycle_width: cycle = np.floor(1 + global_step/(2*step_size)) x = abs(global_step/step_size - 2*cycle + 1) current_lr = hp.min_lr + (hp.max_lr - hp.min_lr)*max(0, (1-x)) else: x = (max_iters - global_step)/(max_iters - cycle_width) current_lr = 0.01*hp.min_lr + 0.99*hp.min_lr*x else: current_lr = noam_learning_rate_decay(hp.initial_learning_rate, global_step, hp.noam_warm_up_steps) return current_lr
def train(self, train_seq2seq, train_postnet, global_epoch=1, global_step=0): while global_epoch < self.epoch: running_loss = 0. running_linear_loss = 0. running_mel_loss = 0. for step, (ling, mel, linear, lengths, speaker_ids) in enumerate(tqdm(self.train_loader)): self.model.train() ismultispeaker = speaker_ids is not None # Learn rate scheduler current_lr = noam_learning_rate_decay( self.hparams.initial_learning_rate, global_step) for param_group in self.optimizer.param_groups: param_group['lr'] = current_lr self.optimizer.zero_grad() # Transform data to CUDA device if train_seq2seq: ling = ling.to(self.device) mel = mel.to(self.device) if train_postnet: linear = linear.to(self.device) lengths = lengths.to(self.device) speaker_ids = speaker_ids.to( self.device) if ismultispeaker else None target_mask = sequence_mask(lengths, max_len=mel.size(1)).unsqueeze(-1) # Apply model if train_seq2seq and train_postnet: _, mel_outputs, linear_outputs = self.model( ling, mel, speaker_ids=speaker_ids) #elif train_seq2seq: # mel_style = self.model.gst(tmel) # style_embed = mel_style.expand_as(smel) # mel_input = smel + style_embed # mel_outputs = self.model.seq2seq(mel_input) # linear_outputs = None #elif train_postnet: # linear_outputs = self.model.postnet(smel) # mel_outputs = None # Losses if train_seq2seq: mel_l1_loss, mel_binary_div = self.spec_loss( mel_outputs, mel, target_mask) mel_loss = (1 - self.w) * mel_l1_loss + self.w * mel_binary_div if train_postnet: linear_l1_loss, linear_binary_div = self.spec_loss( linear_outputs, linear, target_mask) linear_loss = ( 1 - self.w) * linear_l1_loss + self.w * linear_binary_div # Combine losses if train_seq2seq and train_postnet: loss = mel_loss + linear_loss elif train_seq2seq: loss = mel_loss elif train_postnet: loss = linear_loss # Update loss.backward() self.optimizer.step() # Logs if train_seq2seq: self.writer.add_scalar("mel loss", float(mel_loss.item()), global_step) self.writer.add_scalar("mel_l1_loss", float(mel_l1_loss.item()), global_step) self.writer.add_scalar("mel_binary_div_loss", float(mel_binary_div.item()), global_step) if train_postnet: self.writer.add_scalar("linear_loss", float(linear_loss.item()), global_step) self.writer.add_scalar("linear_l1_loss", float(linear_l1_loss.item()), global_step) self.writer.add_scalar("linear_binary_div_loss", float(linear_binary_div.item()), global_step) self.writer.add_scalar("loss", float(loss.item()), global_step) self.writer.add_scalar("learning rate", current_lr, global_step) global_step += 1 running_loss += loss.item() running_linear_loss += linear_loss.item() running_mel_loss += mel_loss.item() if (global_epoch % self.checkpoint_interval == 0): self.save_checkpoint(global_step, global_epoch) if global_epoch % self.eval_interval == 0: self.save_states(global_epoch, mel_outputs, linear_outputs, ling, mel, linear, lengths) self.eval_model(global_epoch, train_seq2seq, train_postnet) avg_loss = running_loss / len(self.train_loader) avg_linear_loss = running_linear_loss / len(self.train_loader) avg_mel_loss = running_mel_loss / len(self.train_loader) self.writer.add_scalar("train loss (per epoch)", avg_loss, global_epoch) self.writer.add_scalar("train linear loss (per epoch)", avg_linear_loss, global_epoch) self.writer.add_scalar("train mel loss (per epoch)", avg_mel_loss, global_epoch) print("Train Loss: {}".format(avg_loss)) global_epoch += 1
def train_loop(device, model, data_loader, optimizer, checkpoint_dir): # create loss and put on device if hp.input_type == 'raw': if hp.distribution == 'beta': criterion = beta_mle_loss elif hp.distribution == 'gaussian': criterion = gaussian_loss elif hp.input_type == 'mixture': criterion = discretized_mix_logistic_loss elif hp.input_type in ["bits", "mulaw"]: criterion = nll_loss else: raise ValueError("input_type:{} not supported".format(hp.input_type)) # Pruner for reducing memory footprint layers = [(model.I, hp.sparsity_target), (model.rnn1, hp.sparsity_target_rnn), (model.fc1, hp.sparsity_target), (model.fc3, hp.sparsity_target) ] #(model.fc2,hp.sparsity_target), pruner = Pruner(layers, hp.start_prune, hp.prune_steps, hp.sparsity_target) global global_step, global_epoch, global_test_step while global_epoch < hp.nepochs: running_loss = 0 for i, (x, m, y) in enumerate(tqdm(data_loader)): x, m, y = x.to(device), m.to(device), y.to(device) y_hat = model(x, m) y = y.unsqueeze(-1) loss = criterion(y_hat, y) # calculate learning rate and update learning rate if hp.fix_learning_rate: current_lr = hp.fix_learning_rate elif hp.lr_schedule_type == 'step': current_lr = step_learning_rate_decay(hp.initial_learning_rate, global_step, hp.step_gamma, hp.lr_step_interval) else: current_lr = noam_learning_rate_decay(hp.initial_learning_rate, global_step, hp.noam_warm_up_steps) for param_group in optimizer.param_groups: param_group['lr'] = current_lr optimizer.zero_grad() loss.backward() grad_norm = nn.utils.clip_grad_norm_(model.parameters(), hp.grad_norm) optimizer.step() num_pruned, z = pruner.prune(global_step) running_loss += loss.item() avg_loss = running_loss / (i + 1) writer.add_scalar("loss", float(loss.item()), global_step) writer.add_scalar("avg_loss", float(avg_loss), global_step) writer.add_scalar("learning_rate", float(current_lr), global_step) writer.add_scalar("grad_norm", float(grad_norm), global_step) writer.add_scalar("num_pruned", float(num_pruned), global_step) writer.add_scalar("fraction_pruned", z, global_step) # saving checkpoint if needed if global_step != 0 and global_step % hp.save_every_step == 0: pruner.prune(global_step) save_checkpoint(device, model, optimizer, global_step, checkpoint_dir, global_epoch) # evaluate model if needed if global_step != 0 and global_test_step != True and global_step % hp.evaluate_every_step == 0: pruner.prune(global_step) print("step {}, evaluating model: generating wav from mel...". format(global_step)) evaluate_model(model, data_loader, checkpoint_dir) print("evaluation finished, resuming training...") # reset global_test_step status after evaluation if global_test_step is True: global_test_step = False global_step += 1 print( "epoch:{}, running loss:{}, average loss:{}, current lr:{}, num_pruned:{} ({}%)" .format(global_epoch, running_loss, avg_loss, current_lr, num_pruned, z)) global_epoch += 1