Пример #1
0
def restore_model(base_dir: str, train_name: str, checkpoint_dir: str) -> None:
    train_dir = get_train_dir(base_dir, train_name, create=True)
    logs_dir = get_train_logs_dir(train_dir)
    logger = prepare_logger(get_train_log_file(logs_dir), reset=True)
    save_checkpoint_dir = get_checkpoints_dir(train_dir)
    last_checkpoint, iteration = get_last_checkpoint(checkpoint_dir)
    logger.info(f"Restoring checkpoint {iteration} from {checkpoint_dir}...")
    shutil.copy2(last_checkpoint, save_checkpoint_dir)
    logger.info("Restoring done.")
Пример #2
0
def try_load_checkpoint(base_dir: str, train_name: Optional[str],
                        checkpoint: Optional[int],
                        logger: Logger) -> Optional[CheckpointTacotron]:
    result = None
    if train_name:
        train_dir = get_train_dir(base_dir, train_name, False)
        checkpoint_path, _ = get_custom_or_last_checkpoint(
            get_checkpoints_dir(train_dir), checkpoint)
        result = CheckpointTacotron.load(checkpoint_path, logger)
        logger.info(f"Using warm start model: {checkpoint_path}")
    return result
Пример #3
0
def plot_embeddings(base_dir: str,
                    train_name: str,
                    custom_checkpoint: Optional[int] = None):
    train_dir = get_train_dir(base_dir, train_name, create=False)
    assert os.path.isdir(train_dir)
    analysis_dir = get_analysis_root_dir(train_dir)

    logger = prepare_logger()

    checkpoint_path, checkpoint_it = get_custom_or_last_checkpoint(
        get_checkpoints_dir(train_dir), custom_checkpoint)
    checkpoint = CheckpointTacotron.load(checkpoint_path, logger)

    # pylint: disable=no-member
    text, fig_2d, fig_3d = plot_embeddings_core(
        symbols=checkpoint.get_symbols(),
        emb=checkpoint.get_symbol_embedding_weights(),
        logger=logger)

    _save_similarities_csv(analysis_dir, checkpoint_it, text)
    _save_2d_plot(analysis_dir, checkpoint_it, fig_2d)
    _save_3d_plot(analysis_dir, checkpoint_it, fig_3d)
    logger.info(f"Saved analysis to: {analysis_dir}")
Пример #4
0
def continue_train(base_dir: str,
                   train_name: str,
                   custom_hparams: Optional[Dict[str, str]] = None) -> None:
    train_dir = get_train_dir(base_dir, train_name, create=False)
    assert os.path.isdir(train_dir)

    logs_dir = get_train_logs_dir(train_dir)
    taco_logger = Tacotron2Logger(logs_dir)
    logger = prepare_logger(get_train_log_file(logs_dir))
    checkpoint_logger = prepare_logger(
        log_file_path=get_train_checkpoints_log_file(logs_dir),
        logger=logging.getLogger("checkpoint-logger"))

    checkpoints_dir = get_checkpoints_dir(train_dir)
    last_checkpoint_path, _ = get_last_checkpoint(checkpoints_dir)
    last_checkpoint = CheckpointTacotron.load(last_checkpoint_path, logger)

    save_callback = partial(
        save_checkpoint,
        save_checkpoint_dir=checkpoints_dir,
        logger=logger,
    )

    ttsp_dir, merge_name, prep_name = load_prep_settings(train_dir)
    merge_dir = get_merged_dir(ttsp_dir, merge_name, create=False)
    prep_dir = get_prep_dir(merge_dir, prep_name, create=False)
    trainset = load_trainset(prep_dir)
    valset = load_valset(prep_dir)

    continue_train_core(checkpoint=last_checkpoint,
                        custom_hparams=custom_hparams,
                        taco_logger=taco_logger,
                        trainset=trainset,
                        valset=valset,
                        logger=logger,
                        checkpoint_logger=checkpoint_logger,
                        save_callback=save_callback)
