Exemplo n.º 1
0
    def test_merge_prepared_data(self):
        prep_list = [
            (PreparedDataList([
                PreparedData(0, 1, "", "", "", 0, "0,1,2", 0, 0, "", 0, ""),
            ]), SymbolIdDict({
                0: (0, "a"),
                1: (0, "b"),
                2: (0, "c"),
            })),
            (PreparedDataList([
                PreparedData(0, 2, "", "", "", 0, "0,1,2", 0, 0, "", 0, ""),
            ]), SymbolIdDict({
                0: (0, "b"),
                1: (0, "a"),
                2: (0, "d"),
            })),
        ]

        res, conv = merge_prepared_data(prep_list)

        self.assertEqual(6, len(conv))
        self.assertEqual("todo", conv.get_symbol(0))
        self.assertEqual("todo", conv.get_symbol(1))
        self.assertEqual("a", conv.get_symbol(2))
        self.assertEqual("b", conv.get_symbol(3))
        self.assertEqual("c", conv.get_symbol(4))
        self.assertEqual("d", conv.get_symbol(5))

        self.assertEqual(2, len(res))
        self.assertEqual(0, res[0].i)
        self.assertEqual(1, res[1].i)
        self.assertEqual(1, res[0].entry_id)
        self.assertEqual(2, res[1].entry_id)
        self.assertEqual("2,3,4", res[0].serialized_symbol_ids)
        self.assertEqual("3,2,5", res[1].serialized_symbol_ids)
Exemplo n.º 2
0
def update_symbols_and_text(sentences: SentenceList, sents_new_symbols: List[List[str]]):
  symbols = SymbolIdDict.init_from_symbols(get_unique_items([x for x in sents_new_symbols]))
  for sentence, new_symbols in zip(sentences.items(), sents_new_symbols):
    sentence.serialized_symbols = symbols.get_serialized_ids(new_symbols)
    sentence.text = SymbolIdDict.symbols_to_text(new_symbols)
    assert len(sentence.get_symbol_ids()) == len(new_symbols)
    assert len(sentence.get_accent_ids()) == len(new_symbols)
  return symbols, sentences
Exemplo n.º 3
0
 def replace_unknown_symbols(self, model_symbols: SymbolIdDict, logger: Logger) -> bool:
   unknown_symbols_exist = False
   for sentence in self.items():
     if model_symbols.has_unknown_symbols(sentence.symbols):
       sentence.symbols = model_symbols.replace_unknown_symbols_with_pad(sentence.symbols)
       text = SymbolIdDict.symbols_to_text(sentence.symbols)
       logger.info(f"Sentence {sentence.sent_id} contains unknown symbols: {text}")
       unknown_symbols_exist = True
       assert len(sentence.symbols) == len(sentence.accents)
   return unknown_symbols_exist
Exemplo n.º 4
0
def get_mapped_symbol_weights(model_symbols: SymbolIdDict,
                              trained_weights: Tensor,
                              trained_symbols: SymbolIdDict,
                              custom_mapping: Optional[SymbolsMap],
                              hparams: HParams, logger: Logger) -> Tensor:
    symbols_match_not_model = trained_weights.shape[0] != len(trained_symbols)
    if symbols_match_not_model:
        logger.exception(
            f"Weights mapping: symbol space from pretrained model ({trained_weights.shape[0]}) did not match amount of symbols ({len(trained_symbols)})."
        )
        raise Exception()

    if custom_mapping is None:
        symbols_map = SymbolsMap.from_intersection(
            map_to=model_symbols.get_all_symbols(),
            map_from=trained_symbols.get_all_symbols(),
        )
    else:
        symbols_map = custom_mapping
        symbols_map.remove_unknown_symbols(
            known_to_symbol=model_symbols.get_all_symbols(),
            known_from_symbols=trained_symbols.get_all_symbols())

    # Remove all empty mappings
    symbols_wo_mapping = symbols_map.get_symbols_with_empty_mapping()
    symbols_map.pop_batch(symbols_wo_mapping)

    symbols_id_map = symbols_map.convert_to_symbols_ids_map(
        to_symbols=model_symbols,
        from_symbols=trained_symbols,
    )

    model_symbols_id_map = symbols_ids_map_to_model_symbols_ids_map(
        symbols_id_map,
        hparams.n_accents,
        n_symbols=hparams.n_symbols,
        accents_use_own_symbols=hparams.accents_use_own_symbols)

    model_weights = get_symbol_weights(hparams)

    map_weights(model_symbols_id_map=model_symbols_id_map,
                model_weights=model_weights,
                trained_weights=trained_weights,
                logger=logger)

    not_existing_symbols = model_symbols.get_all_symbols() - symbols_map.keys()
    no_mapping = symbols_wo_mapping | not_existing_symbols
    if len(no_mapping) > 0:
        logger.warning(f"Following symbols were not mapped: {no_mapping}")
    else:
        logger.info("All symbols were mapped.")

    return model_weights
