예제 #1
0
    def test_merge_prepared_data(self):
        prep_list = [
            (PreparedDataList([
                PreparedData(0, 1, "", "", "", 0, "0,1,2", 0, 0, "", 0, ""),
            ]), SymbolIdDict({
                0: (0, "a"),
                1: (0, "b"),
                2: (0, "c"),
            })),
            (PreparedDataList([
                PreparedData(0, 2, "", "", "", 0, "0,1,2", 0, 0, "", 0, ""),
            ]), SymbolIdDict({
                0: (0, "b"),
                1: (0, "a"),
                2: (0, "d"),
            })),
        ]

        res, conv = merge_prepared_data(prep_list)

        self.assertEqual(6, len(conv))
        self.assertEqual("todo", conv.get_symbol(0))
        self.assertEqual("todo", conv.get_symbol(1))
        self.assertEqual("a", conv.get_symbol(2))
        self.assertEqual("b", conv.get_symbol(3))
        self.assertEqual("c", conv.get_symbol(4))
        self.assertEqual("d", conv.get_symbol(5))

        self.assertEqual(2, len(res))
        self.assertEqual(0, res[0].i)
        self.assertEqual(1, res[1].i)
        self.assertEqual(1, res[0].entry_id)
        self.assertEqual(2, res[1].entry_id)
        self.assertEqual("2,3,4", res[0].serialized_symbol_ids)
        self.assertEqual("3,2,5", res[1].serialized_symbol_ids)
예제 #2
0
def update_symbols_and_text(sentences: SentenceList, sents_new_symbols: List[List[str]]):
  symbols = SymbolIdDict.init_from_symbols(get_unique_items(sents_new_symbols))
  for sentence, new_symbols in zip(sentences.items(), sents_new_symbols):
    sentence.serialized_symbols = symbols.get_serialized_ids(new_symbols)
    sentence.text = SymbolIdDict.symbols_to_text(new_symbols)
    assert len(sentence.get_symbol_ids()) == len(new_symbols)
    assert len(sentence.get_accent_ids()) == len(new_symbols)
  return symbols, sentences
예제 #3
0
def sents_map(sentences: SentenceList, text_symbols: SymbolIdDict, symbols_map: SymbolsMap, ignore_arcs: bool, logger: Logger) -> Tuple[SymbolIdDict, SentenceList]:
  sents_new_symbols = []
  result = SentenceList()
  new_sent_id = 0

  ipa_settings = IPAExtractionSettings(
    ignore_tones=False,
    ignore_arcs=ignore_arcs,
    replace_unknown_ipa_by=DEFAULT_PADDING_SYMBOL,
  )

  for sentence in sentences.items():
    symbols = text_symbols.get_symbols(sentence.serialized_symbols)
    accent_ids = deserialize_list(sentence.serialized_accents)

    mapped_symbols = symbols_map.apply_to_symbols(symbols)

    text = SymbolIdDict.symbols_to_text(mapped_symbols)
    # a resulting empty text would make no problems
    sents = text_to_sentences(
      text=text,
      lang=sentence.lang,
      logger=logger,
    )

    for new_sent_text in sents:
      new_symbols = text_to_symbols(
        new_sent_text,
        lang=sentence.lang,
        ipa_settings=ipa_settings,
        logger=logger,
      )

      if len(accent_ids) > 0:
        new_accent_ids = [accent_ids[0]] * len(new_symbols)
      else:
        new_accent_ids = []

      assert len(new_accent_ids) == len(new_symbols)

      new_sent_id += 1
      tmp = Sentence(
        sent_id=new_sent_id,
        text=new_sent_text,
        lang=sentence.lang,
        orig_lang=sentence.orig_lang,
        # this is not correct but nearest possible currently
        original_text=sentence.original_text,
        serialized_accents=serialize_list(new_accent_ids),
        serialized_symbols=""
      )
      sents_new_symbols.append(new_symbols)

      assert len(tmp.get_accent_ids()) == len(new_symbols)
      result.append(tmp)

  return update_symbols_and_text(result, sents_new_symbols)
