コード例 #1
0
def preprocess(
    data: DsDataList, symbol_ids: SymbolIdDict
) -> Tuple[TextDataList, SymbolIdDict, SymbolsDict]:
    processed_data: List[Tuple[int, List[str], List[int], Language]] = []

    values: DsData
    for values in tqdm(data):
        symbols: List[str] = symbol_ids.get_symbols(
            deserialize_list(values.serialized_symbols))
        accents: List[int] = deserialize_list(values.serialized_accents)
        processed_data.append((values.entry_id, symbols, accents, values.lang))

    return _prepare_data(processed_data)
コード例 #2
0
 def get_accents(self, accent_ids: Union[str, List[int]]) -> List[str]:
     if isinstance(accent_ids, str):
         accent_ids = deserialize_list(accent_ids)
     elif not isinstance(accent_ids, list):
         assert False
     accents = [self.get_accent(accent_id) for accent_id in accent_ids]
     return accents
コード例 #3
0
 def get_symbols(self, symbol_ids: Union[str, List[int]]) -> List[str]:
     if isinstance(symbol_ids, str):
         symbol_ids = deserialize_list(symbol_ids)
     elif not isinstance(symbol_ids, list):
         assert False
     symbols = [self.get_symbol(s_id) for s_id in symbol_ids]
     return symbols
コード例 #4
0
ファイル: merge_ds.py プロジェクト: stefantaubert/tacotron2
 def _remove_unused_accents(self) -> None:
     all_accent_ids: Set[int] = set()
     for entry in self.data.items():
         all_accent_ids |= set(deserialize_list(
             entry.serialized_accent_ids))
     unused_accent_ids = self.accent_ids.get_all_ids().difference(
         all_accent_ids)
     # unused_symbols = unused_symbols.difference({PADDING_SYMBOL})
     self.accent_ids.remove_ids(unused_accent_ids)
コード例 #5
0
def sents_accent_apply(sentences: SentenceList, accented_symbols: AccentedSymbolList, accent_ids: AccentsDict) -> SentenceList:
  current_index = 0
  for sent in sentences.items():
    accent_ids_count = len(deserialize_list(sent.serialized_accents))
    assert len(accented_symbols) >= current_index + accent_ids_count
    accented_symbol_selection: List[AccentedSymbol] = accented_symbols[current_index:current_index + accent_ids_count]
    current_index += accent_ids_count
    new_accent_ids = accent_ids.get_ids([x.accent for x in accented_symbol_selection])
    sent.serialized_accents = serialize_list(new_accent_ids)
    assert len(sent.get_accent_ids()) == len(sent.get_symbol_ids())
  return sentences
コード例 #6
0
    def __init__(self, prepare_ds_ms_data: PreparedDataList, hparams: HParams,
                 logger: Logger):
        data = prepare_ds_ms_data

        random.seed(hparams.seed)
        random.shuffle(data)
        self.use_saved_mels: bool = hparams.use_saved_mels
        if not hparams.use_saved_mels:
            self.mel_parser = TacotronSTFT(hparams, logger)

        logger.info("Reading files...")
        self.data: Dict[int, Tuple[IntTensor, IntTensor, str, int]] = {}
        for i, values in enumerate(data.items(True)):
            symbol_ids = deserialize_list(values.serialized_symbol_ids)
            accent_ids = deserialize_list(values.serialized_accent_ids)

            model_symbol_ids = get_model_symbol_ids(
                symbol_ids, accent_ids, hparams.n_symbols,
                hparams.accents_use_own_symbols)

            symbols_tensor = IntTensor(model_symbol_ids)
            accents_tensor = IntTensor(accent_ids)

            if hparams.use_saved_mels:
                self.data[i] = (symbols_tensor, accents_tensor,
                                values.mel_path, values.speaker_id)
            else:
                self.data[i] = (symbols_tensor, accents_tensor,
                                values.wav_path, values.speaker_id)

        if hparams.use_saved_mels and hparams.cache_mels:
            logger.info("Loading mels into memory...")
            self.cache: Dict[int, Tensor] = {}
            vals: tuple
            for i, vals in tqdm(self.data.items()):
                mel_tensor = torch.load(vals[1], map_location='cpu')
                self.cache[i] = mel_tensor
        self.use_cache: bool = hparams.cache_mels
コード例 #7
0
def sents_normalize(sentences: SentenceList, text_symbols: SymbolIdDict) -> Tuple[SymbolIdDict, SentenceList]:
  # Maybe add info if something was unknown
  sents_new_symbols = []
  for sentence in sentences.items():
    new_symbols, new_accent_ids = symbols_normalize(
      symbols=text_symbols.get_symbols(sentence.serialized_symbols),
      lang=sentence.lang,
      accent_ids=deserialize_list(sentence.serialized_accents)
    )
    # TODO: check if new sentences resulted and then split them.
    sentence.serialized_accents = serialize_list(new_accent_ids)
    sents_new_symbols.append(new_symbols)

  return update_symbols_and_text(sentences, sents_new_symbols)