Exemplo n.º 5
0
  def convert_to_symbols_ids_map(self, to_symbols: SymbolIdDict, from_symbols: SymbolIdDict) -> OrderedDictType[int, int]:
    result: OrderedDictType[int, int] = OrderedDict()

    for map_to_symbol, map_from_symbol in self.items():
      assert to_symbols.symbol_exists(map_to_symbol)
      assert from_symbols.symbol_exists(map_from_symbol)

      map_from_symbol_id = from_symbols.get_id(map_from_symbol)
      map_to_symbol_id = to_symbols.get_id(map_to_symbol)
      result[map_to_symbol_id] = map_from_symbol_id
      # logger.info(
      #  f"Mapped symbol '{map_from_symbol}' ({map_from_symbol_id}) to symbol '{map_to_symbol}' ({map_to_symbol_id})")

    return result
Exemplo n.º 6
0
  def test_get_symbols_id_mapping_without_map(self):
    from_symbols = {"b", "c"}
    to_symbols = {"a", "b"}
    from_conv = SymbolIdDict.init_from_symbols(from_symbols)
    to_conv = SymbolIdDict.init_from_symbols(to_symbols)
    mapping = SymbolsMap.from_intersection(from_symbols, to_symbols)
    mapping.update_existing_to_mappings({"a": "c"})

    symbols_id_map = mapping.convert_to_symbols_ids_map(
      to_symbols=to_conv,
      from_symbols=from_conv,
    )

    self.assertEqual(symbols_id_map[from_conv.get_id("b")], to_conv.get_id("b"))
    self.assertEqual(symbols_id_map[from_conv.get_id("c")], to_conv.get_id("a"))
    self.assertEqual(2, len(symbols_id_map))
Exemplo n.º 7
0
 def get_formatted(self, symbol_id_dict: SymbolIdDict, accent_id_dict: AccentsDict, pairs_per_line=170, space_length=0):
   return get_formatted_core(
     sent_id=self.sent_id,
     symbols=symbol_id_dict.get_symbols(self.serialized_symbols),
     accent_ids=self.get_accent_ids(),
     accent_id_dict=accent_id_dict,
     space_length=space_length,
     max_pairs_per_line=pairs_per_line
   )
Exemplo n.º 8
0
def _prepare_data(
    processed_data: List[Tuple[int, List[str], List[int], Language]]
) -> Tuple[TextDataList, SymbolIdDict, AccentsDict, SymbolsDict]:
    result = TextDataList()
    symbol_counter = get_counter([x[1] for x in processed_data])
    symbols_dict = SymbolsDict.fromcounter(symbol_counter)
    conv = SymbolIdDict.init_from_symbols(set(symbols_dict.keys()))

    for entry_id, symbols, accent_ids, lang in processed_data:
        assert len(accent_ids) == len(symbols)
        text = SymbolIdDict.symbols_to_text(symbols)
        serialized_symbol_ids = conv.get_serialized_ids(symbols)
        serialized_accent_ids = serialize_list(accent_ids)
        data = TextData(entry_id, text, serialized_symbol_ids,
                        serialized_accent_ids, lang)
        result.append(data)

    return result, conv, symbols_dict