예제 #4
0
 def replace_unknown_symbols(self, model_symbols: SymbolIdDict, logger: Logger) -> bool:
   unknown_symbols_exist = False
   for sentence in self.items():
     if model_symbols.has_unknown_symbols(sentence.symbols):
       sentence.symbols = model_symbols.replace_unknown_symbols_with_pad(
         sentence.symbols, pad_symbol=DEFAULT_PADDING_SYMBOL)
       text = SymbolIdDict.symbols_to_text(sentence.symbols)
       logger.info(f"Sentence {sentence.sent_id} contains unknown symbols: {text}")
       unknown_symbols_exist = True
       assert len(sentence.symbols) == len(sentence.accents)
   return unknown_symbols_exist
예제 #5
0
def get_mapped_symbol_weights(model_symbols: SymbolIdDict,
                              trained_weights: Tensor,
                              trained_symbols: SymbolIdDict,
                              custom_mapping: Optional[SymbolsMap],
                              hparams: HParams, logger: Logger) -> Tensor:
    symbols_match_not_model = trained_weights.shape[0] != len(trained_symbols)
    if symbols_match_not_model:
        logger.exception(
            f"Weights mapping: symbol space from pretrained model ({trained_weights.shape[0]}) did not match amount of symbols ({len(trained_symbols)})."
        )
        raise Exception()

    if custom_mapping is None:
        symbols_map = SymbolsMap.from_intersection(
            map_to=model_symbols.get_all_symbols(),
            map_from=trained_symbols.get_all_symbols(),
        )
    else:
        symbols_map = custom_mapping
        symbols_map.remove_unknown_symbols(
            known_to_symbol=model_symbols.get_all_symbols(),
            known_from_symbols=trained_symbols.get_all_symbols())

    # Remove all empty mappings
    symbols_wo_mapping = symbols_map.get_symbols_with_empty_mapping()
    symbols_map.pop_batch(symbols_wo_mapping)

    symbols_id_map = symbols_map.convert_to_symbols_ids_map(
        to_symbols=model_symbols,
        from_symbols=trained_symbols,
    )

    model_symbols_id_map = symbols_ids_map_to_model_symbols_ids_map(
        symbols_id_map,
        hparams.n_accents,
        n_symbols=hparams.n_symbols,
        accents_use_own_symbols=hparams.accents_use_own_symbols)

    model_weights = get_symbol_weights(hparams)

    map_weights(model_symbols_id_map=model_symbols_id_map,
                model_weights=model_weights,
                trained_weights=trained_weights,
                logger=logger)

    not_existing_symbols = model_symbols.get_all_symbols() - symbols_map.keys()
    no_mapping = symbols_wo_mapping | not_existing_symbols
    if len(no_mapping) > 0:
        logger.warning(f"Following symbols were not mapped: {no_mapping}")
    else:
        logger.info("All symbols were mapped.")

    return model_weights
예제 #6
0
def update_symbols(data: MergedDataset, symbols: SymbolIdDict) -> SymbolIdDict:
    new_symbols: Set[str] = {
        x
        for y in data.items()
        for x in symbols.get_symbols(y.serialized_symbol_ids)
    }
    new_symbol_ids = SymbolIdDict.init_from_symbols_with_pad(
        new_symbols, pad_symbol=DEFAULT_PADDING_SYMBOL)
    if new_symbol_ids.get_all_symbols() != symbols.get_all_symbols():
        for entry in data.items():
            original_symbols = symbols.get_symbols(entry.serialized_symbol_ids)
            entry.serialized_symbol_ids = new_symbol_ids.get_serialized_ids(
                original_symbols)
    return new_symbol_ids
예제 #7
0
def sents_convert_to_ipa(sentences: SentenceList, text_symbols: SymbolIdDict, ignore_tones: bool, ignore_arcs: bool, mode: Optional[EngToIpaMode], consider_ipa_annotations: bool, logger: Logger) -> Tuple[SymbolIdDict, SentenceList]:

  sents_new_symbols = []
  for sentence in sentences.items(True):
    if sentence.lang == Language.ENG and mode is None:
      ex = "Please specify the ipa conversion mode."
      logger.exception(ex)
      raise Exception(ex)
    new_symbols, new_accent_ids = symbols_to_ipa(
      symbols=text_symbols.get_symbols(sentence.serialized_symbols),
      lang=sentence.lang,
      accent_ids=deserialize_list(sentence.serialized_accents),
      ignore_arcs=ignore_arcs,
      ignore_tones=ignore_tones,
      mode=mode,
      replace_unknown_with=DEFAULT_PADDING_SYMBOL,
      consider_ipa_annotations=consider_ipa_annotations,
      logger=logger,
    )
    assert len(new_symbols) == len(new_accent_ids)
    sentence.lang = Language.IPA
    sentence.serialized_accents = serialize_list(new_accent_ids)
    sents_new_symbols.append(new_symbols)
    assert len(sentence.get_accent_ids()) == len(new_symbols)

  return update_symbols_and_text(sentences, sents_new_symbols)
