Esempio n. 1
0
def set_accent(sentences: SentenceList, accent_ids: AccentsDict, accent: str) -> Tuple[SymbolIdDict, SentenceList]:
  accent_id = accent_ids.get_id(accent)
  for sentence in sentences.items():
    new_accent_ids = [accent_id] * len(sentence.get_accent_ids())
    sentence.serialized_accents = serialize_list(new_accent_ids)
    assert len(sentence.get_accent_ids()) == len(sentence.get_symbol_ids())
  return sentences
Esempio n. 2
0
 def get_formatted(self, accent_id_dict: AccentsDict, pairs_per_line=170, space_length=0):
   return get_formatted_core(
     sent_id=self.sent_id,
     symbols=self.symbols,
     accent_ids=accent_id_dict.get_ids(self.accents),
     accent_id_dict=accent_id_dict,
     space_length=space_length,
     max_pairs_per_line=pairs_per_line
   )
Esempio n. 3
0
def sents_accent_apply(sentences: SentenceList, accented_symbols: AccentedSymbolList, accent_ids: AccentsDict) -> SentenceList:
  current_index = 0
  for sent in sentences.items():
    accent_ids_count = len(deserialize_list(sent.serialized_accents))
    assert len(accented_symbols) >= current_index + accent_ids_count
    accented_symbol_selection: List[AccentedSymbol] = accented_symbols[current_index:current_index + accent_ids_count]
    current_index += accent_ids_count
    new_accent_ids = accent_ids.get_ids([x.accent for x in accented_symbol_selection])
    sent.serialized_accents = serialize_list(new_accent_ids)
    assert len(sent.get_accent_ids()) == len(sent.get_symbol_ids())
  return sentences
Esempio n. 4
0
 def from_sentences(cls, sentences: SentenceList, accents: AccentsDict, symbols: SymbolIdDict):
   res = cls()
   for sentence in sentences.items():
     infer_sent = InferSentence(
       sent_id=sentence.sent_id,
       symbols=symbols.get_symbols(sentence.serialized_symbols),
       accents=accents.get_accents(sentence.serialized_accents)
     )
     assert len(infer_sent.symbols) == len(infer_sent.accents)
     res.append(infer_sent)
   return res
Esempio n. 5
0
 def from_instances(cls, model: Tacotron2, optimizer: Adam,
                    hparams: HParams, iteration: int, symbols: SymbolIdDict,
                    accents: AccentsDict, speakers: SpeakersDict):
     result = cls(state_dict=model.state_dict(),
                  optimizer=optimizer.state_dict(),
                  learning_rate=hparams.learning_rate,
                  iteration=iteration,
                  hparams=asdict(hparams),
                  symbols=symbols.raw(),
                  accents=accents.raw(),
                  speakers=speakers.raw())
     return result
Esempio n. 6
0
def sents_accent_template(sentences: SentenceList, text_symbols: SymbolIdDict, accent_ids: AccentsDict) -> AccentedSymbolList:
  res = AccentedSymbolList()
  for i, sent in enumerate(sentences.items()):
    symbols = text_symbols.get_symbols(sent.serialized_symbols)
    accents = accent_ids.get_accents(sent.serialized_accents)
    for j, symbol_accent in enumerate(zip(symbols, accents)):
      symbol, accent = symbol_accent
      accented_symbol = AccentedSymbol(
        position=f"{i}-{j}",
        symbol=symbol,
        accent=accent
      )
      res.append(accented_symbol)
  return res
Esempio n. 7
0
def _get_ds_data(l: PreDataList, speakers_dict: SpeakersDict,
                 accents: AccentsDict, symbols: SymbolIdDict) -> DsDataList:
    result = [
        DsData(entry_id=i,
               basename=values.name,
               speaker_name=values.speaker_name,
               speaker_id=speakers_dict[values.speaker_name],
               text=values.text,
               serialized_symbols=symbols.get_serialized_ids(values.symbols),
               serialized_accents=accents.get_serialized_ids(values.accents),
               wav_path=values.wav_path,
               lang=values.lang,
               gender=values.gender) for i, values in enumerate(l.items())
    ]
    return DsDataList(result)
Esempio n. 8
0
    def make_common_accent_ids(self) -> AccentsDict:
        all_accents: Set[str] = set()
        for ds in self.data:
            all_accents |= ds.accent_ids.get_all_accents()
        new_accent_ids = AccentsDict.init_from_accents_with_pad(all_accents)

        for ds in self.data:
            for entry in ds.data.items():
                original_accents = ds.accent_ids.get_accents(
                    entry.serialized_accent_ids)
                entry.serialized_accent_ids = new_accent_ids.get_serialized_ids(
                    original_accents)
            ds.accent_ids = new_accent_ids

        return new_accent_ids