Exemplo n.º 9
0
 def test_sims_to_csv(self):
   emb = torch.ones(size=(2, 3))
   torch.nn.init.zeros_(emb[0])
   sims = get_similarities(emb)
   symbols = SymbolIdDict.init_from_symbols({"2", "1"})
   res = sims_to_csv(sims, symbols)
   self.assertEqual(2, len(res.index))
   self.assertListEqual(['2', '<=>', '1', 'nan'], list(res.values[0]))
   self.assertListEqual(['1', '<=>', '2', 'nan'], list(res.values[1]))
Exemplo n.º 10
0
  def test_plot_embeddings(self):
    emb = torch.ones(size=(2, 3))
    symbols = SymbolIdDict.init_from_symbols({})
    text, plot2d, plot3d = plot_embeddings(symbols, emb, logging.getLogger())

    self.assertEqual(2, len(text.index))
    self.assertListEqual(['a', '<=>', 'b', '1.00'], list(text.values[0]))
    self.assertListEqual(['b', '<=>', 'a', '1.00'], list(text.values[1]))
    self.assertEqual("2D-Embeddings", plot2d.layout.title.text)
    self.assertEqual("3D-Embeddings", plot3d.layout.title.text)
Exemplo n.º 11
0
def sents_map(sentences: SentenceList, text_symbols: SymbolIdDict, symbols_map: SymbolsMap, ignore_arcs: bool) -> Tuple[SymbolIdDict, SentenceList]:
  sents_new_symbols = []
  result = SentenceList()
  new_sent_id = 0
  for sentence in sentences.items():
    symbols = text_symbols.get_symbols(sentence.serialized_symbols)
    accent_ids = deserialize_list(sentence.serialized_accents)

    mapped_symbols = symbols_map.apply_to_symbols(symbols)

    text = SymbolIdDict.symbols_to_text(mapped_symbols)
    # a resulting empty text would make no problems
    sents = split_sentences(text, sentence.lang)
    for new_sent_text in sents:
      new_symbols = text_to_symbols(
        new_sent_text,
        lang=sentence.lang,
        ignore_tones=False,
        ignore_arcs=ignore_arcs
      )

      if len(accent_ids) > 0:
        new_accent_ids = [accent_ids[0]] * len(new_symbols)
      else:
        new_accent_ids = []

      assert len(new_accent_ids) == len(new_symbols)

      new_sent_id += 1
      tmp = Sentence(
        sent_id=new_sent_id,
        text=new_sent_text,
        lang=sentence.lang,
        serialized_accents=serialize_list(new_accent_ids),
        serialized_symbols=""
      )
      sents_new_symbols.append(new_symbols)

      assert len(tmp.get_accent_ids()) == len(new_symbols)
      result.append(tmp)

  return update_symbols_and_text(result, sents_new_symbols)
Exemplo n.º 12
0
def sims_to_csv(sims: Dict[int, List[Tuple[int, float]]], symbols: SymbolIdDict) -> pd.DataFrame:
  lines = []
  assert len(sims) == len(symbols)
  for symbol_id, similarities in sims.items():
    sims = [f"{symbols.get_symbol(symbol_id)}", "<=>"]
    for other_symbol_id, similarity in similarities:
      sims.append(symbols.get_symbol(other_symbol_id))
      sims.append(f"{similarity:.2f}")
    lines.append(sims)
  df = pd.DataFrame(lines)
  return df
Exemplo n.º 13
0
 def from_sentences(cls, sentences: SentenceList, accents: AccentsDict, symbols: SymbolIdDict):
   res = cls()
   for sentence in sentences.items():
     infer_sent = InferSentence(
       sent_id=sentence.sent_id,
       symbols=symbols.get_symbols(sentence.serialized_symbols),
       accents=accents.get_accents(sentence.serialized_accents)
     )
     assert len(infer_sent.symbols) == len(infer_sent.accents)
     res.append(infer_sent)
   return res