예제 #8
0
 def get_formatted(self, symbol_id_dict: SymbolIdDict, accent_id_dict: AccentsDict, pairs_per_line=170, space_length=0):
   return get_formatted_core(
     sent_id=self.sent_id,
     symbols=symbol_id_dict.get_symbols(self.serialized_symbols),
     accent_ids=self.get_accent_ids(),
     accent_id_dict=accent_id_dict,
     space_length=space_length,
     max_pairs_per_line=pairs_per_line
   )
예제 #9
0
 def test_sims_to_csv(self):
   emb = torch.ones(size=(2, 3))
   torch.nn.init.zeros_(emb[0])
   sims = get_similarities(emb)
   symbols = SymbolIdDict.init_from_symbols({"2", "1"})
   res = sims_to_csv(sims, symbols)
   self.assertEqual(2, len(res.index))
   self.assertListEqual(['2', '<=>', '1', 'nan'], list(res.values[0]))
   self.assertListEqual(['1', '<=>', '2', 'nan'], list(res.values[1]))
예제 #10
0
  def test_plot_embeddings(self):
    emb = torch.ones(size=(2, 3))
    symbols = SymbolIdDict.init_from_symbols({})
    text, plot2d, plot3d = plot_embeddings(symbols, emb, logging.getLogger())

    self.assertEqual(2, len(text.index))
    self.assertListEqual(['a', '<=>', 'b', '1.00'], list(text.values[0]))
    self.assertListEqual(['b', '<=>', 'a', '1.00'], list(text.values[1]))
    self.assertEqual("2D-Embeddings", plot2d.layout.title.text)
    self.assertEqual("3D-Embeddings", plot3d.layout.title.text)
예제 #11
0
def get_ngram_rarity(data: PreparedDataList, corpus: PreparedDataList,
                     symbols: SymbolIdDict,
                     ngram: int) -> OrderedDictType[int, float]:
    data_symbols_dict = OrderedDict({
        x.entry_id: symbols.get_symbols(x.serialized_symbol_ids)
        for x in data.items()
    })
    corpus_symbols_dict = OrderedDict({
        x.entry_id: symbols.get_symbols(x.serialized_symbol_ids)
        for x in corpus.items()
    })

    rarity = get_rarity_ngrams(
        data=data_symbols_dict,
        corpus=corpus_symbols_dict,
        n_gram=ngram,
        ignore_symbols=None,
    )

    return rarity
예제 #12
0
def filter_symbols(data: MergedDataset, symbols: SymbolIdDict,
                   accent_ids: AccentsDict, speakers: SpeakersDict,
                   allowed_symbol_ids: Set[int],
                   logger: Logger) -> MergedDatasetContainer:
    # maybe check all symbol ids are valid before
    allowed_symbols = [symbols.get_symbol(x) for x in allowed_symbol_ids]
    not_allowed_symbols = [
        symbols.get_symbol(x) for x in symbols.get_all_symbol_ids()
        if x not in allowed_symbol_ids
    ]
    logger.info(
        f"Keep utterances with these symbols: {' '.join(allowed_symbols)}")
    logger.info(
        f"Remove utterances with these symbols: {' '.join(not_allowed_symbols)}"
    )
    logger.info("Statistics before filtering:")
    log_stats(data, symbols, accent_ids, speakers, logger)
    result = MergedDataset([
        x for x in data.items() if contains_only_allowed_symbols(
            deserialize_list(x.serialized_symbol_ids), allowed_symbol_ids)
    ])
    if len(result) > 0:
        logger.info(
            f"Removed {len(data) - len(result)} from {len(data)} total entries and got {len(result)} entries ({len(result)/len(data)*100:.2f}%)."
        )
    else:
        logger.info("Removed all utterances!")
    new_symbol_ids = update_symbols(result, symbols)
    new_accent_ids = update_accents(result, accent_ids)
    new_speaker_ids = update_speakers(result, speakers)
    logger.info("Statistics after filtering:")
    log_stats(result, new_symbol_ids, new_accent_ids, new_speaker_ids, logger)

    res = MergedDatasetContainer(
        name=None,
        data=result,
        accent_ids=new_accent_ids,
        speaker_ids=new_speaker_ids,
        symbol_ids=new_symbol_ids,
    )
    return res
