def update_symbols_and_text(sentences: SentenceList, sents_new_symbols: List[List[str]]):
  symbols = SymbolIdDict.init_from_symbols(get_unique_items(sents_new_symbols))
  for sentence, new_symbols in zip(sentences.items(), sents_new_symbols):
    sentence.serialized_symbols = symbols.get_serialized_ids(new_symbols)
    sentence.text = SymbolIdDict.symbols_to_text(new_symbols)
    assert len(sentence.get_symbol_ids()) == len(new_symbols)
    assert len(sentence.get_accent_ids()) == len(new_symbols)
  return symbols, sentences
示例#2
0
 def test_sims_to_csv(self):
   emb = torch.ones(size=(2, 3))
   torch.nn.init.zeros_(emb[0])
   sims = get_similarities(emb)
   symbols = SymbolIdDict.init_from_symbols({"2", "1"})
   res = sims_to_csv(sims, symbols)
   self.assertEqual(2, len(res.index))
   self.assertListEqual(['2', '<=>', '1', 'nan'], list(res.values[0]))
   self.assertListEqual(['1', '<=>', '2', 'nan'], list(res.values[1]))
示例#3
0
  def test_plot_embeddings(self):
    emb = torch.ones(size=(2, 3))
    symbols = SymbolIdDict.init_from_symbols({})
    text, plot2d, plot3d = plot_embeddings(symbols, emb, logging.getLogger())

    self.assertEqual(2, len(text.index))
    self.assertListEqual(['a', '<=>', 'b', '1.00'], list(text.values[0]))
    self.assertListEqual(['b', '<=>', 'a', '1.00'], list(text.values[1]))
    self.assertEqual("2D-Embeddings", plot2d.layout.title.text)
    self.assertEqual("3D-Embeddings", plot3d.layout.title.text)
示例#4
0
  def test_get_similarities_is_sorted(self):
    symbols = SymbolIdDict.init_from_symbols({"a", "b", "c"})
    emb = np.zeros(shape=(3, 3))
    emb[symbols.get_id("a")] = [0.5, 1.0, 0]
    emb[symbols.get_id("b")] = [1.0, 0.6, 0]
    emb[symbols.get_id("c")] = [1.0, 0.5, 0]

    sims = get_similarities(emb)

    self.assertEqual(3, len(sims))
    self.assertEqual(symbols.get_id("b"), sims[symbols.get_id("a")][0][0])
    self.assertEqual(symbols.get_id("c"), sims[symbols.get_id("a")][1][0])
    self.assertEqual(symbols.get_id("b"), sims[symbols.get_id("c")][0][0])
    self.assertEqual(symbols.get_id("a"), sims[symbols.get_id("c")][1][0])
    self.assertEqual(symbols.get_id("c"), sims[symbols.get_id("b")][0][0])
    self.assertEqual(symbols.get_id("a"), sims[symbols.get_id("b")][1][0])
def add_text(text: str, lang: Language, logger: Logger) -> Tuple[SymbolIdDict, SentenceList]:
  res = SentenceList()
  # each line is at least regarded as one sentence.
  lines = text.split("\n")

  all_sents = []
  for line in lines:
    sents = text_to_sentences(
      text=line,
      lang=lang,
      logger=logger,
    )
    all_sents.extend(sents)

  default_accent_id = 0
  ipa_settings = IPAExtractionSettings(
    ignore_tones=False,
    ignore_arcs=False,
    replace_unknown_ipa_by=DEFAULT_PADDING_SYMBOL,
  )

  sents_symbols: List[List[str]] = [text_to_symbols(
    sent,
    lang=lang,
    ipa_settings=ipa_settings,
    logger=logger,
  ) for sent in all_sents]
  symbols = SymbolIdDict.init_from_symbols(get_unique_items(sents_symbols))
  for i, sent_symbols in enumerate(sents_symbols):
    sentence = Sentence(
      sent_id=i + 1,
      lang=lang,
      serialized_symbols=symbols.get_serialized_ids(sent_symbols),
      serialized_accents=serialize_list([default_accent_id] * len(sent_symbols)),
      text=SymbolIdDict.symbols_to_text(sent_symbols),
      original_text=SymbolIdDict.symbols_to_text(sent_symbols),
      orig_lang=lang,
    )
    res.append(sentence)
  return symbols, res