Exemplo n.º 14
0
 def from_instances(cls, model: Tacotron2, optimizer: Adam,
                    hparams: HParams, iteration: int, symbols: SymbolIdDict,
                    accents: AccentsDict, speakers: SpeakersDict):
     result = cls(state_dict=model.state_dict(),
                  optimizer=optimizer.state_dict(),
                  learning_rate=hparams.learning_rate,
                  iteration=iteration,
                  hparams=asdict(hparams),
                  symbols=symbols.raw(),
                  accents=accents.raw(),
                  speakers=speakers.raw())
     return result
Exemplo n.º 15
0
def add_text(text: str, lang: Language) -> Tuple[SymbolIdDict, SentenceList]:
  res = SentenceList()
  sents = split_sentences(text, lang)
  default_accent_id = 0
  sents_symbols: List[List[str]] = [text_to_symbols(
    sent,
    lang=lang,
    ignore_tones=False,
    ignore_arcs=False
  ) for sent in sents]
  symbols = SymbolIdDict.init_from_symbols(get_unique_items(sents_symbols))
  for i, sent_symbols in enumerate(sents_symbols):
    sentence = Sentence(
      sent_id=i + 1,
      lang=lang,
      serialized_symbols=symbols.get_serialized_ids(sent_symbols),
      serialized_accents=serialize_list([default_accent_id] * len(sent_symbols)),
      text=SymbolIdDict.symbols_to_text(sent_symbols),
    )
    res.append(sentence)
  return symbols, res
Exemplo n.º 16
0
def preprocess(
    data: DsDataList, symbol_ids: SymbolIdDict
) -> Tuple[TextDataList, SymbolIdDict, SymbolsDict]:
    processed_data: List[Tuple[int, List[str], List[int], Language]] = []

    values: DsData
    for values in tqdm(data):
        symbols: List[str] = symbol_ids.get_symbols(
            deserialize_list(values.serialized_symbols))
        accents: List[int] = deserialize_list(values.serialized_accents)
        processed_data.append((values.entry_id, symbols, accents, values.lang))

    return _prepare_data(processed_data)
Exemplo n.º 17
0
def plot_embeddings(symbols: SymbolIdDict, emb: torch.Tensor, logger: Logger) -> Tuple[pd.DataFrame, go.Figure, go.Figure]:
  assert emb.shape[0] == len(symbols)

  logger.info(f"Emb size {emb.shape}")
  logger.info(f"Sym len {len(symbols)}")

  sims = get_similarities(emb.numpy())
  df = sims_to_csv(sims, symbols)
  all_symbols = symbols.get_all_symbols()
  emb_normed = norm2emb(emb)
  fig_2d = emb_plot_2d(emb_normed, all_symbols)
  fig_3d = emb_plot_3d(emb_normed, all_symbols)

  return df, fig_2d, fig_3d
Exemplo n.º 18
0
def sents_normalize(sentences: SentenceList, text_symbols: SymbolIdDict) -> Tuple[SymbolIdDict, SentenceList]:
  # Maybe add info if something was unknown
  sents_new_symbols = []
  for sentence in sentences.items():
    new_symbols, new_accent_ids = symbols_normalize(
      symbols=text_symbols.get_symbols(sentence.serialized_symbols),
      lang=sentence.lang,
      accent_ids=deserialize_list(sentence.serialized_accents)
    )
    # TODO: check if new sentences resulted and then split them.
    sentence.serialized_accents = serialize_list(new_accent_ids)
    sents_new_symbols.append(new_symbols)

  return update_symbols_and_text(sentences, sents_new_symbols)
Exemplo n.º 19
0
def sents_accent_template(sentences: SentenceList, text_symbols: SymbolIdDict, accent_ids: AccentsDict) -> AccentedSymbolList:
  res = AccentedSymbolList()
  for i, sent in enumerate(sentences.items()):
    symbols = text_symbols.get_symbols(sent.serialized_symbols)
    accents = accent_ids.get_accents(sent.serialized_accents)
    for j, symbol_accent in enumerate(zip(symbols, accents)):
      symbol, accent = symbol_accent
      accented_symbol = AccentedSymbol(
        position=f"{i}-{j}",
        symbol=symbol,
        accent=accent
      )
      res.append(accented_symbol)
  return res