예제 #13
0
 def from_instances(cls, model: Tacotron2, optimizer: Adam,
                    hparams: HParams, iteration: int, symbols: SymbolIdDict,
                    accents: AccentsDict, speakers: SpeakersDict):
     result = cls(state_dict=model.state_dict(),
                  optimizer=optimizer.state_dict(),
                  learning_rate=hparams.learning_rate,
                  iteration=iteration,
                  hparams=asdict(hparams),
                  symbols=symbols.raw(),
                  accents=accents.raw(),
                  speakers=speakers.raw())
     return result
예제 #14
0
def sims_to_csv(sims: Dict[int, List[Tuple[int, float]]],
                symbols: SymbolIdDict) -> pd.DataFrame:
    lines = []
    assert len(sims) == len(symbols)
    for symbol_id, similarities in sims.items():
        sims = [f"{symbols.get_symbol(symbol_id)}", "<=>"]
        for other_symbol_id, similarity in similarities:
            sims.append(symbols.get_symbol(other_symbol_id))
            sims.append(f"{similarity:.2f}")
        lines.append(sims)
    df = pd.DataFrame(lines)
    return df
예제 #15
0
 def from_sentences(cls, sentences: SentenceList, accents: AccentsDict, symbols: SymbolIdDict):
   res = cls()
   for sentence in sentences.items():
     infer_sent = InferSentence(
       sent_id=sentence.sent_id,
       symbols=symbols.get_symbols(sentence.serialized_symbols),
       accents=accents.get_accents(sentence.serialized_accents),
       original_text=sentence.original_text,
     )
     assert len(infer_sent.symbols) == len(infer_sent.accents)
     res.append(infer_sent)
   return res
예제 #16
0
def add_text(text: str, lang: Language, logger: Logger) -> Tuple[SymbolIdDict, SentenceList]:
  res = SentenceList()
  # each line is at least regarded as one sentence.
  lines = text.split("\n")

  all_sents = []
  for line in lines:
    sents = text_to_sentences(
      text=line,
      lang=lang,
      logger=logger,
    )
    all_sents.extend(sents)

  default_accent_id = 0
  ipa_settings = IPAExtractionSettings(
    ignore_tones=False,
    ignore_arcs=False,
    replace_unknown_ipa_by=DEFAULT_PADDING_SYMBOL,
  )

  sents_symbols: List[List[str]] = [text_to_symbols(
    sent,
    lang=lang,
    ipa_settings=ipa_settings,
    logger=logger,
  ) for sent in all_sents]
  symbols = SymbolIdDict.init_from_symbols(get_unique_items(sents_symbols))
  for i, sent_symbols in enumerate(sents_symbols):
    sentence = Sentence(
      sent_id=i + 1,
      lang=lang,
      serialized_symbols=symbols.get_serialized_ids(sent_symbols),
      serialized_accents=serialize_list([default_accent_id] * len(sent_symbols)),
      text=SymbolIdDict.symbols_to_text(sent_symbols),
      original_text=SymbolIdDict.symbols_to_text(sent_symbols),
      orig_lang=lang,
    )
    res.append(sentence)
  return symbols, res
예제 #17
0
def sents_accent_template(sentences: SentenceList, text_symbols: SymbolIdDict, accent_ids: AccentsDict) -> AccentedSymbolList:
  res = AccentedSymbolList()
  for i, sent in enumerate(sentences.items()):
    symbols = text_symbols.get_symbols(sent.serialized_symbols)
    accents = accent_ids.get_accents(sent.serialized_accents)
    for j, symbol_accent in enumerate(zip(symbols, accents)):
      symbol, accent = symbol_accent
      accented_symbol = AccentedSymbol(
        position=f"{i}-{j}",
        symbol=symbol,
        accent=accent
      )
      res.append(accented_symbol)
  return res
