예제 #1
0
def _train(custom_hparams: Optional[Dict[str, str]], taco_logger: Tacotron2Logger, trainset: PreparedDataList, valset: PreparedDataList, save_callback: Any, speakers: SpeakersDict, accents: AccentsDict, symbols: SymbolIdDict, checkpoint: Optional[CheckpointTacotron], warm_model: Optional[CheckpointTacotron], weights_checkpoint: Optional[CheckpointTacotron], weights_map: SymbolsMap, map_from_speaker_name: Optional[str], logger: Logger, checkpoint_logger: Logger):
  """Training and validation logging results to tensorboard and stdout
  Params
  ------
  output_directory (string): directory to save checkpoints
  log_directory (string) directory to save tensorboard logs
  checkpoint_path(string): checkpoint path
  n_gpus (int): number of gpus
  rank (int): rank of current gpu
  hparams (object): comma separated list of "name=value" pairs.
  """

  complete_start = time.time()

  if checkpoint is not None:
    hparams = checkpoint.get_hparams(logger)
  else:
    hparams = HParams(
      n_accents=len(accents),
      n_speakers=len(speakers),
      n_symbols=len(symbols)
    )
  # is it problematic to change the batch size on a trained model?
  hparams = overwrite_custom_hparams(hparams, custom_hparams)

  assert hparams.n_accents > 0
  assert hparams.n_speakers > 0
  assert hparams.n_symbols > 0

  if hparams.use_saved_learning_rate and checkpoint is not None:
    hparams.learning_rate = checkpoint.learning_rate

  log_hparams(hparams, logger)
  init_torch(hparams)

  model, optimizer = load_model_and_optimizer(
    hparams=hparams,
    checkpoint=checkpoint,
    logger=logger,
  )

  iteration = get_iteration(checkpoint)

  if checkpoint is None:
    if warm_model is not None:
      logger.info("Loading states from pretrained model...")
      warm_start_model(model, warm_model, hparams, logger)

    if weights_checkpoint is not None:
      logger.info("Mapping symbol embeddings...")

      pretrained_symbol_weights = get_mapped_symbol_weights(
        model_symbols=symbols,
        trained_weights=weights_checkpoint.get_symbol_embedding_weights(),
        trained_symbols=weights_checkpoint.get_symbols(),
        custom_mapping=weights_map,
        hparams=hparams,
        logger=logger
      )

      update_weights(model.embedding, pretrained_symbol_weights)

      map_speaker_weights = map_from_speaker_name is not None
      if map_speaker_weights:
        logger.info("Mapping speaker embeddings...")
        pretrained_speaker_weights = get_mapped_speaker_weights(
          model_speakers=speakers,
          trained_weights=weights_checkpoint.get_speaker_embedding_weights(),
          trained_speaker=weights_checkpoint.get_speakers(),
          map_from_speaker_name=map_from_speaker_name,
          hparams=hparams,
          logger=logger,
        )

        update_weights(model.speakers_embedding, pretrained_speaker_weights)

  log_symbol_weights(model, logger)

  collate_fn = SymbolsMelCollate(
    n_frames_per_step=hparams.n_frames_per_step,
    padding_symbol_id=symbols.get_id(PADDING_SYMBOL),
    padding_accent_id=accents.get_id(PADDING_ACCENT)
  )

  val_loader = prepare_valloader(hparams, collate_fn, valset, logger)
  train_loader = prepare_trainloader(hparams, collate_fn, trainset, logger)

  batch_iterations = len(train_loader)
  enough_traindata = batch_iterations > 0
  if not enough_traindata:
    msg = "Not enough trainingdata."
    logger.error(msg)
    raise Exception(msg)

  save_it_settings = SaveIterationSettings(
    epochs=hparams.epochs,
    batch_iterations=batch_iterations,
    save_first_iteration=True,
    save_last_iteration=True,
    iters_per_checkpoint=hparams.iters_per_checkpoint,
    epochs_per_checkpoint=hparams.epochs_per_checkpoint
  )

  criterion = Tacotron2Loss()
  batch_durations: List[float] = []

  train_start = time.perf_counter()
  start = train_start
  model.train()
  continue_epoch = get_continue_epoch(iteration, batch_iterations)
  for epoch in range(continue_epoch, hparams.epochs):
    next_batch_iteration = get_continue_batch_iteration(iteration, batch_iterations)
    skip_bar = None
    if next_batch_iteration > 0:
      logger.debug(f"Current batch is {next_batch_iteration} of {batch_iterations}")
      logger.debug("Skipping batches...")
      skip_bar = tqdm(total=next_batch_iteration)
    for batch_iteration, batch in enumerate(train_loader):
      need_to_skip_batch = skip_batch(
        batch_iteration=batch_iteration,
        continue_batch_iteration=next_batch_iteration
      )
      if need_to_skip_batch:
        assert skip_bar is not None
        skip_bar.update(1)
        #debug_logger.debug(f"Skipped batch {batch_iteration + 1}/{next_batch_iteration + 1}.")
        continue
      # debug_logger.debug(f"Current batch: {batch[0][0]}")

      # update_learning_rate_optimizer(optimizer, hparams.learning_rate)

      model.zero_grad()
      x, y = parse_batch(batch)
      y_pred = model(x)

      loss = criterion(y_pred, y)
      reduced_loss = loss.item()

      loss.backward()

      grad_norm = torch.nn.utils.clip_grad_norm_(
        parameters=model.parameters(),
        max_norm=hparams.grad_clip_thresh
      )

      optimizer.step()

      iteration += 1

      end = time.perf_counter()
      duration = end - start
      start = end

      batch_durations.append(duration)
      avg_batch_dur = np.mean(batch_durations)
      avg_epoch_dur = avg_batch_dur * batch_iterations
      remaining_its = hparams.epochs * batch_iterations - iteration
      estimated_remaining_duration = avg_batch_dur * remaining_its

      next_it = get_next_save_it(iteration, save_it_settings)
      next_checkpoint_save_time = 0
      if next_it is not None:
        next_checkpoint_save_time = (next_it - iteration) * avg_batch_dur

      logger.info(" | ".join([
        f"Epoch: {get_formatted_current_total(epoch + 1, hparams.epochs)}",
        f"It.: {get_formatted_current_total(batch_iteration + 1, batch_iterations)}",
        f"Tot. it.: {get_formatted_current_total(iteration, hparams.epochs * batch_iterations)} ({iteration / (hparams.epochs * batch_iterations) * 100:.2f}%)",
        f"Loss: {reduced_loss:.6f}",
        f"Grad norm: {grad_norm:.6f}",
        #f"Dur.: {duration:.2f}s/it",
        f"Avg. dur.: {avg_batch_dur:.2f}s/it & {avg_epoch_dur / 60:.0f}min/epoch",
        f"Tot. dur.: {(time.perf_counter() - train_start) / 60 / 60:.2f}h/{estimated_remaining_duration / 60 / 60:.0f}h ({estimated_remaining_duration / 60 / 60 / 24:.1f}days)",
        f"Next checkpoint: {next_checkpoint_save_time / 60:.0f}min",
      ]))

      taco_logger.log_training(reduced_loss, grad_norm, hparams.learning_rate,
                               duration, iteration)

      save_it = check_save_it(epoch, iteration, save_it_settings)
      if save_it:
        checkpoint = CheckpointTacotron.from_instances(
          model=model,
          optimizer=optimizer,
          hparams=hparams,
          iteration=iteration,
          symbols=symbols,
          accents=accents,
          speakers=speakers
        )

        save_callback(checkpoint)

        valloss = validate(model, criterion, val_loader, iteration, taco_logger, logger)

        # if rank == 0:
        log_checkpoint_score(iteration, grad_norm,
                             reduced_loss, valloss, epoch, batch_iteration, checkpoint_logger)

  duration_s = time.time() - complete_start
  logger.info(f'Finished training. Total duration: {duration_s / 60:.2f}min')