Esempio n. 9
0
def convert_v1_to_v2_model(old_model_path: str,
                           custom_hparams: Optional[Dict[str, str]],
                           speakers: SpeakersDict, accents: AccentsDict,
                           symbols: SymbolIdDict):
    checkpoint_dict = torch.load(old_model_path, map_location='cpu')
    hparams = HParams(n_speakers=len(speakers),
                      n_accents=len(accents),
                      n_symbols=len(symbols))

    hparams = overwrite_custom_hparams(hparams, custom_hparams)

    chp = CheckpointTacotron(state_dict=checkpoint_dict["state_dict"],
                             optimizer=checkpoint_dict["optimizer"],
                             learning_rate=checkpoint_dict["learning_rate"],
                             iteration=checkpoint_dict["iteration"] + 1,
                             hparams=asdict(hparams),
                             speakers=speakers.raw(),
                             symbols=symbols.raw(),
                             accents=accents.raw())

    new_model_path = f"{old_model_path}_{get_pytorch_filename(chp.iteration)}"

    chp.save(new_model_path, logging.getLogger())
    def test_init_from_accents_is_sorted(self):
        res = AccentsDict.init_from_accents({"c", "a", "b"})

        self.assertEqual("a", res.get_accent(0))
        self.assertEqual("b", res.get_accent(1))
        self.assertEqual("c", res.get_accent(2))
    def test_init_from_accents_with_pad_uses_pad_const(self):
        res = AccentsDict.init_from_accents_with_pad({"b", "a"})

        self.assertEqual(PADDING_ACCENT, res.get_accent(0))
        self.assertEqual("a", res.get_accent(1))
        self.assertEqual("b", res.get_accent(2))
Esempio n. 12
0
def load_prep_accents_ids(prep_dir: str) -> AccentsDict:
    path = os.path.join(prep_dir, _prepared_accents_json)
    return AccentsDict.load(path)
Esempio n. 13
0
def save_prep_accents_ids(prep_dir: str, data: AccentsDict):
    path = os.path.join(prep_dir, _prepared_accents_json)
    data.save(path)
Esempio n. 14
0
def _save_accents_json(ds_dir: str, data: AccentsDict):
    path = os.path.join(ds_dir, _ds_accents_json)
    data.save(path)
Esempio n. 15
0
def load_accents_json(ds_dir: str) -> AccentsDict:
    path = os.path.join(ds_dir, _ds_accents_json)
    return AccentsDict.load(path)
Esempio n. 16
0
    def test_ds_dataset_to_merged_dataset(self):
        ds = DsDataset(
            name="ds",
            accent_ids=AccentsDict.init_from_accents({"c", "d"}),
            data=DsDataList([
                DsData(
                    entry_id=1,
                    basename="basename",
                    speaker_id=2,
                    serialized_symbols="0,1",
                    gender=Gender.MALE,
                    lang=Language.GER,
                    serialized_accents="0,0",
                    wav_path="wav",
                    speaker_name="speaker",
                    text="text",
                )
            ]),
            mels=MelDataList(
                [MelData(
                    entry_id=1,
                    mel_path="mel",
                    n_mel_channels=5,
                )]),
            speakers=SpeakersDict.fromlist(["sp1", "sp2", "sp3"]),
            symbol_ids=SymbolIdDict.init_from_symbols({"a", "b"}),
            texts=TextDataList([
                TextData(
                    entry_id=1,
                    lang=Language.IPA,
                    serialized_accent_ids="1,0",
                    serialized_symbol_ids="1,1",
                    text="text_new",
                )
            ]),
            wavs=WavDataList(
                [WavData(
                    entry_id=1,
                    duration=15,
                    sr=22050,
                    wav="wav_new",
                )]),
        )

        res = ds_dataset_to_merged_dataset(ds)

        self.assertEqual("ds", res.name)
        self.assertEqual(2, len(res.accent_ids))
        self.assertEqual(3, len(res.speakers))
        self.assertEqual(2, len(res.symbol_ids))
        self.assertEqual(1, len(res.items()))
        first_entry = res.items()[0]
        self.assertEqual(1, first_entry.entry_id)
        self.assertEqual(Gender.MALE, first_entry.gender)
        self.assertEqual("basename", first_entry.basename)
        self.assertEqual(2, first_entry.speaker_id)
        self.assertEqual(Language.IPA, first_entry.lang)
        self.assertEqual("1,0", first_entry.serialized_accents)
        self.assertEqual("1,1", first_entry.serialized_symbols)
        self.assertEqual("wav_new", first_entry.wav_path)
        self.assertEqual(15, first_entry.duration)
        self.assertEqual(22050, first_entry.sampling_rate)
        self.assertEqual("mel", first_entry.mel_path)
        self.assertEqual(5, first_entry.n_mel_channels)
    def test_init_from_accents_adds_no_accents(self):
        res = AccentsDict.init_from_accents({"a", "b", "c"})

        self.assertEqual(3, len(res))
    def test_init_from_accents_with_pad_has_pad_at_idx_zero(self):
        res = AccentsDict.init_from_accents_with_pad({"b", "a"}, "xx")

        self.assertEqual("xx", res.get_accent(0))
        self.assertEqual("a", res.get_accent(1))
        self.assertEqual("b", res.get_accent(2))
    def test_init_from_accents_with_pad_ignores_existing_pad(self):
        res = AccentsDict.init_from_accents_with_pad({"b", "a", "xx"}, "xx")

        self.assertEqual("xx", res.get_accent(0))
        self.assertEqual("a", res.get_accent(1))
        self.assertEqual("b", res.get_accent(2))