예제 #18
0
def sents_normalize(sentences: SentenceList, text_symbols: SymbolIdDict, logger: Logger) -> Tuple[SymbolIdDict, SentenceList]:
  # Maybe add info if something was unknown
  sents_new_symbols = []
  for sentence in sentences.items():
    new_symbols, new_accent_ids = symbols_normalize(
      symbols=text_symbols.get_symbols(sentence.serialized_symbols),
      lang=sentence.lang,
      accent_ids=deserialize_list(sentence.serialized_accents),
      logger=logger,
    )
    # TODO: check if new sentences resulted and then split them.
    sentence.serialized_accents = serialize_list(new_accent_ids)
    sents_new_symbols.append(new_symbols)

  return update_symbols_and_text(sentences, sents_new_symbols)
예제 #19
0
def get_ngram_stats_df(symbols: SymbolIdDict, trainset: PreparedDataList,
                       valset: PreparedDataList, testset: PreparedDataList,
                       restset: PreparedDataList, n: int, logger: Logger):
    total_set = get_total_set(trainset, valset, testset, restset)
    logger.info(f"Getting all {n}-gram stats...")
    tot_symbols = [
        symbols.get_symbols(x.serialized_symbol_ids)
        for x in total_set.items()
    ]
    tot_symbols_one_gram = [get_ngrams(x, n=n) for x in tot_symbols]
    symbol_order = list(sorted({x for y in tot_symbols_one_gram for x in y}))

    ngram_stats = _get_ngram_stats_df_core(
        symbol_order=symbol_order,
        symbols=symbols,
        trainset=trainset,
        valset=valset,
        testset=testset,
        restset=restset,
        n=n,
        logger=logger,
    )
    occurences_count_df, occurrences_percent_df, occurrences_distribution_percent_df, utterance_occurrences_count_df, utterance_occurrences_percent_df, uniform_occurrences_count_df, uniform_occurrences_percent_df = ngram_stats

    symbol_dfs = []
    symbol_dfs.append(occurences_count_df)
    symbol_dfs.append(occurrences_percent_df)
    symbol_dfs.append(occurrences_distribution_percent_df)
    symbol_dfs.append(utterance_occurrences_count_df)
    symbol_dfs.append(utterance_occurrences_percent_df)
    symbol_dfs.append(uniform_occurrences_count_df)
    symbol_dfs.append(uniform_occurrences_percent_df)

    for i in range(1, len(symbol_dfs)):
        symbol_dfs[i] = symbol_dfs[i].loc[:, symbol_dfs[i].
                                          columns != FIRST_COL_NAME]

    symbol_stats = pd.concat(
        symbol_dfs,
        axis=1,
        join='inner',
    )

    # symbol_stats = symbol_stats.round(decimals=2)
    symbol_stats = symbol_stats.sort_values(by='TOTAL_OCCURRENCES_COUNT',
                                            ascending=False)
    print(symbol_stats)
    return symbol_stats
예제 #20
0
def log_stats(data: MergedDataset, symbols: SymbolIdDict,
              accent_ids: AccentsDict, speakers: SpeakersDict, logger: Logger):
    logger.info(
        f"Speakers ({len(speakers)}): {', '.join(sorted(speakers.get_all_speakers()))}"
    )
    logger.info(
        f"Symbols ({len(symbols)}): {' '.join(sorted(symbols.get_all_symbols()))}"
    )
    logger.info(
        f"Accents ({len(accent_ids)}): {', '.join(sorted(accent_ids.get_all_accents()))}"
    )
    logger.info(
        f"Entries ({len(data)}): {data.get_total_duration_s()/60:.2f}m")
    symbol_counter = get_counter(
        [symbols.get_symbols(x.serialized_symbol_ids) for x in data.items()])
    logger.info(symbol_counter)
예제 #21
0
    def make_common_symbol_ids(self) -> SymbolIdDict:
        all_symbols: Set[str] = set()
        for ds in self.data:
            all_symbols |= ds.symbol_ids.get_all_symbols()
        new_symbol_ids = SymbolIdDict.init_from_symbols_with_pad(
            all_symbols, pad_symbol=DEFAULT_PADDING_SYMBOL)

        for ds in self.data:
            for entry in ds.data.items():
                original_symbols = ds.symbol_ids.get_symbols(
                    entry.serialized_symbol_ids)
                entry.serialized_symbol_ids = new_symbol_ids.get_serialized_ids(
                    original_symbols)
            ds.symbol_ids = new_symbol_ids

        return new_symbol_ids