예제 #2
0
 def test_get_continue_iteration_cur4_tot2_is_0(self):
     res = get_continue_batch_iteration(iteration=4, batch_iterations=2)
     self.assertEqual(0, res)
예제 #3
0
 def test_get_continue_iteration_cur5_tot2_is_1(self):
     res = get_continue_batch_iteration(iteration=5, batch_iterations=2)
     self.assertEqual(1, res)
예제 #4
0
def _train(custom_hparams: Optional[Dict[str, str]], logdir: str, trainset: PreparedDataList, valset: PreparedDataList, save_checkpoint_dir: str, checkpoint: Optional[CheckpointWaveglow], logger: Logger, warm_model: Optional[CheckpointWaveglow]):
  complete_start = time.time()
  wg_logger = WaveglowLogger(logdir)

  if checkpoint is not None:
    hparams = checkpoint.get_hparams(logger)
  else:
    hparams = HParams()
  # is it problematic to change the batch size?
  hparams = overwrite_custom_hparams(hparams, custom_hparams)

  log_hparams(hparams, logger)
  init_torch(hparams)

  model, optimizer = load_model_and_optimizer(
    hparams=hparams,
    checkpoint=checkpoint,
    logger=logger,
  )

  iteration = get_iteration(checkpoint)

  if checkpoint is None and warm_model is not None:
    logger.info("Loading states from pretrained model...")
    warm_start_model(model, warm_model)

  criterion = WaveGlowLoss(
    sigma=hparams.sigma
  )

  train_loader = prepare_trainloader(
    hparams=hparams,
    trainset=trainset,
    logger=logger
  )

  val_loader = prepare_valloader(
    hparams=hparams,
    valset=valset,
    logger=logger
  )

  batch_iterations = len(train_loader)
  if batch_iterations == 0:
    logger.error("Not enough trainingdata.")
    raise Exception()

  # Get shared output_directory ready
  # if rank == 0:
  #   if not os.path.isdir(output_directory):
  #     os.makedirs(output_directory)
  #     os.chmod(output_directory, 0o775)
  #   print("output directory", output_directory)

  model.train()

  train_start = time.perf_counter()
  start = train_start

  save_it_settings = SaveIterationSettings(
    epochs=hparams.epochs,
    batch_iterations=batch_iterations,
    save_first_iteration=True,
    save_last_iteration=True,
    iters_per_checkpoint=hparams.iters_per_checkpoint,
    epochs_per_checkpoint=hparams.epochs_per_checkpoint
  )

  # total_its = hparams.epochs * len(train_loader)
  # epoch_offset = max(0, int(iteration / len(train_loader)))
  # # ================ MAIN TRAINING LOOP! ===================
  # for epoch in range(epoch_offset, hparams.epochs):
  #   debug_logger.info("Epoch: {}".format(epoch))
  #   for i, batch in enumerate(train_loader):
  batch_durations: List[float] = []

  continue_epoch = get_continue_epoch(iteration, batch_iterations)
  for epoch in range(continue_epoch, hparams.epochs):
    next_batch_iteration = get_continue_batch_iteration(iteration, batch_iterations)
    skip_bar = None
    if next_batch_iteration > 0:
      logger.debug(f"Current batch is {next_batch_iteration} of {batch_iterations}")
      logger.debug("Skipping batches...")
      skip_bar = tqdm(total=next_batch_iteration)
    for batch_iteration, batch in enumerate(train_loader):
      need_to_skip_batch = skip_batch(
        batch_iteration=batch_iteration,
        continue_batch_iteration=next_batch_iteration
      )
      if need_to_skip_batch:
        assert skip_bar is not None
        skip_bar.update(1)
        #debug_logger.debug(f"Skipped batch {batch_iteration + 1}/{next_batch_iteration + 1}.")
        continue
      # debug_logger.debug(f"Current batch: {batch[0][0]}")

      model.zero_grad()
      x, y = parse_batch(batch)
      y_pred = model(x)

      loss = criterion(y_pred, y)
      reduced_loss = loss.item()

      loss.backward()

      optimizer.step()

      iteration += 1

      end = time.perf_counter()
      duration = end - start
      start = end

      batch_durations.append(duration)
      logger.info(" | ".join([
        f"Epoch: {get_formatted_current_total(epoch + 1, hparams.epochs)}",
        f"Iteration: {get_formatted_current_total(batch_iteration + 1, batch_iterations)}",
        f"Total iteration: {get_formatted_current_total(iteration, hparams.epochs * batch_iterations)}",
        f"Train loss: {reduced_loss:.6f}",
        f"Duration: {duration:.2f}s/it",
        f"Avg. duration: {np.mean(batch_durations):.2f}s/it",
        f"Total Duration: {(time.perf_counter() - train_start) / 60 / 60:.2f}h"
      ]))

      wg_logger.log_training(reduced_loss, hparams.learning_rate, duration, iteration)

      wg_logger.add_scalar('training_loss', reduced_loss, iteration)

      save_it = check_save_it(epoch, iteration, save_it_settings)
      if save_it:
        checkpoint = CheckpointWaveglow.from_instances(
          model=model,
          optimizer=optimizer,
          hparams=hparams,
          iteration=iteration,
        )

        checkpoint_path = os.path.join(
          save_checkpoint_dir, get_pytorch_filename(iteration))
        checkpoint.save(checkpoint_path, logger)

        validate(model, criterion, val_loader, iteration, wg_logger, logger)

  duration_s = time.time() - complete_start
  logger.info(f'Finished training. Total duration: {duration_s / 60:.2f}min')