Exemplo n.º 20
0
    def make_common_symbol_ids(self) -> SymbolIdDict:
        all_symbols: Set[str] = set()
        for ds in self.data:
            all_symbols |= ds.symbol_ids.get_all_symbols()
        new_symbol_ids = SymbolIdDict.init_from_symbols_with_pad(all_symbols)

        for ds in self.data:
            for entry in ds.data.items():
                original_symbols = ds.symbol_ids.get_symbols(
                    entry.serialized_symbol_ids)
                entry.serialized_symbol_ids = new_symbol_ids.get_serialized_ids(
                    original_symbols)
            ds.symbol_ids = new_symbol_ids

        return new_symbol_ids
Exemplo n.º 21
0
def _get_ds_data(l: PreDataList, speakers_dict: SpeakersDict,
                 accents: AccentsDict, symbols: SymbolIdDict) -> DsDataList:
    result = [
        DsData(entry_id=i,
               basename=values.name,
               speaker_name=values.speaker_name,
               speaker_id=speakers_dict[values.speaker_name],
               text=values.text,
               serialized_symbols=symbols.get_serialized_ids(values.symbols),
               serialized_accents=accents.get_serialized_ids(values.accents),
               wav_path=values.wav_path,
               lang=values.lang,
               gender=values.gender) for i, values in enumerate(l.items())
    ]
    return DsDataList(result)
Exemplo n.º 22
0
def symbols_convert_to_ipa(symbols: List[str], lang: Language,
                           accent_ids: List[str], ignore_tones: bool,
                           ignore_arcs: bool) -> Tuple[List[str], List[int]]:
    assert len(symbols) == len(accent_ids)
    # Note: do also for ipa symbols to have possibility to remove arcs and tones
    orig_text = SymbolIdDict.symbols_to_text(symbols)
    ipa = convert_to_ipa(orig_text, lang)
    new_symbols: List[str] = text_to_symbols(ipa, Language.IPA, ignore_tones,
                                             ignore_arcs)
    if len(accent_ids) > 0:
        new_accent_ids = [accent_ids[0]] * len(new_symbols)
    else:
        new_accent_ids = []
    assert len(new_symbols) == len(new_accent_ids)
    return new_symbols, new_accent_ids
Exemplo n.º 23
0
  def test_get_similarities_is_sorted(self):
    symbols = SymbolIdDict.init_from_symbols({"a", "b", "c"})
    emb = np.zeros(shape=(3, 3))
    emb[symbols.get_id("a")] = [0.5, 1.0, 0]
    emb[symbols.get_id("b")] = [1.0, 0.6, 0]
    emb[symbols.get_id("c")] = [1.0, 0.5, 0]

    sims = get_similarities(emb)

    self.assertEqual(3, len(sims))
    self.assertEqual(symbols.get_id("b"), sims[symbols.get_id("a")][0][0])
    self.assertEqual(symbols.get_id("c"), sims[symbols.get_id("a")][1][0])
    self.assertEqual(symbols.get_id("b"), sims[symbols.get_id("c")][0][0])
    self.assertEqual(symbols.get_id("a"), sims[symbols.get_id("c")][1][0])
    self.assertEqual(symbols.get_id("c"), sims[symbols.get_id("b")][0][0])
    self.assertEqual(symbols.get_id("a"), sims[symbols.get_id("b")][1][0])
Exemplo n.º 24
0
def symbols_normalize(symbols: List[str], lang: Language,
                      accent_ids: List[str]) -> Tuple[List[str], List[int]]:
    assert len(symbols) == len(accent_ids)
    orig_text = SymbolIdDict.symbols_to_text(symbols)
    text = normalize(orig_text, lang)
    new_symbols: List[str] = text_to_symbols(text, lang)
    if lang != Language.IPA:
        if len(accent_ids) > 0:
            new_accent_ids = [accent_ids[0]] * len(new_symbols)
        else:
            new_accent_ids = []
    else:
        # because no replacing was done in ipa normalization
        # maybe support remove whitespace
        new_accent_ids = accent_ids
    assert len(new_symbols) == len(new_accent_ids)
    return new_symbols, new_accent_ids