예제 #22
0
def plot_embeddings(
        symbols: SymbolIdDict, emb: torch.Tensor,
        logger: Logger) -> Tuple[pd.DataFrame, go.Figure, go.Figure]:
    assert emb.shape[0] == len(symbols)

    logger.info(f"Emb size {emb.shape}")
    logger.info(f"Sym len {len(symbols)}")

    sims = get_similarities(emb.numpy())
    df = sims_to_csv(sims, symbols)
    all_symbols_sorted = [symbols.get_symbol(x) for x in range(len(symbols))]
    emb_normed = norm2emb(emb)
    fig_2d = emb_plot_2d(emb_normed, all_symbols_sorted)
    fig_3d = emb_plot_3d(emb_normed, all_symbols_sorted)

    return df, fig_2d, fig_3d
예제 #23
0
  def test_get_similarities_is_sorted(self):
    symbols = SymbolIdDict.init_from_symbols({"a", "b", "c"})
    emb = np.zeros(shape=(3, 3))
    emb[symbols.get_id("a")] = [0.5, 1.0, 0]
    emb[symbols.get_id("b")] = [1.0, 0.6, 0]
    emb[symbols.get_id("c")] = [1.0, 0.5, 0]

    sims = get_similarities(emb)

    self.assertEqual(3, len(sims))
    self.assertEqual(symbols.get_id("b"), sims[symbols.get_id("a")][0][0])
    self.assertEqual(symbols.get_id("c"), sims[symbols.get_id("a")][1][0])
    self.assertEqual(symbols.get_id("b"), sims[symbols.get_id("c")][0][0])
    self.assertEqual(symbols.get_id("a"), sims[symbols.get_id("c")][1][0])
    self.assertEqual(symbols.get_id("c"), sims[symbols.get_id("b")][0][0])
    self.assertEqual(symbols.get_id("a"), sims[symbols.get_id("b")][1][0])
예제 #24
0
def convert_v1_to_v2_model(old_model_path: str,
                           custom_hparams: Optional[Dict[str, str]],
                           speakers: SpeakersDict, accents: AccentsDict,
                           symbols: SymbolIdDict):
    checkpoint_dict = torch.load(old_model_path, map_location='cpu')
    hparams = HParams(n_speakers=len(speakers),
                      n_accents=len(accents),
                      n_symbols=len(symbols))

    hparams = overwrite_custom_hparams(hparams, custom_hparams)

    chp = CheckpointTacotron(state_dict=checkpoint_dict["state_dict"],
                             optimizer=checkpoint_dict["optimizer"],
                             learning_rate=checkpoint_dict["learning_rate"],
                             iteration=checkpoint_dict["iteration"] + 1,
                             hparams=asdict(hparams),
                             speakers=speakers.raw(),
                             symbols=symbols.raw(),
                             accents=accents.raw())

    new_model_path = f"{old_model_path}_{get_pytorch_filename(chp.iteration)}"

    chp.save(new_model_path, logging.getLogger())