Esempio n. 20
0
 def get_accents(self) -> AccentsDict:
     return AccentsDict.from_raw(self.accents)
Esempio n. 21
0
def _train(custom_hparams: Optional[Dict[str, str]], taco_logger: Tacotron2Logger, trainset: PreparedDataList, valset: PreparedDataList, save_callback: Any, speakers: SpeakersDict, accents: AccentsDict, symbols: SymbolIdDict, checkpoint: Optional[CheckpointTacotron], warm_model: Optional[CheckpointTacotron], weights_checkpoint: Optional[CheckpointTacotron], weights_map: SymbolsMap, map_from_speaker_name: Optional[str], logger: Logger, checkpoint_logger: Logger):
  """Training and validation logging results to tensorboard and stdout
  Params
  ------
  output_directory (string): directory to save checkpoints
  log_directory (string) directory to save tensorboard logs
  checkpoint_path(string): checkpoint path
  n_gpus (int): number of gpus
  rank (int): rank of current gpu
  hparams (object): comma separated list of "name=value" pairs.
  """

  complete_start = time.time()

  if checkpoint is not None:
    hparams = checkpoint.get_hparams(logger)
  else:
    hparams = HParams(
      n_accents=len(accents),
      n_speakers=len(speakers),
      n_symbols=len(symbols)
    )
  # is it problematic to change the batch size on a trained model?
  hparams = overwrite_custom_hparams(hparams, custom_hparams)

  assert hparams.n_accents > 0
  assert hparams.n_speakers > 0
  assert hparams.n_symbols > 0

  if hparams.use_saved_learning_rate and checkpoint is not None:
    hparams.learning_rate = checkpoint.learning_rate

  log_hparams(hparams, logger)
  init_torch(hparams)

  model, optimizer = load_model_and_optimizer(
    hparams=hparams,
    checkpoint=checkpoint,
    logger=logger,
  )

  iteration = get_iteration(checkpoint)

  if checkpoint is None:
    if warm_model is not None:
      logger.info("Loading states from pretrained model...")
      warm_start_model(model, warm_model, hparams, logger)

    if weights_checkpoint is not None:
      logger.info("Mapping symbol embeddings...")

      pretrained_symbol_weights = get_mapped_symbol_weights(
        model_symbols=symbols,
        trained_weights=weights_checkpoint.get_symbol_embedding_weights(),
        trained_symbols=weights_checkpoint.get_symbols(),
        custom_mapping=weights_map,
        hparams=hparams,
        logger=logger
      )

      update_weights(model.embedding, pretrained_symbol_weights)

      map_speaker_weights = map_from_speaker_name is not None
      if map_speaker_weights:
        logger.info("Mapping speaker embeddings...")
        pretrained_speaker_weights = get_mapped_speaker_weights(
          model_speakers=speakers,
          trained_weights=weights_checkpoint.get_speaker_embedding_weights(),
          trained_speaker=weights_checkpoint.get_speakers(),
          map_from_speaker_name=map_from_speaker_name,
          hparams=hparams,
          logger=logger,
        )

        update_weights(model.speakers_embedding, pretrained_speaker_weights)

  log_symbol_weights(model, logger)

  collate_fn = SymbolsMelCollate(
    n_frames_per_step=hparams.n_frames_per_step,
    padding_symbol_id=symbols.get_id(PADDING_SYMBOL),
    padding_accent_id=accents.get_id(PADDING_ACCENT)
  )

  val_loader = prepare_valloader(hparams, collate_fn, valset, logger)
  train_loader = prepare_trainloader(hparams, collate_fn, trainset, logger)

  batch_iterations = len(train_loader)
  enough_traindata = batch_iterations > 0
  if not enough_traindata:
    msg = "Not enough trainingdata."
    logger.error(msg)
    raise Exception(msg)

  save_it_settings = SaveIterationSettings(
    epochs=hparams.epochs,
    batch_iterations=batch_iterations,
    save_first_iteration=True,
    save_last_iteration=True,
    iters_per_checkpoint=hparams.iters_per_checkpoint,
    epochs_per_checkpoint=hparams.epochs_per_checkpoint
  )

  criterion = Tacotron2Loss()
  batch_durations: List[float] = []

  train_start = time.perf_counter()
  start = train_start
  model.train()
  continue_epoch = get_continue_epoch(iteration, batch_iterations)
  for epoch in range(continue_epoch, hparams.epochs):
    next_batch_iteration = get_continue_batch_iteration(iteration, batch_iterations)
    skip_bar = None
    if next_batch_iteration > 0:
      logger.debug(f"Current batch is {next_batch_iteration} of {batch_iterations}")
      logger.debug("Skipping batches...")
      skip_bar = tqdm(total=next_batch_iteration)
    for batch_iteration, batch in enumerate(train_loader):
      need_to_skip_batch = skip_batch(
        batch_iteration=batch_iteration,
        continue_batch_iteration=next_batch_iteration
      )
      if need_to_skip_batch:
        assert skip_bar is not None
        skip_bar.update(1)
        #debug_logger.debug(f"Skipped batch {batch_iteration + 1}/{next_batch_iteration + 1}.")
        continue
      # debug_logger.debug(f"Current batch: {batch[0][0]}")

      # update_learning_rate_optimizer(optimizer, hparams.learning_rate)

      model.zero_grad()
      x, y = parse_batch(batch)
      y_pred = model(x)

      loss = criterion(y_pred, y)
      reduced_loss = loss.item()

      loss.backward()

      grad_norm = torch.nn.utils.clip_grad_norm_(
        parameters=model.parameters(),
        max_norm=hparams.grad_clip_thresh
      )

      optimizer.step()

      iteration += 1

      end = time.perf_counter()
      duration = end - start
      start = end

      batch_durations.append(duration)
      avg_batch_dur = np.mean(batch_durations)
      avg_epoch_dur = avg_batch_dur * batch_iterations
      remaining_its = hparams.epochs * batch_iterations - iteration
      estimated_remaining_duration = avg_batch_dur * remaining_its

      next_it = get_next_save_it(iteration, save_it_settings)
      next_checkpoint_save_time = 0
      if next_it is not None:
        next_checkpoint_save_time = (next_it - iteration) * avg_batch_dur

      logger.info(" | ".join([
        f"Epoch: {get_formatted_current_total(epoch + 1, hparams.epochs)}",
        f"It.: {get_formatted_current_total(batch_iteration + 1, batch_iterations)}",
        f"Tot. it.: {get_formatted_current_total(iteration, hparams.epochs * batch_iterations)} ({iteration / (hparams.epochs * batch_iterations) * 100:.2f}%)",
        f"Loss: {reduced_loss:.6f}",
        f"Grad norm: {grad_norm:.6f}",
        #f"Dur.: {duration:.2f}s/it",
        f"Avg. dur.: {avg_batch_dur:.2f}s/it & {avg_epoch_dur / 60:.0f}min/epoch",
        f"Tot. dur.: {(time.perf_counter() - train_start) / 60 / 60:.2f}h/{estimated_remaining_duration / 60 / 60:.0f}h ({estimated_remaining_duration / 60 / 60 / 24:.1f}days)",
        f"Next checkpoint: {next_checkpoint_save_time / 60:.0f}min",
      ]))

      taco_logger.log_training(reduced_loss, grad_norm, hparams.learning_rate,
                               duration, iteration)

      save_it = check_save_it(epoch, iteration, save_it_settings)
      if save_it:
        checkpoint = CheckpointTacotron.from_instances(
          model=model,
          optimizer=optimizer,
          hparams=hparams,
          iteration=iteration,
          symbols=symbols,
          accents=accents,
          speakers=speakers
        )

        save_callback(checkpoint)

        valloss = validate(model, criterion, val_loader, iteration, taco_logger, logger)

        # if rank == 0:
        log_checkpoint_score(iteration, grad_norm,
                             reduced_loss, valloss, epoch, batch_iteration, checkpoint_logger)

  duration_s = time.time() - complete_start
  logger.info(f'Finished training. Total duration: {duration_s / 60:.2f}min')
Esempio n. 22
0
def _get_all_accents(l: PreDataList) -> AccentsDict:
    accents = set()
    for x in l.items():
        accents = accents.union(set(x.accents))
    return AccentsDict.init_from_accents(accents)