Exemplo n.º 25
0
def convert_to_ipa(
        data: TextDataList, symbol_converter: SymbolIdDict, ignore_tones: bool,
        ignore_arcs: bool) -> Tuple[TextDataList, SymbolIdDict, SymbolsDict]:
    processed_data: List[Tuple[int, List[str], List[int], Language]] = []

    values: TextData
    for values in tqdm(data.items()):
        new_symbols, new_accent_ids = symbols_convert_to_ipa(
            symbols=symbol_converter.get_symbols(values.serialized_symbol_ids),
            lang=values.lang,
            accent_ids=deserialize_list(values.serialized_accent_ids),
            ignore_arcs=ignore_arcs,
            ignore_tones=ignore_tones)
        processed_data.append(
            (values.entry_id, new_symbols, new_accent_ids, Language.IPA))

    return _prepare_data(processed_data)
Exemplo n.º 26
0
def normalize(
    data: TextDataList, symbol_converter: SymbolIdDict
) -> Tuple[TextDataList, SymbolIdDict, SymbolsDict]:
    processed_data: List[Tuple[int, List[str], List[int], Language]] = []

    values: TextData
    for values in tqdm(data):
        new_symbols, new_accent_ids = symbols_normalize(
            symbols=symbol_converter.get_symbols(values.serialized_symbol_ids),
            lang=values.lang,
            accent_ids=deserialize_list(values.serialized_accent_ids),
        )

        processed_data.append(
            (values.entry_id, new_symbols, new_accent_ids, values.lang))

    return _prepare_data(processed_data)
Exemplo n.º 27
0
def sents_convert_to_ipa(sentences: SentenceList, text_symbols: SymbolIdDict, ignore_tones: bool, ignore_arcs: bool) -> Tuple[SymbolIdDict, SentenceList]:

  sents_new_symbols = []
  for sentence in sentences.items(True):
    new_symbols, new_accent_ids = symbols_convert_to_ipa(
      symbols=text_symbols.get_symbols(sentence.serialized_symbols),
      lang=sentence.lang,
      accent_ids=deserialize_list(sentence.serialized_accents),
      ignore_arcs=ignore_arcs,
      ignore_tones=ignore_tones
    )
    assert len(new_symbols) == len(new_accent_ids)
    sentence.lang = Language.IPA
    sentence.serialized_accents = serialize_list(new_accent_ids)
    sents_new_symbols.append(new_symbols)
    assert len(sentence.get_accent_ids()) == len(new_symbols)

  return update_symbols_and_text(sentences, sents_new_symbols)
Exemplo n.º 28
0
def convert_v1_to_v2_model(old_model_path: str,
                           custom_hparams: Optional[Dict[str, str]],
                           speakers: SpeakersDict, accents: AccentsDict,
                           symbols: SymbolIdDict):
    checkpoint_dict = torch.load(old_model_path, map_location='cpu')
    hparams = HParams(n_speakers=len(speakers),
                      n_accents=len(accents),
                      n_symbols=len(symbols))

    hparams = overwrite_custom_hparams(hparams, custom_hparams)

    chp = CheckpointTacotron(state_dict=checkpoint_dict["state_dict"],
                             optimizer=checkpoint_dict["optimizer"],
                             learning_rate=checkpoint_dict["learning_rate"],
                             iteration=checkpoint_dict["iteration"] + 1,
                             hparams=asdict(hparams),
                             speakers=speakers.raw(),
                             symbols=symbols.raw(),
                             accents=accents.raw())

    new_model_path = f"{old_model_path}_{get_pytorch_filename(chp.iteration)}"

    chp.save(new_model_path, logging.getLogger())