예제 #25
0
def _get_ngram_stats_df_core(symbol_order: List[str], symbols: SymbolIdDict,
                             trainset: PreparedDataList,
                             valset: PreparedDataList,
                             testset: PreparedDataList,
                             restset: PreparedDataList, n: int,
                             logger: Logger):
    logger.info(f"Get {n}-grams...")
    trn_symbols = [
        symbols.get_symbols(x.serialized_symbol_ids) for x in trainset.items()
    ]
    val_symbols = [
        symbols.get_symbols(x.serialized_symbol_ids) for x in valset.items()
    ]
    tst_symbols = [
        symbols.get_symbols(x.serialized_symbol_ids) for x in testset.items()
    ]
    rst_symbols = [
        symbols.get_symbols(x.serialized_symbol_ids) for x in restset.items()
    ]

    trn_symbols_one_gram = [get_ngrams(x, n=n) for x in trn_symbols]
    val_symbols_one_gram = [get_ngrams(x, n=n) for x in val_symbols]
    tst_symbols_one_gram = [get_ngrams(x, n=n) for x in tst_symbols]
    rst_symbols_one_gram = [get_ngrams(x, n=n) for x in rst_symbols]
    logger.info("Get stats...")

    occurences_count_df = get_occ_df_of_all_symbols(
        symbols=symbol_order,
        data_trn=trn_symbols_one_gram,
        data_val=val_symbols_one_gram,
        data_tst=tst_symbols_one_gram,
        data_rst=rst_symbols_one_gram,
    )
    occurences_count_df.columns = [
        FIRST_COL_NAME, 'TRAIN_OCCURRENCES_COUNT', 'VAL_OCCURRENCES_COUNT',
        'TEST_OCCURRENCES_COUNT', 'REST_OCCURRENCES_COUNT',
        'TOTAL_OCCURRENCES_COUNT'
    ]
    print(occurences_count_df)

    occurrences_percent_df = get_rel_occ_df_of_all_symbols(occurences_count_df)
    occurrences_percent_df.columns = [
        FIRST_COL_NAME, 'TRAIN_OCCURRENCES_PERCENT', 'VAL_OCCURRENCES_PERCENT',
        'TEST_OCCURRENCES_PERCENT', 'REST_OCCURRENCES_PERCENT'
    ]
    print(occurrences_percent_df)

    occurrences_distribution_percent_df = get_dist_among_other_symbols_df_of_all_symbols(
        occs_df=occurences_count_df,
        data_trn=trn_symbols_one_gram,
        data_val=val_symbols_one_gram,
        data_tst=tst_symbols_one_gram,
        data_rst=rst_symbols_one_gram,
    )
    occurrences_distribution_percent_df.columns = [
        FIRST_COL_NAME, 'TRAIN_OCCURRENCES_DISTRIBUTION_PERCENT',
        'VAL_OCCURRENCES_DISTRIBUTION_PERCENT',
        'TEST_OCCURRENCES_DISTRIBUTION_PERCENT',
        'REST_OCCURRENCES_DISTRIBUTION_PERCENT',
        'TOTAL_OCCURRENCES_DISTRIBUTION_PERCENT'
    ]
    print(occurrences_distribution_percent_df)

    utterance_occurrences_count_df = get_utter_occ_df_of_all_symbols(
        symbols=symbol_order,
        data_trn=trn_symbols_one_gram,
        data_val=val_symbols_one_gram,
        data_tst=tst_symbols_one_gram,
        data_rst=rst_symbols_one_gram,
    )
    utterance_occurrences_count_df.columns = [
        FIRST_COL_NAME, 'TRAIN_UTTERANCE_OCCURRENCES_COUNT',
        'VAL_UTTERANCE_OCCURRENCES_COUNT', 'TEST_UTTERANCE_OCCURRENCES_COUNT',
        'REST_UTTERANCE_OCCURRENCES_COUNT', 'TOTAL_UTTERANCE_OCCURRENCES_COUNT'
    ]
    print(utterance_occurrences_count_df)

    utterance_occurrences_percent_df = get_rel_utter_occ_df_of_all_symbols(
        utterance_occurrences_count_df)
    utterance_occurrences_percent_df.columns = [
        FIRST_COL_NAME, 'TRAIN_UTTERANCE_OCCURRENCES_PERCENT',
        'VAL_UTTERANCE_OCCURRENCES_PERCENT',
        'TEST_UTTERANCE_OCCURRENCES_PERCENT',
        'REST_UTTERANCE_OCCURRENCES_PERCENT'
    ]
    print(utterance_occurrences_percent_df)

    uniform_occurrences_count_df = get_uniform_distr_df_for_occs(
        symbols=symbol_order,
        occ_df=occurences_count_df,
    )
    uniform_occurrences_count_df.columns = [
        FIRST_COL_NAME, 'TRAIN_UNIFORM_OCCURRENCES_COUNT',
        'VAL_UNIFORM_OCCURRENCES_COUNT', 'TEST_UNIFORM_OCCURRENCES_COUNT',
        'REST_UNIFORM_OCCURRENCES_COUNT', 'TOTAL_UNIFORM_OCCURRENCES_COUNT'
    ]
    print(uniform_occurrences_count_df)

    uniform_occurrences_percent_df = get_rel_uniform_distr_df_for_occs(
        symbols=symbol_order, )
    uniform_occurrences_percent_df.columns = [
        FIRST_COL_NAME, 'UNIFORM_OCCURRENCES_PERCENT'
    ]
    print(uniform_occurrences_percent_df)

    return occurences_count_df, occurrences_percent_df, occurrences_distribution_percent_df, utterance_occurrences_count_df, utterance_occurrences_percent_df, uniform_occurrences_count_df, uniform_occurrences_percent_df
