Exemplo n.º 1
0
def update_accents(data: MergedDataset,
                   accent_ids: AccentsDict) -> AccentsDict:
    new_accents: Set[str] = {
        x
        for y in data.items()
        for x in accent_ids.get_accents(y.serialized_accent_ids)
    }
    new_accent_ids = AccentsDict.init_from_accents_with_pad(
        new_accents, pad_accent=DEFAULT_PADDING_ACCENT)
    if new_accent_ids.get_all_accents() != accent_ids.get_all_accents():
        for entry in data.items():
            original_accents = accent_ids.get_accents(
                entry.serialized_accent_ids)
            entry.serialized_accent_ids = new_accent_ids.get_serialized_ids(
                original_accents)
    return new_accent_ids
Exemplo n.º 2
0
def set_accent(sentences: SentenceList, accent_ids: AccentsDict, accent: str) -> Tuple[SymbolIdDict, SentenceList]:
  accent_id = accent_ids.get_id(accent)
  for sentence in sentences.items():
    new_accent_ids = [accent_id] * len(sentence.get_accent_ids())
    sentence.serialized_accents = serialize_list(new_accent_ids)
    assert len(sentence.get_accent_ids()) == len(sentence.get_symbol_ids())
  return sentences
Exemplo n.º 3
0
 def get_formatted_old(self, accent_id_dict: AccentsDict, pairs_per_line=170, space_length=0):
   return get_formatted_core(
     sent_id=self.sent_id,
     symbols=self.symbols,
     accent_ids=accent_id_dict.get_ids(self.accents),
     accent_id_dict=accent_id_dict,
     space_length=space_length,
     max_pairs_per_line=pairs_per_line
   )
Exemplo n.º 4
0
def sents_accent_apply(sentences: SentenceList, accented_symbols: AccentedSymbolList, accent_ids: AccentsDict) -> SentenceList:
  current_index = 0
  for sent in sentences.items():
    accent_ids_count = len(deserialize_list(sent.serialized_accents))
    assert len(accented_symbols) >= current_index + accent_ids_count
    accented_symbol_selection: List[AccentedSymbol] = accented_symbols[current_index:current_index + accent_ids_count]
    current_index += accent_ids_count
    new_accent_ids = accent_ids.get_ids([x.accent for x in accented_symbol_selection])
    sent.serialized_accents = serialize_list(new_accent_ids)
    assert len(sent.get_accent_ids()) == len(sent.get_symbol_ids())
  return sentences
Exemplo n.º 5
0
 def from_instances(cls, model: Tacotron2, optimizer: Adam,
                    hparams: HParams, iteration: int, symbols: SymbolIdDict,
                    accents: AccentsDict, speakers: SpeakersDict):
     result = cls(state_dict=model.state_dict(),
                  optimizer=optimizer.state_dict(),
                  learning_rate=hparams.learning_rate,
                  iteration=iteration,
                  hparams=asdict(hparams),
                  symbols=symbols.raw(),
                  accents=accents.raw(),
                  speakers=speakers.raw())
     return result
Exemplo n.º 6
0
 def from_sentences(cls, sentences: SentenceList, accents: AccentsDict, symbols: SymbolIdDict):
   res = cls()
   for sentence in sentences.items():
     infer_sent = InferSentence(
       sent_id=sentence.sent_id,
       symbols=symbols.get_symbols(sentence.serialized_symbols),
       accents=accents.get_accents(sentence.serialized_accents),
       original_text=sentence.original_text,
     )
     assert len(infer_sent.symbols) == len(infer_sent.accents)
     res.append(infer_sent)
   return res
Exemplo n.º 7
0
def sents_accent_template(sentences: SentenceList, text_symbols: SymbolIdDict, accent_ids: AccentsDict) -> AccentedSymbolList:
  res = AccentedSymbolList()
  for i, sent in enumerate(sentences.items()):
    symbols = text_symbols.get_symbols(sent.serialized_symbols)
    accents = accent_ids.get_accents(sent.serialized_accents)
    for j, symbol_accent in enumerate(zip(symbols, accents)):
      symbol, accent = symbol_accent
      accented_symbol = AccentedSymbol(
        position=f"{i}-{j}",
        symbol=symbol,
        accent=accent
      )
      res.append(accented_symbol)
  return res