コード例 #8
0
def normalize(
    data: TextDataList, symbol_converter: SymbolIdDict
) -> Tuple[TextDataList, SymbolIdDict, SymbolsDict]:
    processed_data: List[Tuple[int, List[str], List[int], Language]] = []

    values: TextData
    for values in tqdm(data):
        new_symbols, new_accent_ids = symbols_normalize(
            symbols=symbol_converter.get_symbols(values.serialized_symbol_ids),
            lang=values.lang,
            accent_ids=deserialize_list(values.serialized_accent_ids),
        )

        processed_data.append(
            (values.entry_id, new_symbols, new_accent_ids, values.lang))

    return _prepare_data(processed_data)
コード例 #9
0
def convert_to_ipa(
        data: TextDataList, symbol_converter: SymbolIdDict, ignore_tones: bool,
        ignore_arcs: bool) -> Tuple[TextDataList, SymbolIdDict, SymbolsDict]:
    processed_data: List[Tuple[int, List[str], List[int], Language]] = []

    values: TextData
    for values in tqdm(data.items()):
        new_symbols, new_accent_ids = symbols_convert_to_ipa(
            symbols=symbol_converter.get_symbols(values.serialized_symbol_ids),
            lang=values.lang,
            accent_ids=deserialize_list(values.serialized_accent_ids),
            ignore_arcs=ignore_arcs,
            ignore_tones=ignore_tones)
        processed_data.append(
            (values.entry_id, new_symbols, new_accent_ids, Language.IPA))

    return _prepare_data(processed_data)
コード例 #10
0
def sents_convert_to_ipa(sentences: SentenceList, text_symbols: SymbolIdDict, ignore_tones: bool, ignore_arcs: bool) -> Tuple[SymbolIdDict, SentenceList]:

  sents_new_symbols = []
  for sentence in sentences.items(True):
    new_symbols, new_accent_ids = symbols_convert_to_ipa(
      symbols=text_symbols.get_symbols(sentence.serialized_symbols),
      lang=sentence.lang,
      accent_ids=deserialize_list(sentence.serialized_accents),
      ignore_arcs=ignore_arcs,
      ignore_tones=ignore_tones
    )
    assert len(new_symbols) == len(new_accent_ids)
    sentence.lang = Language.IPA
    sentence.serialized_accents = serialize_list(new_accent_ids)
    sents_new_symbols.append(new_symbols)
    assert len(sentence.get_accent_ids()) == len(new_symbols)

  return update_symbols_and_text(sentences, sents_new_symbols)
コード例 #11
0
def sents_map(sentences: SentenceList, text_symbols: SymbolIdDict, symbols_map: SymbolsMap, ignore_arcs: bool) -> Tuple[SymbolIdDict, SentenceList]:
  sents_new_symbols = []
  result = SentenceList()
  new_sent_id = 0
  for sentence in sentences.items():
    symbols = text_symbols.get_symbols(sentence.serialized_symbols)
    accent_ids = deserialize_list(sentence.serialized_accents)

    mapped_symbols = symbols_map.apply_to_symbols(symbols)

    text = SymbolIdDict.symbols_to_text(mapped_symbols)
    # a resulting empty text would make no problems
    sents = split_sentences(text, sentence.lang)
    for new_sent_text in sents:
      new_symbols = text_to_symbols(
        new_sent_text,
        lang=sentence.lang,
        ignore_tones=False,
        ignore_arcs=ignore_arcs
      )

      if len(accent_ids) > 0:
        new_accent_ids = [accent_ids[0]] * len(new_symbols)
      else:
        new_accent_ids = []

      assert len(new_accent_ids) == len(new_symbols)

      new_sent_id += 1
      tmp = Sentence(
        sent_id=new_sent_id,
        text=new_sent_text,
        lang=sentence.lang,
        serialized_accents=serialize_list(new_accent_ids),
        serialized_symbols=""
      )
      sents_new_symbols.append(new_symbols)

      assert len(tmp.get_accent_ids()) == len(new_symbols)
      result.append(tmp)

  return update_symbols_and_text(result, sents_new_symbols)
コード例 #12
0
 def deserialize_symbol_ids(serialized_str: str):
     return deserialize_list(serialized_str)
コード例 #13
0
 def get_accent_ids(self):
   return deserialize_list(self.serialized_accents)
コード例 #14
0
 def get_symbol_ids(self):
   return deserialize_list(self.serialized_symbols)