Пример #5
0
def train(base_dir: str,
          ttsp_dir: str,
          train_name: str,
          merge_name: str,
          prep_name: str,
          warm_start_train_name: Optional[str] = None,
          warm_start_checkpoint: Optional[int] = None,
          custom_hparams: Optional[Dict[str, str]] = None,
          weights_train_name: Optional[str] = None,
          weights_checkpoint: Optional[int] = None,
          use_weights_map: Optional[bool] = None,
          map_from_speaker: Optional[str] = None) -> None:
    merge_dir = get_merged_dir(ttsp_dir, merge_name, create=False)
    prep_dir = get_prep_dir(merge_dir, prep_name, create=False)

    train_dir = get_train_dir(base_dir, train_name, create=True)
    logs_dir = get_train_logs_dir(train_dir)

    taco_logger = Tacotron2Logger(logs_dir)
    logger = prepare_logger(get_train_log_file(logs_dir), reset=True)
    checkpoint_logger = prepare_logger(
        log_file_path=get_train_checkpoints_log_file(logs_dir),
        logger=logging.getLogger("checkpoint-logger"),
        reset=True)

    save_prep_settings(train_dir, ttsp_dir, merge_name, prep_name)

    trainset = load_trainset(prep_dir)
    valset = load_valset(prep_dir)

    weights_model = try_load_checkpoint(base_dir=base_dir,
                                        train_name=weights_train_name,
                                        checkpoint=weights_checkpoint,
                                        logger=logger)

    weights_map = None
    if use_weights_map is not None and use_weights_map:
        weights_train_dir = get_train_dir(base_dir, weights_train_name, False)
        _, weights_merge_name, _ = load_prep_settings(weights_train_dir)
        weights_map = load_weights_map(merge_dir, weights_merge_name)

    warm_model = try_load_checkpoint(base_dir=base_dir,
                                     train_name=warm_start_train_name,
                                     checkpoint=warm_start_checkpoint,
                                     logger=logger)

    save_callback = partial(
        save_checkpoint,
        save_checkpoint_dir=get_checkpoints_dir(train_dir),
        logger=logger,
    )

    train_core(
        custom_hparams=custom_hparams,
        taco_logger=taco_logger,
        symbols=load_merged_symbol_converter(merge_dir),
        speakers=load_merged_speakers_json(merge_dir),
        accents=load_merged_accents_ids(merge_dir),
        trainset=trainset,
        valset=valset,
        save_callback=save_callback,
        weights_map=weights_map,
        weights_checkpoint=weights_model,
        warm_model=warm_model,
        map_from_speaker_name=map_from_speaker,
        logger=logger,
        checkpoint_logger=checkpoint_logger,
    )
Пример #6
0
def infer(base_dir: str, train_name: str, text_name: str, speaker: str, sentence_ids: Optional[Set[int]] = None, custom_checkpoint: Optional[int] = None, full_run: bool = True, custom_hparams: Optional[Dict[str, str]] = None, max_decoder_steps: int = DEFAULT_MAX_DECODER_STEPS, seed: int = DEFAULT_SEED, copy_mel_info_to: Optional[str] = DEFAULT_SAVE_MEL_INFO_COPY_PATH):
  train_dir = get_train_dir(base_dir, train_name, create=False)
  assert os.path.isdir(train_dir)

  logger = get_default_logger()
  init_logger(logger)
  add_console_out_to_logger(logger)

  logger.info("Inferring...")

  checkpoint_path, iteration = get_custom_or_last_checkpoint(
    get_checkpoints_dir(train_dir), custom_checkpoint)
  taco_checkpoint = CheckpointTacotron.load(checkpoint_path, logger)

  ttsp_dir, merge_name, _ = load_prep_settings(train_dir)
  # merge_dir = get_merged_dir(ttsp_dir, merge_name, create=False)

  infer_sents = get_infer_sentences(ttsp_dir, merge_name, text_name)

  run_name = get_run_name(
    input_name=text_name,
    full_run=full_run,
    iteration=iteration,
    speaker_name=speaker,
  )

  infer_dir = get_infer_dir(
    train_dir=train_dir,
    run_name=run_name,
  )

  add_file_out_to_logger(logger, get_infer_log_new(infer_dir))

  mel_postnet_npy_paths: List[Dict[str, Any]] = []
  save_callback = partial(save_results, infer_dir=infer_dir,
                          mel_postnet_npy_paths=mel_postnet_npy_paths)

  inference_results = infer_core(
    checkpoint=taco_checkpoint,
    sentences=infer_sents,
    custom_hparams=custom_hparams,
    full_run=full_run,
    save_callback=save_callback,
    sentence_ids=sentence_ids,
    speaker_name=speaker,
    train_name=train_name,
    max_decoder_steps=max_decoder_steps,
    seed=seed,
    logger=logger,
  )

  logger.info("Creating mel_postnet_v.png")
  save_mel_postnet_v_plot(infer_dir, inference_results)

  logger.info("Creating mel_postnet_h.png")
  save_mel_postnet_h_plot(infer_dir, inference_results)

  logger.info("Creating mel_v.png")
  save_mel_v_plot(infer_dir, inference_results)

  logger.info("Creating alignments_v.png")
  save_alignments_v_plot(infer_dir, inference_results)

  logger.info("Creating total.csv")
  save_stats(infer_dir, inference_results)

  npy_path = save_mel_postnet_npy_paths(
    infer_dir=infer_dir,
    name=run_name,
    mel_postnet_npy_paths=mel_postnet_npy_paths
  )

  logger.info("Wrote all inferred mel paths including sampling rate into these file(s):")
  logger.info(npy_path)

  if copy_mel_info_to is not None:
    create_parent_folder(copy_mel_info_to)
    copyfile(npy_path, copy_mel_info_to)
    logger.info(copy_mel_info_to)

  logger.info(f"Saved output to: {infer_dir}")