Exemplo n.º 8
0
    def make_common_accent_ids(self) -> AccentsDict:
        all_accents: Set[str] = set()
        for ds in self.data:
            all_accents |= ds.accent_ids.get_all_accents()
        new_accent_ids = AccentsDict.init_from_accents_with_pad(
            all_accents, pad_accent=DEFAULT_PADDING_ACCENT)

        for ds in self.data:
            for entry in ds.data.items():
                original_accents = ds.accent_ids.get_accents(
                    entry.serialized_accent_ids)
                entry.serialized_accent_ids = new_accent_ids.get_serialized_ids(
                    original_accents)
            ds.accent_ids = new_accent_ids

        return new_accent_ids
Exemplo n.º 9
0
def convert_v1_to_v2_model(old_model_path: str,
                           custom_hparams: Optional[Dict[str, str]],
                           speakers: SpeakersDict, accents: AccentsDict,
                           symbols: SymbolIdDict):
    checkpoint_dict = torch.load(old_model_path, map_location='cpu')
    hparams = HParams(n_speakers=len(speakers),
                      n_accents=len(accents),
                      n_symbols=len(symbols))

    hparams = overwrite_custom_hparams(hparams, custom_hparams)

    chp = CheckpointTacotron(state_dict=checkpoint_dict["state_dict"],
                             optimizer=checkpoint_dict["optimizer"],
                             learning_rate=checkpoint_dict["learning_rate"],
                             iteration=checkpoint_dict["iteration"] + 1,
                             hparams=asdict(hparams),
                             speakers=speakers.raw(),
                             symbols=symbols.raw(),
                             accents=accents.raw())

    new_model_path = f"{old_model_path}_{get_pytorch_filename(chp.iteration)}"

    chp.save(new_model_path, logging.getLogger())
Exemplo n.º 10
0
 def get_accents(self) -> AccentsDict:
     return AccentsDict.from_raw(self.accents)
Exemplo n.º 11
0
    def test_ds_dataset_to_merged_dataset(self):
        ds = DsDataset(
            name="ds",
            accent_ids=AccentsDict.init_from_accents({"c", "d"}),
            data=DsDataList([
                DsData(
                    entry_id=1,
                    basename="basename",
                    speaker_id=2,
                    serialized_symbols="0,1",
                    gender=Gender.MALE,
                    lang=Language.GER,
                    serialized_accents="0,0",
                    wav_path="wav",
                    speaker_name="speaker",
                    text="text",
                )
            ]),
            mels=MelDataList([
                MelData(
                    entry_id=1,
                    relative_mel_path="mel",
                    n_mel_channels=5,
                )
            ]),
            speakers=SpeakersDict.fromlist(["sp1", "sp2", "sp3"]),
            symbol_ids=SymbolIdDict.init_from_symbols({"a", "b"}),
            texts=TextDataList([
                TextData(
                    entry_id=1,
                    lang=Language.IPA,
                    serialized_accent_ids="1,0",
                    serialized_symbol_ids="1,1",
                    text="text_new",
                )
            ]),
            wavs=WavDataList([
                WavData(
                    entry_id=1,
                    duration=15,
                    sr=22050,
                    relative_wav_path="wav_new",
                )
            ]),
        )

        res = ds_dataset_to_merged_dataset(ds)

        self.assertEqual("ds", res.name)
        self.assertEqual(2, len(res.accent_ids))
        self.assertEqual(3, len(res.speakers))
        self.assertEqual(2, len(res.symbol_ids))
        self.assertEqual(1, len(res.items()))
        first_entry = res.items()[0]
        self.assertEqual(1, first_entry.entry_id)
        self.assertEqual(Gender.MALE, first_entry.gender)
        self.assertEqual("basename", first_entry.basename)
        self.assertEqual(2, first_entry.speaker_id)
        self.assertEqual(Language.IPA, first_entry.lang)
        self.assertEqual("1,0", first_entry.serialized_accents)
        self.assertEqual("1,1", first_entry.serialized_symbols)
        self.assertEqual("wav_new", first_entry.wav_path)
        self.assertEqual(15, first_entry.duration)
        self.assertEqual(22050, first_entry.sampling_rate)
        self.assertEqual("mel", first_entry.mel_path)
        self.assertEqual(5, first_entry.n_mel_channels)