Exemplo n.º 29
0
def _train(custom_hparams: Optional[Dict[str, str]], taco_logger: Tacotron2Logger, trainset: PreparedDataList, valset: PreparedDataList, save_callback: Any, speakers: SpeakersDict, accents: AccentsDict, symbols: SymbolIdDict, checkpoint: Optional[CheckpointTacotron], warm_model: Optional[CheckpointTacotron], weights_checkpoint: Optional[CheckpointTacotron], weights_map: SymbolsMap, map_from_speaker_name: Optional[str], logger: Logger, checkpoint_logger: Logger):
  """Training and validation logging results to tensorboard and stdout
  Params
  ------
  output_directory (string): directory to save checkpoints
  log_directory (string) directory to save tensorboard logs
  checkpoint_path(string): checkpoint path
  n_gpus (int): number of gpus
  rank (int): rank of current gpu
  hparams (object): comma separated list of "name=value" pairs.
  """

  complete_start = time.time()

  if checkpoint is not None:
    hparams = checkpoint.get_hparams(logger)
  else:
    hparams = HParams(
      n_accents=len(accents),
      n_speakers=len(speakers),
      n_symbols=len(symbols)
    )
  # is it problematic to change the batch size on a trained model?
  hparams = overwrite_custom_hparams(hparams, custom_hparams)

  assert hparams.n_accents > 0
  assert hparams.n_speakers > 0
  assert hparams.n_symbols > 0

  if hparams.use_saved_learning_rate and checkpoint is not None:
    hparams.learning_rate = checkpoint.learning_rate

  log_hparams(hparams, logger)
  init_torch(hparams)

  model, optimizer = load_model_and_optimizer(
    hparams=hparams,
    checkpoint=checkpoint,
    logger=logger,
  )

  iteration = get_iteration(checkpoint)

  if checkpoint is None:
    if warm_model is not None:
      logger.info("Loading states from pretrained model...")
      warm_start_model(model, warm_model, hparams, logger)

    if weights_checkpoint is not None:
      logger.info("Mapping symbol embeddings...")

      pretrained_symbol_weights = get_mapped_symbol_weights(
        model_symbols=symbols,
        trained_weights=weights_checkpoint.get_symbol_embedding_weights(),
        trained_symbols=weights_checkpoint.get_symbols(),
        custom_mapping=weights_map,
        hparams=hparams,
        logger=logger
      )

      update_weights(model.embedding, pretrained_symbol_weights)

      map_speaker_weights = map_from_speaker_name is not None
      if map_speaker_weights:
        logger.info("Mapping speaker embeddings...")
        pretrained_speaker_weights = get_mapped_speaker_weights(
          model_speakers=speakers,
          trained_weights=weights_checkpoint.get_speaker_embedding_weights(),
          trained_speaker=weights_checkpoint.get_speakers(),
          map_from_speaker_name=map_from_speaker_name,
          hparams=hparams,
          logger=logger,
        )

        update_weights(model.speakers_embedding, pretrained_speaker_weights)

  log_symbol_weights(model, logger)

  collate_fn = SymbolsMelCollate(
    n_frames_per_step=hparams.n_frames_per_step,
    padding_symbol_id=symbols.get_id(PADDING_SYMBOL),
    padding_accent_id=accents.get_id(PADDING_ACCENT)
  )

  val_loader = prepare_valloader(hparams, collate_fn, valset, logger)
  train_loader = prepare_trainloader(hparams, collate_fn, trainset, logger)

  batch_iterations = len(train_loader)
  enough_traindata = batch_iterations > 0
  if not enough_traindata:
    msg = "Not enough trainingdata."
    logger.error(msg)
    raise Exception(msg)

  save_it_settings = SaveIterationSettings(
    epochs=hparams.epochs,
    batch_iterations=batch_iterations,
    save_first_iteration=True,
    save_last_iteration=True,
    iters_per_checkpoint=hparams.iters_per_checkpoint,
    epochs_per_checkpoint=hparams.epochs_per_checkpoint
  )

  criterion = Tacotron2Loss()
  batch_durations: List[float] = []

  train_start = time.perf_counter()
  start = train_start
  model.train()
  continue_epoch = get_continue_epoch(iteration, batch_iterations)
  for epoch in range(continue_epoch, hparams.epochs):
    next_batch_iteration = get_continue_batch_iteration(iteration, batch_iterations)
    skip_bar = None
    if next_batch_iteration > 0:
      logger.debug(f"Current batch is {next_batch_iteration} of {batch_iterations}")
      logger.debug("Skipping batches...")
      skip_bar = tqdm(total=next_batch_iteration)
    for batch_iteration, batch in enumerate(train_loader):
      need_to_skip_batch = skip_batch(
        batch_iteration=batch_iteration,
        continue_batch_iteration=next_batch_iteration
      )
      if need_to_skip_batch:
        assert skip_bar is not None
        skip_bar.update(1)
        #debug_logger.debug(f"Skipped batch {batch_iteration + 1}/{next_batch_iteration + 1}.")
        continue
      # debug_logger.debug(f"Current batch: {batch[0][0]}")

      # update_learning_rate_optimizer(optimizer, hparams.learning_rate)

      model.zero_grad()
      x, y = parse_batch(batch)
      y_pred = model(x)

      loss = criterion(y_pred, y)
      reduced_loss = loss.item()

      loss.backward()

      grad_norm = torch.nn.utils.clip_grad_norm_(
        parameters=model.parameters(),
        max_norm=hparams.grad_clip_thresh
      )

      optimizer.step()

      iteration += 1

      end = time.perf_counter()
      duration = end - start
      start = end

      batch_durations.append(duration)
      avg_batch_dur = np.mean(batch_durations)
      avg_epoch_dur = avg_batch_dur * batch_iterations
      remaining_its = hparams.epochs * batch_iterations - iteration
      estimated_remaining_duration = avg_batch_dur * remaining_its

      next_it = get_next_save_it(iteration, save_it_settings)
      next_checkpoint_save_time = 0
      if next_it is not None:
        next_checkpoint_save_time = (next_it - iteration) * avg_batch_dur

      logger.info(" | ".join([
        f"Epoch: {get_formatted_current_total(epoch + 1, hparams.epochs)}",
        f"It.: {get_formatted_current_total(batch_iteration + 1, batch_iterations)}",
        f"Tot. it.: {get_formatted_current_total(iteration, hparams.epochs * batch_iterations)} ({iteration / (hparams.epochs * batch_iterations) * 100:.2f}%)",
        f"Loss: {reduced_loss:.6f}",
        f"Grad norm: {grad_norm:.6f}",
        #f"Dur.: {duration:.2f}s/it",
        f"Avg. dur.: {avg_batch_dur:.2f}s/it & {avg_epoch_dur / 60:.0f}min/epoch",
        f"Tot. dur.: {(time.perf_counter() - train_start) / 60 / 60:.2f}h/{estimated_remaining_duration / 60 / 60:.0f}h ({estimated_remaining_duration / 60 / 60 / 24:.1f}days)",
        f"Next checkpoint: {next_checkpoint_save_time / 60:.0f}min",
      ]))

      taco_logger.log_training(reduced_loss, grad_norm, hparams.learning_rate,
                               duration, iteration)

      save_it = check_save_it(epoch, iteration, save_it_settings)
      if save_it:
        checkpoint = CheckpointTacotron.from_instances(
          model=model,
          optimizer=optimizer,
          hparams=hparams,
          iteration=iteration,
          symbols=symbols,
          accents=accents,
          speakers=speakers
        )

        save_callback(checkpoint)

        valloss = validate(model, criterion, val_loader, iteration, taco_logger, logger)

        # if rank == 0:
        log_checkpoint_score(iteration, grad_norm,
                             reduced_loss, valloss, epoch, batch_iteration, checkpoint_logger)

  duration_s = time.time() - complete_start
  logger.info(f'Finished training. Total duration: {duration_s / 60:.2f}min')
Exemplo n.º 30
0
def save_prep_symbol_converter(prep_dir: str, data: SymbolIdDict):
    path = os.path.join(prep_dir, _prepared_symbols_json)
    data.save(path)