def _train(custom_hparams: Optional[Dict[str, str]], taco_logger: Tacotron2Logger, trainset: PreparedDataList, valset: PreparedDataList, save_callback: Any, speakers: SpeakersDict, accents: AccentsDict, symbols: SymbolIdDict, checkpoint: Optional[CheckpointTacotron], warm_model: Optional[CheckpointTacotron], weights_checkpoint: Optional[CheckpointTacotron], weights_map: SymbolsMap, map_from_speaker_name: Optional[str], logger: Logger, checkpoint_logger: Logger): """Training and validation logging results to tensorboard and stdout Params ------ output_directory (string): directory to save checkpoints log_directory (string) directory to save tensorboard logs checkpoint_path(string): checkpoint path n_gpus (int): number of gpus rank (int): rank of current gpu hparams (object): comma separated list of "name=value" pairs. """ complete_start = time.time() if checkpoint is not None: hparams = checkpoint.get_hparams(logger) else: hparams = HParams( n_accents=len(accents), n_speakers=len(speakers), n_symbols=len(symbols) ) # is it problematic to change the batch size on a trained model? hparams = overwrite_custom_hparams(hparams, custom_hparams) assert hparams.n_accents > 0 assert hparams.n_speakers > 0 assert hparams.n_symbols > 0 if hparams.use_saved_learning_rate and checkpoint is not None: hparams.learning_rate = checkpoint.learning_rate log_hparams(hparams, logger) init_torch(hparams) model, optimizer = load_model_and_optimizer( hparams=hparams, checkpoint=checkpoint, logger=logger, ) iteration = get_iteration(checkpoint) if checkpoint is None: if warm_model is not None: logger.info("Loading states from pretrained model...") warm_start_model(model, warm_model, hparams, logger) if weights_checkpoint is not None: logger.info("Mapping symbol embeddings...") pretrained_symbol_weights = get_mapped_symbol_weights( model_symbols=symbols, trained_weights=weights_checkpoint.get_symbol_embedding_weights(), trained_symbols=weights_checkpoint.get_symbols(), custom_mapping=weights_map, hparams=hparams, logger=logger ) update_weights(model.embedding, pretrained_symbol_weights) map_speaker_weights = map_from_speaker_name is not None if map_speaker_weights: logger.info("Mapping speaker embeddings...") pretrained_speaker_weights = get_mapped_speaker_weights( model_speakers=speakers, trained_weights=weights_checkpoint.get_speaker_embedding_weights(), trained_speaker=weights_checkpoint.get_speakers(), map_from_speaker_name=map_from_speaker_name, hparams=hparams, logger=logger, ) update_weights(model.speakers_embedding, pretrained_speaker_weights) log_symbol_weights(model, logger) collate_fn = SymbolsMelCollate( n_frames_per_step=hparams.n_frames_per_step, padding_symbol_id=symbols.get_id(PADDING_SYMBOL), padding_accent_id=accents.get_id(PADDING_ACCENT) ) val_loader = prepare_valloader(hparams, collate_fn, valset, logger) train_loader = prepare_trainloader(hparams, collate_fn, trainset, logger) batch_iterations = len(train_loader) enough_traindata = batch_iterations > 0 if not enough_traindata: msg = "Not enough trainingdata." logger.error(msg) raise Exception(msg) save_it_settings = SaveIterationSettings( epochs=hparams.epochs, batch_iterations=batch_iterations, save_first_iteration=True, save_last_iteration=True, iters_per_checkpoint=hparams.iters_per_checkpoint, epochs_per_checkpoint=hparams.epochs_per_checkpoint ) criterion = Tacotron2Loss() batch_durations: List[float] = [] train_start = time.perf_counter() start = train_start model.train() continue_epoch = get_continue_epoch(iteration, batch_iterations) for epoch in range(continue_epoch, hparams.epochs): next_batch_iteration = get_continue_batch_iteration(iteration, batch_iterations) skip_bar = None if next_batch_iteration > 0: logger.debug(f"Current batch is {next_batch_iteration} of {batch_iterations}") logger.debug("Skipping batches...") skip_bar = tqdm(total=next_batch_iteration) for batch_iteration, batch in enumerate(train_loader): need_to_skip_batch = skip_batch( batch_iteration=batch_iteration, continue_batch_iteration=next_batch_iteration ) if need_to_skip_batch: assert skip_bar is not None skip_bar.update(1) #debug_logger.debug(f"Skipped batch {batch_iteration + 1}/{next_batch_iteration + 1}.") continue # debug_logger.debug(f"Current batch: {batch[0][0]}") # update_learning_rate_optimizer(optimizer, hparams.learning_rate) model.zero_grad() x, y = parse_batch(batch) y_pred = model(x) loss = criterion(y_pred, y) reduced_loss = loss.item() loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( parameters=model.parameters(), max_norm=hparams.grad_clip_thresh ) optimizer.step() iteration += 1 end = time.perf_counter() duration = end - start start = end batch_durations.append(duration) avg_batch_dur = np.mean(batch_durations) avg_epoch_dur = avg_batch_dur * batch_iterations remaining_its = hparams.epochs * batch_iterations - iteration estimated_remaining_duration = avg_batch_dur * remaining_its next_it = get_next_save_it(iteration, save_it_settings) next_checkpoint_save_time = 0 if next_it is not None: next_checkpoint_save_time = (next_it - iteration) * avg_batch_dur logger.info(" | ".join([ f"Epoch: {get_formatted_current_total(epoch + 1, hparams.epochs)}", f"It.: {get_formatted_current_total(batch_iteration + 1, batch_iterations)}", f"Tot. it.: {get_formatted_current_total(iteration, hparams.epochs * batch_iterations)} ({iteration / (hparams.epochs * batch_iterations) * 100:.2f}%)", f"Loss: {reduced_loss:.6f}", f"Grad norm: {grad_norm:.6f}", #f"Dur.: {duration:.2f}s/it", f"Avg. dur.: {avg_batch_dur:.2f}s/it & {avg_epoch_dur / 60:.0f}min/epoch", f"Tot. dur.: {(time.perf_counter() - train_start) / 60 / 60:.2f}h/{estimated_remaining_duration / 60 / 60:.0f}h ({estimated_remaining_duration / 60 / 60 / 24:.1f}days)", f"Next checkpoint: {next_checkpoint_save_time / 60:.0f}min", ])) taco_logger.log_training(reduced_loss, grad_norm, hparams.learning_rate, duration, iteration) save_it = check_save_it(epoch, iteration, save_it_settings) if save_it: checkpoint = CheckpointTacotron.from_instances( model=model, optimizer=optimizer, hparams=hparams, iteration=iteration, symbols=symbols, accents=accents, speakers=speakers ) save_callback(checkpoint) valloss = validate(model, criterion, val_loader, iteration, taco_logger, logger) # if rank == 0: log_checkpoint_score(iteration, grad_norm, reduced_loss, valloss, epoch, batch_iteration, checkpoint_logger) duration_s = time.time() - complete_start logger.info(f'Finished training. Total duration: {duration_s / 60:.2f}min')
def test_get_continue_epoch_cur5_tot2_is_2(self): res = get_continue_epoch(current_iteration=5, batch_iterations=2) self.assertEqual(2, res)
def test_get_continue_epoch_cur1_tot2_is_0(self): res = get_continue_epoch(current_iteration=1, batch_iterations=2) self.assertEqual(0, res)
def test_get_continue_epoch_cur3_tot2_is_1(self): res = get_continue_epoch(current_iteration=3, batch_iterations=2) self.assertEqual(1, res)
def _train(custom_hparams: Optional[Dict[str, str]], logdir: str, trainset: PreparedDataList, valset: PreparedDataList, save_checkpoint_dir: str, checkpoint: Optional[CheckpointWaveglow], logger: Logger, warm_model: Optional[CheckpointWaveglow]): complete_start = time.time() wg_logger = WaveglowLogger(logdir) if checkpoint is not None: hparams = checkpoint.get_hparams(logger) else: hparams = HParams() # is it problematic to change the batch size? hparams = overwrite_custom_hparams(hparams, custom_hparams) log_hparams(hparams, logger) init_torch(hparams) model, optimizer = load_model_and_optimizer( hparams=hparams, checkpoint=checkpoint, logger=logger, ) iteration = get_iteration(checkpoint) if checkpoint is None and warm_model is not None: logger.info("Loading states from pretrained model...") warm_start_model(model, warm_model) criterion = WaveGlowLoss( sigma=hparams.sigma ) train_loader = prepare_trainloader( hparams=hparams, trainset=trainset, logger=logger ) val_loader = prepare_valloader( hparams=hparams, valset=valset, logger=logger ) batch_iterations = len(train_loader) if batch_iterations == 0: logger.error("Not enough trainingdata.") raise Exception() # Get shared output_directory ready # if rank == 0: # if not os.path.isdir(output_directory): # os.makedirs(output_directory) # os.chmod(output_directory, 0o775) # print("output directory", output_directory) model.train() train_start = time.perf_counter() start = train_start save_it_settings = SaveIterationSettings( epochs=hparams.epochs, batch_iterations=batch_iterations, save_first_iteration=True, save_last_iteration=True, iters_per_checkpoint=hparams.iters_per_checkpoint, epochs_per_checkpoint=hparams.epochs_per_checkpoint ) # total_its = hparams.epochs * len(train_loader) # epoch_offset = max(0, int(iteration / len(train_loader))) # # ================ MAIN TRAINING LOOP! =================== # for epoch in range(epoch_offset, hparams.epochs): # debug_logger.info("Epoch: {}".format(epoch)) # for i, batch in enumerate(train_loader): batch_durations: List[float] = [] continue_epoch = get_continue_epoch(iteration, batch_iterations) for epoch in range(continue_epoch, hparams.epochs): next_batch_iteration = get_continue_batch_iteration(iteration, batch_iterations) skip_bar = None if next_batch_iteration > 0: logger.debug(f"Current batch is {next_batch_iteration} of {batch_iterations}") logger.debug("Skipping batches...") skip_bar = tqdm(total=next_batch_iteration) for batch_iteration, batch in enumerate(train_loader): need_to_skip_batch = skip_batch( batch_iteration=batch_iteration, continue_batch_iteration=next_batch_iteration ) if need_to_skip_batch: assert skip_bar is not None skip_bar.update(1) #debug_logger.debug(f"Skipped batch {batch_iteration + 1}/{next_batch_iteration + 1}.") continue # debug_logger.debug(f"Current batch: {batch[0][0]}") model.zero_grad() x, y = parse_batch(batch) y_pred = model(x) loss = criterion(y_pred, y) reduced_loss = loss.item() loss.backward() optimizer.step() iteration += 1 end = time.perf_counter() duration = end - start start = end batch_durations.append(duration) logger.info(" | ".join([ f"Epoch: {get_formatted_current_total(epoch + 1, hparams.epochs)}", f"Iteration: {get_formatted_current_total(batch_iteration + 1, batch_iterations)}", f"Total iteration: {get_formatted_current_total(iteration, hparams.epochs * batch_iterations)}", f"Train loss: {reduced_loss:.6f}", f"Duration: {duration:.2f}s/it", f"Avg. duration: {np.mean(batch_durations):.2f}s/it", f"Total Duration: {(time.perf_counter() - train_start) / 60 / 60:.2f}h" ])) wg_logger.log_training(reduced_loss, hparams.learning_rate, duration, iteration) wg_logger.add_scalar('training_loss', reduced_loss, iteration) save_it = check_save_it(epoch, iteration, save_it_settings) if save_it: checkpoint = CheckpointWaveglow.from_instances( model=model, optimizer=optimizer, hparams=hparams, iteration=iteration, ) checkpoint_path = os.path.join( save_checkpoint_dir, get_pytorch_filename(iteration)) checkpoint.save(checkpoint_path, logger) validate(model, criterion, val_loader, iteration, wg_logger, logger) duration_s = time.time() - complete_start logger.info(f'Finished training. Total duration: {duration_s / 60:.2f}min')