Пример #7
0
def validate(base_dir: str, train_name: str, entry_ids: Optional[Set[int]] = None, speaker: Optional[str] = None, ds: str = "val", custom_checkpoints: Optional[Set[int]] = None, custom_hparams: Optional[Dict[str, str]] = None, full_run: bool = False, max_decoder_steps: int = DEFAULT_MAX_DECODER_STEPS, mcd_no_of_coeffs_per_frame: int = DEFAULT_MCD_NO_OF_COEFFS_PER_FRAME, copy_mel_info_to: Optional[str] = DEFAULT_SAVE_MEL_INFO_COPY_PATH, fast: bool = False, repetitions: int = DEFAULT_REPETITIONS, select_best_from: Optional[str] = None, seed: Optional[int] = DEFAULT_SEED) -> None:
  """Param: custom checkpoints: empty => all; None => random; ids"""
  assert repetitions > 0

  train_dir = get_train_dir(base_dir, train_name, create=False)
  assert os.path.isdir(train_dir)

  ttsp_dir, merge_name, prep_name = load_prep_settings(train_dir)
  merge_dir = get_merged_dir(ttsp_dir, merge_name, create=False)
  prep_dir = get_prep_dir(merge_dir, prep_name, create=False)

  if ds == "val":
    data = load_valset(prep_dir)
  elif ds == "test":
    data = load_testset(prep_dir)
  else:
    assert False

  iterations: Set[int] = set()
  checkpoint_dir = get_checkpoints_dir(train_dir)

  if custom_checkpoints is None:
    _, last_it = get_last_checkpoint(checkpoint_dir)
    iterations.add(last_it)
  else:
    if len(custom_checkpoints) == 0:
      iterations = set(get_all_checkpoint_iterations(checkpoint_dir))
    else:
      iterations = custom_checkpoints

  run_name = get_run_name(
    ds=ds,
    entry_ids=entry_ids,
    full_run=full_run,
    iterations=iterations,
    speaker=speaker,
  )

  val_dir = get_val_dir(
    train_dir=train_dir,
    run_name=run_name,
  )

  val_log_path = os.path.join(val_dir, "log.txt")
  logger = prepare_logger(val_log_path)
  logger.info("Validating...")
  logger.info(f"Checkpoints: {','.join(str(x) for x in sorted(iterations))}")

  result = ValidationEntries()
  save_callback = None
  trainset = load_trainset(prep_dir)

  select_best_from_df = None
  if select_best_from is not None:
    select_best_from_df = pd.read_csv(select_best_from, sep="\t")

  for iteration in tqdm(sorted(iterations)):
    mel_postnet_npy_paths: List[str] = []
    logger.info(f"Current checkpoint: {iteration}")
    checkpoint_path = get_checkpoint(checkpoint_dir, iteration)
    taco_checkpoint = CheckpointTacotron.load(checkpoint_path, logger)
    if not fast:
      save_callback = partial(save_results, val_dir=val_dir, iteration=iteration,
                              mel_postnet_npy_paths=mel_postnet_npy_paths)

    validation_entries = validate_core(
      checkpoint=taco_checkpoint,
      data=data,
      trainset=trainset,
      custom_hparams=custom_hparams,
      entry_ids=entry_ids,
      full_run=full_run,
      speaker_name=speaker,
      train_name=train_name,
      logger=logger,
      max_decoder_steps=max_decoder_steps,
      fast=fast,
      save_callback=save_callback,
      mcd_no_of_coeffs_per_frame=mcd_no_of_coeffs_per_frame,
      repetitions=repetitions,
      seed=seed,
      select_best_from=select_best_from_df,
    )

    result.extend(validation_entries)

  if len(result) == 0:
    return

  save_stats(val_dir, result)

  if not fast:
    logger.info("Wrote all inferred mel paths including sampling rate into these file(s):")
    npy_path = save_mel_postnet_npy_paths(
      val_dir=val_dir,
      name=run_name,
      mel_postnet_npy_paths=mel_postnet_npy_paths
    )
    logger.info(npy_path)

    if copy_mel_info_to is not None:
      create_parent_folder(copy_mel_info_to)
      copyfile(npy_path, copy_mel_info_to)
      logger.info(copy_mel_info_to)

  logger.info(f"Saved output to: {val_dir}")