예제 #26
0
 def get_symbols(self) -> SymbolIdDict:
     return SymbolIdDict.from_raw(self.symbols)
예제 #27
0
def prep_data_list_to_dict_with_symbols(l: PreparedDataList, symbols: SymbolIdDict) -> OrderedDictType[int, List[str]]:
  res = OrderedDict({x.entry_id: symbols.get_symbols(x.serialized_symbol_ids) for x in l.items()})
  return res
예제 #28
0
def get_shard_size(symbols: SymbolIdDict) -> int:
  all_symbols = symbols.get_all_symbols()
  count = len(all_symbols)
  if DEFAULT_PADDING_SYMBOL in all_symbols:
    count -= 1
  return count
예제 #29
0
 def get_formatted_v2(self, symbol_id_dict: SymbolIdDict):
   return get_formatted_core_v2(
     sent_id=self.sent_id,
     symbols=symbol_id_dict.get_symbols(self.serialized_symbols),
     original_text=self.original_text,
   )
예제 #30
0
    def test_preprocess(self):
        datasets = {
            "ljs":
            (DsDataList([
                DsData(0, "basename0", "speaker0", 0, "text0", "wavpath0",
                       Language.ENG),
                DsData(1, "basename0", "speaker0", 0, "text0", "wavpath0",
                       Language.ENG),
                DsData(2, "basename0", "speaker1", 1, "text0", "wavpath0",
                       Language.ENG),
            ]),
             TextDataList([
                 TextData(0, "text_pre0", "1,2,3", Language.CHN),
                 TextData(1, "text_pre0", "1,2,3", Language.CHN),
                 TextData(2, "text_pre0", "1,2,3", Language.CHN),
             ]),
             WavDataList([
                 WavData(0, "wavpath_pre0", 7.89, 22050),
                 WavData(1, "wavpath_pre0", 7.89, 22050),
                 WavData(2, "wavpath_pre0", 7.89, 22050),
             ]),
             MelDataList([
                 MelData(0, "melpath_pre0", 80),
                 MelData(1, "melpath_pre0", 80),
                 MelData(2, "melpath_pre0", 80),
             ]), ["speaker0",
                  "speaker1"], SymbolIdDict.init_from_symbols({"a", "b"})),
            "thchs": (DsDataList([
                DsData(0, "basename0", "speaker0", 0, "text0", "wavpath0",
                       Language.ENG),
                DsData(1, "basename0", "speaker1", 1, "text0", "wavpath0",
                       Language.ENG),
                DsData(2, "basename0", "speaker1", 1, "text0", "wavpath0",
                       Language.ENG),
            ]),
                      TextDataList([
                          TextData(0, "text_pre0", "1,2,3", Language.CHN),
                          TextData(1, "text_pre0", "1,2,3", Language.CHN),
                          TextData(2, "text_pre0", "1,2,3", Language.CHN),
                      ]),
                      WavDataList([
                          WavData(0, "wavpath_pre0", 7.89, 22050),
                          WavData(1, "wavpath_pre0", 7.89, 22050),
                          WavData(2, "wavpath_pre0", 7.89, 22050),
                      ]),
                      MelDataList([
                          MelData(0, "melpath_pre0", 80),
                          MelData(1, "melpath_pre0", 80),
                          MelData(2, "melpath_pre0", 80),
                      ]), ["speaker0", "speaker1"],
                      SymbolIdDict.init_from_symbols({"b", "c"}))
        }
        ds_speakers = {
            ("ljs", "speaker0"),
            ("thchs", "speaker1"),
        }

        whole, conv, speakers_id_dict = merge(datasets,
                                              ds_speakers,
                                              speakers_as_accents=False)

        self.assertEqual(4, len(whole))
        self.assertEqual(set({"a", "b", "c"}), set(conv.get_all_symbols()))
        # TODO
        self.assertEqual("1,2,3", whole[0].serialized_symbol_ids)
        self.assertEqual("1,2,3", whole[1].serialized_symbol_ids)
        self.assertEqual("1,3,4", whole[2].serialized_symbol_ids)
        self.assertEqual("1,3,4", whole[3].serialized_symbol_ids)
        self.assertEqual(["ljs,speaker0", "thchs,speaker1"],
                         speakers_id_dict.get_speakers())