Exemplo n.º 12
0
def _train(custom_hparams: Optional[Dict[str, str]],
           taco_logger: Tacotron2Logger, trainset: PreparedDataList,
           valset: PreparedDataList, save_callback: Callable[[str], None],
           speakers: SpeakersDict, accents: AccentsDict, symbols: SymbolIdDict,
           checkpoint: Optional[CheckpointTacotron],
           warm_model: Optional[CheckpointTacotron],
           weights_checkpoint: Optional[CheckpointTacotron],
           weights_map: SymbolsMap, map_from_speaker_name: Optional[str],
           logger: Logger, checkpoint_logger: Logger):
    """Training and validation logging results to tensorboard and stdout
  Params
  ------
  output_directory (string): directory to save checkpoints
  log_directory (string) directory to save tensorboard logs
  checkpoint_path(string): checkpoint path
  n_gpus (int): number of gpus
  rank (int): rank of current gpu
  hparams (object): comma separated list of "name=value" pairs.
  """

    complete_start = time.time()

    if checkpoint is not None:
        hparams = checkpoint.get_hparams(logger)
    else:
        hparams = HParams(n_accents=len(accents),
                          n_speakers=len(speakers),
                          n_symbols=len(symbols))
    # TODO: it should not be recommended to change the batch size on a trained model
    hparams = overwrite_custom_hparams(hparams, custom_hparams)

    assert hparams.n_accents > 0
    assert hparams.n_speakers > 0
    assert hparams.n_symbols > 0

    if hparams.use_saved_learning_rate and checkpoint is not None:
        hparams.learning_rate = checkpoint.learning_rate

    log_hparams(hparams, logger)
    init_global_seeds(hparams.seed)
    init_torch(hparams)

    model, optimizer = load_model_and_optimizer(
        hparams=hparams,
        checkpoint=checkpoint,
        logger=logger,
    )

    iteration = get_iteration(checkpoint)

    if checkpoint is None:
        if warm_model is not None:
            logger.info("Loading states from pretrained model...")
            warm_start_model(model, warm_model, hparams, logger)

        if weights_checkpoint is not None:
            logger.info("Mapping symbol embeddings...")

            pretrained_symbol_weights = get_mapped_symbol_weights(
                model_symbols=symbols,
                trained_weights=weights_checkpoint.
                get_symbol_embedding_weights(),
                trained_symbols=weights_checkpoint.get_symbols(),
                custom_mapping=weights_map,
                hparams=hparams,
                logger=logger)

            update_weights(model.embedding, pretrained_symbol_weights)

            logger.info("Checking if mapping speaker embeddings...")
            weights_checkpoint_hparams = weights_checkpoint.get_hparams(logger)
            map_speaker_weights = hparams.use_speaker_embedding and weights_checkpoint_hparams.use_speaker_embedding and map_from_speaker_name is not None
            if map_speaker_weights:
                logger.info("Mapping speaker embeddings...")
                pretrained_speaker_weights = get_mapped_speaker_weights(
                    model_speakers=speakers,
                    trained_weights=weights_checkpoint.
                    get_speaker_embedding_weights(),
                    trained_speaker=weights_checkpoint.get_speakers(),
                    map_from_speaker_name=map_from_speaker_name,
                    hparams=hparams,
                    logger=logger,
                )

                update_weights(model.speakers_embedding,
                               pretrained_speaker_weights)
            logger.info(f"Done. Mapped speaker weights: {map_speaker_weights}")

    log_symbol_weights(model, logger)

    collate_fn = SymbolsMelCollate(
        n_frames_per_step=hparams.n_frames_per_step,
        padding_symbol_id=symbols.get_id(DEFAULT_PADDING_SYMBOL),
        padding_accent_id=accents.get_id(DEFAULT_PADDING_ACCENT))

    val_loader = prepare_valloader(hparams, collate_fn, valset, logger)
    train_loader = prepare_trainloader(hparams, collate_fn, trainset, logger)

    batch_iterations = len(train_loader)
    enough_traindata = batch_iterations > 0
    if not enough_traindata:
        msg = "Not enough trainingdata."
        logger.error(msg)
        raise Exception(msg)

    save_it_settings = SaveIterationSettings(
        epochs=hparams.epochs,
        iterations=hparams.iterations,
        batch_iterations=batch_iterations,
        save_first_iteration=True,
        save_last_iteration=True,
        iters_per_checkpoint=hparams.iters_per_checkpoint,
        epochs_per_checkpoint=hparams.epochs_per_checkpoint)

    last_iteration = get_last_iteration(hparams.epochs, batch_iterations,
                                        hparams.iterations)
    last_epoch_one_based = iteration_to_epoch(last_iteration,
                                              batch_iterations) + 1

    criterion = Tacotron2Loss()
    batch_durations: List[float] = []

    train_start = time.perf_counter()
    start = train_start
    model.train()
    continue_epoch = get_continue_epoch(iteration, batch_iterations)
    for epoch in range(continue_epoch, last_epoch_one_based):
        # logger.debug("==new epoch==")
        next_batch_iteration = get_continue_batch_iteration(
            iteration, batch_iterations)
        skip_bar = None
        if next_batch_iteration > 0:
            logger.debug(
                f"Current batch is {next_batch_iteration} of {batch_iterations}"
            )
            logger.debug("Skipping batches...")
            skip_bar = tqdm(total=next_batch_iteration)
        for batch_iteration, batch in enumerate(train_loader):
            # logger.debug(f"Used batch with fingerprint: {sum(batch[0][0])}")
            need_to_skip_batch = skip_batch(
                batch_iteration=batch_iteration,
                continue_batch_iteration=next_batch_iteration)

            if need_to_skip_batch:
                assert skip_bar is not None
                skip_bar.update(1)
                #debug_logger.debug(f"Skipped batch {batch_iteration + 1}/{next_batch_iteration + 1}.")
                continue
            # debug_logger.debug(f"Current batch: {batch[0][0]}")

            # update_learning_rate_optimizer(optimizer, hparams.learning_rate)

            model.zero_grad()
            x, y = parse_batch(batch)
            y_pred = model(x)

            loss = criterion(y_pred, y)
            reduced_loss = loss.item()

            loss.backward()

            grad_norm = torch.nn.utils.clip_grad_norm_(
                parameters=model.parameters(),
                max_norm=hparams.grad_clip_thresh)

            optimizer.step()

            iteration += 1

            end = time.perf_counter()
            duration = end - start
            start = end

            batch_durations.append(duration)
            avg_batch_dur = np.mean(batch_durations)
            avg_epoch_dur = avg_batch_dur * batch_iterations
            remaining_its = last_iteration - iteration
            estimated_remaining_duration = avg_batch_dur * remaining_its

            next_it = get_next_save_it(iteration, save_it_settings)
            next_checkpoint_save_time = 0
            if next_it is not None:
                next_checkpoint_save_time = (next_it -
                                             iteration) * avg_batch_dur

            logger.info(" | ".join([
                f"Ep: {get_formatted_current_total(epoch + 1, last_epoch_one_based)}",
                f"It.: {get_formatted_current_total(batch_iteration + 1, batch_iterations)}",
                f"Tot. it.: {get_formatted_current_total(iteration, last_iteration)} ({iteration / last_iteration * 100:.2f}%)",
                f"Utts.: {iteration * hparams.batch_size}",
                f"Loss: {reduced_loss:.6f}",
                f"Grad norm: {grad_norm:.6f}",
                #f"Dur.: {duration:.2f}s/it",
                f"Avg. dur.: {avg_batch_dur:.2f}s/it & {avg_epoch_dur / 60:.0f}m/epoch",
                f"Tot. dur.: {(time.perf_counter() - train_start) / 60 / 60:.2f}h/{estimated_remaining_duration / 60 / 60:.0f}h ({estimated_remaining_duration / 60 / 60 / 24:.1f}days)",
                f"Next ckp.: {next_checkpoint_save_time / 60:.0f}m",
            ]))

            taco_logger.log_training(reduced_loss, grad_norm,
                                     hparams.learning_rate, duration,
                                     iteration)

            save_it = check_save_it(epoch, iteration, save_it_settings)
            if save_it:
                checkpoint = CheckpointTacotron.from_instances(
                    model=model,
                    optimizer=optimizer,
                    hparams=hparams,
                    iteration=iteration,
                    symbols=symbols,
                    accents=accents,
                    speakers=speakers)

                save_callback(checkpoint)

                valloss = validate(model, criterion, val_loader, iteration,
                                   taco_logger, logger)

                # if rank == 0:
                log_checkpoint_score(iteration=iteration,
                                     gradloss=grad_norm,
                                     trainloss=reduced_loss,
                                     valloss=valloss,
                                     epoch_one_based=epoch + 1,
                                     batch_it_one_based=batch_iteration + 1,
                                     batch_size=hparams.batch_size,
                                     checkpoint_logger=checkpoint_logger)

            is_last_it = iteration == last_iteration
            if is_last_it:
                break

    duration_s = time.time() - complete_start
    logger.info(f'Finished training. Total duration: {duration_s / 60:.2f}m')