示例#6
0
    def test_preprocess(self):
        datasets = {
            "ljs":
            (DsDataList([
                DsData(0, "basename0", "speaker0", 0, "text0", "wavpath0",
                       Language.ENG),
                DsData(1, "basename0", "speaker0", 0, "text0", "wavpath0",
                       Language.ENG),
                DsData(2, "basename0", "speaker1", 1, "text0", "wavpath0",
                       Language.ENG),
            ]),
             TextDataList([
                 TextData(0, "text_pre0", "1,2,3", Language.CHN),
                 TextData(1, "text_pre0", "1,2,3", Language.CHN),
                 TextData(2, "text_pre0", "1,2,3", Language.CHN),
             ]),
             WavDataList([
                 WavData(0, "wavpath_pre0", 7.89, 22050),
                 WavData(1, "wavpath_pre0", 7.89, 22050),
                 WavData(2, "wavpath_pre0", 7.89, 22050),
             ]),
             MelDataList([
                 MelData(0, "melpath_pre0", 80),
                 MelData(1, "melpath_pre0", 80),
                 MelData(2, "melpath_pre0", 80),
             ]), ["speaker0",
                  "speaker1"], SymbolIdDict.init_from_symbols({"a", "b"})),
            "thchs": (DsDataList([
                DsData(0, "basename0", "speaker0", 0, "text0", "wavpath0",
                       Language.ENG),
                DsData(1, "basename0", "speaker1", 1, "text0", "wavpath0",
                       Language.ENG),
                DsData(2, "basename0", "speaker1", 1, "text0", "wavpath0",
                       Language.ENG),
            ]),
                      TextDataList([
                          TextData(0, "text_pre0", "1,2,3", Language.CHN),
                          TextData(1, "text_pre0", "1,2,3", Language.CHN),
                          TextData(2, "text_pre0", "1,2,3", Language.CHN),
                      ]),
                      WavDataList([
                          WavData(0, "wavpath_pre0", 7.89, 22050),
                          WavData(1, "wavpath_pre0", 7.89, 22050),
                          WavData(2, "wavpath_pre0", 7.89, 22050),
                      ]),
                      MelDataList([
                          MelData(0, "melpath_pre0", 80),
                          MelData(1, "melpath_pre0", 80),
                          MelData(2, "melpath_pre0", 80),
                      ]), ["speaker0", "speaker1"],
                      SymbolIdDict.init_from_symbols({"b", "c"}))
        }
        ds_speakers = {
            ("ljs", "speaker0"),
            ("thchs", "speaker1"),
        }

        whole, conv, speakers_id_dict = merge(datasets,
                                              ds_speakers,
                                              speakers_as_accents=False)

        self.assertEqual(4, len(whole))
        self.assertEqual(set({"a", "b", "c"}), set(conv.get_all_symbols()))
        # TODO
        self.assertEqual("1,2,3", whole[0].serialized_symbol_ids)
        self.assertEqual("1,2,3", whole[1].serialized_symbol_ids)
        self.assertEqual("1,3,4", whole[2].serialized_symbol_ids)
        self.assertEqual("1,3,4", whole[3].serialized_symbol_ids)
        self.assertEqual(["ljs,speaker0", "thchs,speaker1"],
                         speakers_id_dict.get_speakers())
示例#7
0
    def test_ds_dataset_to_merged_dataset(self):
        ds = DsDataset(
            name="ds",
            accent_ids=AccentsDict.init_from_accents({"c", "d"}),
            data=DsDataList([
                DsData(
                    entry_id=1,
                    basename="basename",
                    speaker_id=2,
                    serialized_symbols="0,1",
                    gender=Gender.MALE,
                    lang=Language.GER,
                    serialized_accents="0,0",
                    wav_path="wav",
                    speaker_name="speaker",
                    text="text",
                )
            ]),
            mels=MelDataList([
                MelData(
                    entry_id=1,
                    relative_mel_path="mel",
                    n_mel_channels=5,
                )
            ]),
            speakers=SpeakersDict.fromlist(["sp1", "sp2", "sp3"]),
            symbol_ids=SymbolIdDict.init_from_symbols({"a", "b"}),
            texts=TextDataList([
                TextData(
                    entry_id=1,
                    lang=Language.IPA,
                    serialized_accent_ids="1,0",
                    serialized_symbol_ids="1,1",
                    text="text_new",
                )
            ]),
            wavs=WavDataList([
                WavData(
                    entry_id=1,
                    duration=15,
                    sr=22050,
                    relative_wav_path="wav_new",
                )
            ]),
        )

        res = ds_dataset_to_merged_dataset(ds)

        self.assertEqual("ds", res.name)
        self.assertEqual(2, len(res.accent_ids))
        self.assertEqual(3, len(res.speakers))
        self.assertEqual(2, len(res.symbol_ids))
        self.assertEqual(1, len(res.items()))
        first_entry = res.items()[0]
        self.assertEqual(1, first_entry.entry_id)
        self.assertEqual(Gender.MALE, first_entry.gender)
        self.assertEqual("basename", first_entry.basename)
        self.assertEqual(2, first_entry.speaker_id)
        self.assertEqual(Language.IPA, first_entry.lang)
        self.assertEqual("1,0", first_entry.serialized_accents)
        self.assertEqual("1,1", first_entry.serialized_symbols)
        self.assertEqual("wav_new", first_entry.wav_path)
        self.assertEqual(15, first_entry.duration)
        self.assertEqual(22050, first_entry.sampling_rate)
        self.assertEqual("mel", first_entry.mel_path)
        self.assertEqual(5, first_entry.n_mel_channels)