def test_merge_prepared_data(self): prep_list = [ (PreparedDataList([ PreparedData(0, 1, "", "", "", 0, "0,1,2", 0, 0, "", 0, ""), ]), SymbolIdDict({ 0: (0, "a"), 1: (0, "b"), 2: (0, "c"), })), (PreparedDataList([ PreparedData(0, 2, "", "", "", 0, "0,1,2", 0, 0, "", 0, ""), ]), SymbolIdDict({ 0: (0, "b"), 1: (0, "a"), 2: (0, "d"), })), ] res, conv = merge_prepared_data(prep_list) self.assertEqual(6, len(conv)) self.assertEqual("todo", conv.get_symbol(0)) self.assertEqual("todo", conv.get_symbol(1)) self.assertEqual("a", conv.get_symbol(2)) self.assertEqual("b", conv.get_symbol(3)) self.assertEqual("c", conv.get_symbol(4)) self.assertEqual("d", conv.get_symbol(5)) self.assertEqual(2, len(res)) self.assertEqual(0, res[0].i) self.assertEqual(1, res[1].i) self.assertEqual(1, res[0].entry_id) self.assertEqual(2, res[1].entry_id) self.assertEqual("2,3,4", res[0].serialized_symbol_ids) self.assertEqual("3,2,5", res[1].serialized_symbol_ids)
def update_symbols_and_text(sentences: SentenceList, sents_new_symbols: List[List[str]]): symbols = SymbolIdDict.init_from_symbols(get_unique_items(sents_new_symbols)) for sentence, new_symbols in zip(sentences.items(), sents_new_symbols): sentence.serialized_symbols = symbols.get_serialized_ids(new_symbols) sentence.text = SymbolIdDict.symbols_to_text(new_symbols) assert len(sentence.get_symbol_ids()) == len(new_symbols) assert len(sentence.get_accent_ids()) == len(new_symbols) return symbols, sentences
def sents_map(sentences: SentenceList, text_symbols: SymbolIdDict, symbols_map: SymbolsMap, ignore_arcs: bool, logger: Logger) -> Tuple[SymbolIdDict, SentenceList]: sents_new_symbols = [] result = SentenceList() new_sent_id = 0 ipa_settings = IPAExtractionSettings( ignore_tones=False, ignore_arcs=ignore_arcs, replace_unknown_ipa_by=DEFAULT_PADDING_SYMBOL, ) for sentence in sentences.items(): symbols = text_symbols.get_symbols(sentence.serialized_symbols) accent_ids = deserialize_list(sentence.serialized_accents) mapped_symbols = symbols_map.apply_to_symbols(symbols) text = SymbolIdDict.symbols_to_text(mapped_symbols) # a resulting empty text would make no problems sents = text_to_sentences( text=text, lang=sentence.lang, logger=logger, ) for new_sent_text in sents: new_symbols = text_to_symbols( new_sent_text, lang=sentence.lang, ipa_settings=ipa_settings, logger=logger, ) if len(accent_ids) > 0: new_accent_ids = [accent_ids[0]] * len(new_symbols) else: new_accent_ids = [] assert len(new_accent_ids) == len(new_symbols) new_sent_id += 1 tmp = Sentence( sent_id=new_sent_id, text=new_sent_text, lang=sentence.lang, orig_lang=sentence.orig_lang, # this is not correct but nearest possible currently original_text=sentence.original_text, serialized_accents=serialize_list(new_accent_ids), serialized_symbols="" ) sents_new_symbols.append(new_symbols) assert len(tmp.get_accent_ids()) == len(new_symbols) result.append(tmp) return update_symbols_and_text(result, sents_new_symbols)
def replace_unknown_symbols(self, model_symbols: SymbolIdDict, logger: Logger) -> bool: unknown_symbols_exist = False for sentence in self.items(): if model_symbols.has_unknown_symbols(sentence.symbols): sentence.symbols = model_symbols.replace_unknown_symbols_with_pad( sentence.symbols, pad_symbol=DEFAULT_PADDING_SYMBOL) text = SymbolIdDict.symbols_to_text(sentence.symbols) logger.info(f"Sentence {sentence.sent_id} contains unknown symbols: {text}") unknown_symbols_exist = True assert len(sentence.symbols) == len(sentence.accents) return unknown_symbols_exist
def get_mapped_symbol_weights(model_symbols: SymbolIdDict, trained_weights: Tensor, trained_symbols: SymbolIdDict, custom_mapping: Optional[SymbolsMap], hparams: HParams, logger: Logger) -> Tensor: symbols_match_not_model = trained_weights.shape[0] != len(trained_symbols) if symbols_match_not_model: logger.exception( f"Weights mapping: symbol space from pretrained model ({trained_weights.shape[0]}) did not match amount of symbols ({len(trained_symbols)})." ) raise Exception() if custom_mapping is None: symbols_map = SymbolsMap.from_intersection( map_to=model_symbols.get_all_symbols(), map_from=trained_symbols.get_all_symbols(), ) else: symbols_map = custom_mapping symbols_map.remove_unknown_symbols( known_to_symbol=model_symbols.get_all_symbols(), known_from_symbols=trained_symbols.get_all_symbols()) # Remove all empty mappings symbols_wo_mapping = symbols_map.get_symbols_with_empty_mapping() symbols_map.pop_batch(symbols_wo_mapping) symbols_id_map = symbols_map.convert_to_symbols_ids_map( to_symbols=model_symbols, from_symbols=trained_symbols, ) model_symbols_id_map = symbols_ids_map_to_model_symbols_ids_map( symbols_id_map, hparams.n_accents, n_symbols=hparams.n_symbols, accents_use_own_symbols=hparams.accents_use_own_symbols) model_weights = get_symbol_weights(hparams) map_weights(model_symbols_id_map=model_symbols_id_map, model_weights=model_weights, trained_weights=trained_weights, logger=logger) not_existing_symbols = model_symbols.get_all_symbols() - symbols_map.keys() no_mapping = symbols_wo_mapping | not_existing_symbols if len(no_mapping) > 0: logger.warning(f"Following symbols were not mapped: {no_mapping}") else: logger.info("All symbols were mapped.") return model_weights
def update_symbols(data: MergedDataset, symbols: SymbolIdDict) -> SymbolIdDict: new_symbols: Set[str] = { x for y in data.items() for x in symbols.get_symbols(y.serialized_symbol_ids) } new_symbol_ids = SymbolIdDict.init_from_symbols_with_pad( new_symbols, pad_symbol=DEFAULT_PADDING_SYMBOL) if new_symbol_ids.get_all_symbols() != symbols.get_all_symbols(): for entry in data.items(): original_symbols = symbols.get_symbols(entry.serialized_symbol_ids) entry.serialized_symbol_ids = new_symbol_ids.get_serialized_ids( original_symbols) return new_symbol_ids
def sents_convert_to_ipa(sentences: SentenceList, text_symbols: SymbolIdDict, ignore_tones: bool, ignore_arcs: bool, mode: Optional[EngToIpaMode], consider_ipa_annotations: bool, logger: Logger) -> Tuple[SymbolIdDict, SentenceList]: sents_new_symbols = [] for sentence in sentences.items(True): if sentence.lang == Language.ENG and mode is None: ex = "Please specify the ipa conversion mode." logger.exception(ex) raise Exception(ex) new_symbols, new_accent_ids = symbols_to_ipa( symbols=text_symbols.get_symbols(sentence.serialized_symbols), lang=sentence.lang, accent_ids=deserialize_list(sentence.serialized_accents), ignore_arcs=ignore_arcs, ignore_tones=ignore_tones, mode=mode, replace_unknown_with=DEFAULT_PADDING_SYMBOL, consider_ipa_annotations=consider_ipa_annotations, logger=logger, ) assert len(new_symbols) == len(new_accent_ids) sentence.lang = Language.IPA sentence.serialized_accents = serialize_list(new_accent_ids) sents_new_symbols.append(new_symbols) assert len(sentence.get_accent_ids()) == len(new_symbols) return update_symbols_and_text(sentences, sents_new_symbols)
def get_formatted(self, symbol_id_dict: SymbolIdDict, accent_id_dict: AccentsDict, pairs_per_line=170, space_length=0): return get_formatted_core( sent_id=self.sent_id, symbols=symbol_id_dict.get_symbols(self.serialized_symbols), accent_ids=self.get_accent_ids(), accent_id_dict=accent_id_dict, space_length=space_length, max_pairs_per_line=pairs_per_line )
def test_sims_to_csv(self): emb = torch.ones(size=(2, 3)) torch.nn.init.zeros_(emb[0]) sims = get_similarities(emb) symbols = SymbolIdDict.init_from_symbols({"2", "1"}) res = sims_to_csv(sims, symbols) self.assertEqual(2, len(res.index)) self.assertListEqual(['2', '<=>', '1', 'nan'], list(res.values[0])) self.assertListEqual(['1', '<=>', '2', 'nan'], list(res.values[1]))
def test_plot_embeddings(self): emb = torch.ones(size=(2, 3)) symbols = SymbolIdDict.init_from_symbols({}) text, plot2d, plot3d = plot_embeddings(symbols, emb, logging.getLogger()) self.assertEqual(2, len(text.index)) self.assertListEqual(['a', '<=>', 'b', '1.00'], list(text.values[0])) self.assertListEqual(['b', '<=>', 'a', '1.00'], list(text.values[1])) self.assertEqual("2D-Embeddings", plot2d.layout.title.text) self.assertEqual("3D-Embeddings", plot3d.layout.title.text)
def get_ngram_rarity(data: PreparedDataList, corpus: PreparedDataList, symbols: SymbolIdDict, ngram: int) -> OrderedDictType[int, float]: data_symbols_dict = OrderedDict({ x.entry_id: symbols.get_symbols(x.serialized_symbol_ids) for x in data.items() }) corpus_symbols_dict = OrderedDict({ x.entry_id: symbols.get_symbols(x.serialized_symbol_ids) for x in corpus.items() }) rarity = get_rarity_ngrams( data=data_symbols_dict, corpus=corpus_symbols_dict, n_gram=ngram, ignore_symbols=None, ) return rarity
def filter_symbols(data: MergedDataset, symbols: SymbolIdDict, accent_ids: AccentsDict, speakers: SpeakersDict, allowed_symbol_ids: Set[int], logger: Logger) -> MergedDatasetContainer: # maybe check all symbol ids are valid before allowed_symbols = [symbols.get_symbol(x) for x in allowed_symbol_ids] not_allowed_symbols = [ symbols.get_symbol(x) for x in symbols.get_all_symbol_ids() if x not in allowed_symbol_ids ] logger.info( f"Keep utterances with these symbols: {' '.join(allowed_symbols)}") logger.info( f"Remove utterances with these symbols: {' '.join(not_allowed_symbols)}" ) logger.info("Statistics before filtering:") log_stats(data, symbols, accent_ids, speakers, logger) result = MergedDataset([ x for x in data.items() if contains_only_allowed_symbols( deserialize_list(x.serialized_symbol_ids), allowed_symbol_ids) ]) if len(result) > 0: logger.info( f"Removed {len(data) - len(result)} from {len(data)} total entries and got {len(result)} entries ({len(result)/len(data)*100:.2f}%)." ) else: logger.info("Removed all utterances!") new_symbol_ids = update_symbols(result, symbols) new_accent_ids = update_accents(result, accent_ids) new_speaker_ids = update_speakers(result, speakers) logger.info("Statistics after filtering:") log_stats(result, new_symbol_ids, new_accent_ids, new_speaker_ids, logger) res = MergedDatasetContainer( name=None, data=result, accent_ids=new_accent_ids, speaker_ids=new_speaker_ids, symbol_ids=new_symbol_ids, ) return res
def from_instances(cls, model: Tacotron2, optimizer: Adam, hparams: HParams, iteration: int, symbols: SymbolIdDict, accents: AccentsDict, speakers: SpeakersDict): result = cls(state_dict=model.state_dict(), optimizer=optimizer.state_dict(), learning_rate=hparams.learning_rate, iteration=iteration, hparams=asdict(hparams), symbols=symbols.raw(), accents=accents.raw(), speakers=speakers.raw()) return result
def sims_to_csv(sims: Dict[int, List[Tuple[int, float]]], symbols: SymbolIdDict) -> pd.DataFrame: lines = [] assert len(sims) == len(symbols) for symbol_id, similarities in sims.items(): sims = [f"{symbols.get_symbol(symbol_id)}", "<=>"] for other_symbol_id, similarity in similarities: sims.append(symbols.get_symbol(other_symbol_id)) sims.append(f"{similarity:.2f}") lines.append(sims) df = pd.DataFrame(lines) return df
def from_sentences(cls, sentences: SentenceList, accents: AccentsDict, symbols: SymbolIdDict): res = cls() for sentence in sentences.items(): infer_sent = InferSentence( sent_id=sentence.sent_id, symbols=symbols.get_symbols(sentence.serialized_symbols), accents=accents.get_accents(sentence.serialized_accents), original_text=sentence.original_text, ) assert len(infer_sent.symbols) == len(infer_sent.accents) res.append(infer_sent) return res
def add_text(text: str, lang: Language, logger: Logger) -> Tuple[SymbolIdDict, SentenceList]: res = SentenceList() # each line is at least regarded as one sentence. lines = text.split("\n") all_sents = [] for line in lines: sents = text_to_sentences( text=line, lang=lang, logger=logger, ) all_sents.extend(sents) default_accent_id = 0 ipa_settings = IPAExtractionSettings( ignore_tones=False, ignore_arcs=False, replace_unknown_ipa_by=DEFAULT_PADDING_SYMBOL, ) sents_symbols: List[List[str]] = [text_to_symbols( sent, lang=lang, ipa_settings=ipa_settings, logger=logger, ) for sent in all_sents] symbols = SymbolIdDict.init_from_symbols(get_unique_items(sents_symbols)) for i, sent_symbols in enumerate(sents_symbols): sentence = Sentence( sent_id=i + 1, lang=lang, serialized_symbols=symbols.get_serialized_ids(sent_symbols), serialized_accents=serialize_list([default_accent_id] * len(sent_symbols)), text=SymbolIdDict.symbols_to_text(sent_symbols), original_text=SymbolIdDict.symbols_to_text(sent_symbols), orig_lang=lang, ) res.append(sentence) return symbols, res
def sents_accent_template(sentences: SentenceList, text_symbols: SymbolIdDict, accent_ids: AccentsDict) -> AccentedSymbolList: res = AccentedSymbolList() for i, sent in enumerate(sentences.items()): symbols = text_symbols.get_symbols(sent.serialized_symbols) accents = accent_ids.get_accents(sent.serialized_accents) for j, symbol_accent in enumerate(zip(symbols, accents)): symbol, accent = symbol_accent accented_symbol = AccentedSymbol( position=f"{i}-{j}", symbol=symbol, accent=accent ) res.append(accented_symbol) return res
def sents_normalize(sentences: SentenceList, text_symbols: SymbolIdDict, logger: Logger) -> Tuple[SymbolIdDict, SentenceList]: # Maybe add info if something was unknown sents_new_symbols = [] for sentence in sentences.items(): new_symbols, new_accent_ids = symbols_normalize( symbols=text_symbols.get_symbols(sentence.serialized_symbols), lang=sentence.lang, accent_ids=deserialize_list(sentence.serialized_accents), logger=logger, ) # TODO: check if new sentences resulted and then split them. sentence.serialized_accents = serialize_list(new_accent_ids) sents_new_symbols.append(new_symbols) return update_symbols_and_text(sentences, sents_new_symbols)
def get_ngram_stats_df(symbols: SymbolIdDict, trainset: PreparedDataList, valset: PreparedDataList, testset: PreparedDataList, restset: PreparedDataList, n: int, logger: Logger): total_set = get_total_set(trainset, valset, testset, restset) logger.info(f"Getting all {n}-gram stats...") tot_symbols = [ symbols.get_symbols(x.serialized_symbol_ids) for x in total_set.items() ] tot_symbols_one_gram = [get_ngrams(x, n=n) for x in tot_symbols] symbol_order = list(sorted({x for y in tot_symbols_one_gram for x in y})) ngram_stats = _get_ngram_stats_df_core( symbol_order=symbol_order, symbols=symbols, trainset=trainset, valset=valset, testset=testset, restset=restset, n=n, logger=logger, ) occurences_count_df, occurrences_percent_df, occurrences_distribution_percent_df, utterance_occurrences_count_df, utterance_occurrences_percent_df, uniform_occurrences_count_df, uniform_occurrences_percent_df = ngram_stats symbol_dfs = [] symbol_dfs.append(occurences_count_df) symbol_dfs.append(occurrences_percent_df) symbol_dfs.append(occurrences_distribution_percent_df) symbol_dfs.append(utterance_occurrences_count_df) symbol_dfs.append(utterance_occurrences_percent_df) symbol_dfs.append(uniform_occurrences_count_df) symbol_dfs.append(uniform_occurrences_percent_df) for i in range(1, len(symbol_dfs)): symbol_dfs[i] = symbol_dfs[i].loc[:, symbol_dfs[i]. columns != FIRST_COL_NAME] symbol_stats = pd.concat( symbol_dfs, axis=1, join='inner', ) # symbol_stats = symbol_stats.round(decimals=2) symbol_stats = symbol_stats.sort_values(by='TOTAL_OCCURRENCES_COUNT', ascending=False) print(symbol_stats) return symbol_stats
def log_stats(data: MergedDataset, symbols: SymbolIdDict, accent_ids: AccentsDict, speakers: SpeakersDict, logger: Logger): logger.info( f"Speakers ({len(speakers)}): {', '.join(sorted(speakers.get_all_speakers()))}" ) logger.info( f"Symbols ({len(symbols)}): {' '.join(sorted(symbols.get_all_symbols()))}" ) logger.info( f"Accents ({len(accent_ids)}): {', '.join(sorted(accent_ids.get_all_accents()))}" ) logger.info( f"Entries ({len(data)}): {data.get_total_duration_s()/60:.2f}m") symbol_counter = get_counter( [symbols.get_symbols(x.serialized_symbol_ids) for x in data.items()]) logger.info(symbol_counter)
def make_common_symbol_ids(self) -> SymbolIdDict: all_symbols: Set[str] = set() for ds in self.data: all_symbols |= ds.symbol_ids.get_all_symbols() new_symbol_ids = SymbolIdDict.init_from_symbols_with_pad( all_symbols, pad_symbol=DEFAULT_PADDING_SYMBOL) for ds in self.data: for entry in ds.data.items(): original_symbols = ds.symbol_ids.get_symbols( entry.serialized_symbol_ids) entry.serialized_symbol_ids = new_symbol_ids.get_serialized_ids( original_symbols) ds.symbol_ids = new_symbol_ids return new_symbol_ids
def plot_embeddings( symbols: SymbolIdDict, emb: torch.Tensor, logger: Logger) -> Tuple[pd.DataFrame, go.Figure, go.Figure]: assert emb.shape[0] == len(symbols) logger.info(f"Emb size {emb.shape}") logger.info(f"Sym len {len(symbols)}") sims = get_similarities(emb.numpy()) df = sims_to_csv(sims, symbols) all_symbols_sorted = [symbols.get_symbol(x) for x in range(len(symbols))] emb_normed = norm2emb(emb) fig_2d = emb_plot_2d(emb_normed, all_symbols_sorted) fig_3d = emb_plot_3d(emb_normed, all_symbols_sorted) return df, fig_2d, fig_3d
def test_get_similarities_is_sorted(self): symbols = SymbolIdDict.init_from_symbols({"a", "b", "c"}) emb = np.zeros(shape=(3, 3)) emb[symbols.get_id("a")] = [0.5, 1.0, 0] emb[symbols.get_id("b")] = [1.0, 0.6, 0] emb[symbols.get_id("c")] = [1.0, 0.5, 0] sims = get_similarities(emb) self.assertEqual(3, len(sims)) self.assertEqual(symbols.get_id("b"), sims[symbols.get_id("a")][0][0]) self.assertEqual(symbols.get_id("c"), sims[symbols.get_id("a")][1][0]) self.assertEqual(symbols.get_id("b"), sims[symbols.get_id("c")][0][0]) self.assertEqual(symbols.get_id("a"), sims[symbols.get_id("c")][1][0]) self.assertEqual(symbols.get_id("c"), sims[symbols.get_id("b")][0][0]) self.assertEqual(symbols.get_id("a"), sims[symbols.get_id("b")][1][0])
def convert_v1_to_v2_model(old_model_path: str, custom_hparams: Optional[Dict[str, str]], speakers: SpeakersDict, accents: AccentsDict, symbols: SymbolIdDict): checkpoint_dict = torch.load(old_model_path, map_location='cpu') hparams = HParams(n_speakers=len(speakers), n_accents=len(accents), n_symbols=len(symbols)) hparams = overwrite_custom_hparams(hparams, custom_hparams) chp = CheckpointTacotron(state_dict=checkpoint_dict["state_dict"], optimizer=checkpoint_dict["optimizer"], learning_rate=checkpoint_dict["learning_rate"], iteration=checkpoint_dict["iteration"] + 1, hparams=asdict(hparams), speakers=speakers.raw(), symbols=symbols.raw(), accents=accents.raw()) new_model_path = f"{old_model_path}_{get_pytorch_filename(chp.iteration)}" chp.save(new_model_path, logging.getLogger())
def _get_ngram_stats_df_core(symbol_order: List[str], symbols: SymbolIdDict, trainset: PreparedDataList, valset: PreparedDataList, testset: PreparedDataList, restset: PreparedDataList, n: int, logger: Logger): logger.info(f"Get {n}-grams...") trn_symbols = [ symbols.get_symbols(x.serialized_symbol_ids) for x in trainset.items() ] val_symbols = [ symbols.get_symbols(x.serialized_symbol_ids) for x in valset.items() ] tst_symbols = [ symbols.get_symbols(x.serialized_symbol_ids) for x in testset.items() ] rst_symbols = [ symbols.get_symbols(x.serialized_symbol_ids) for x in restset.items() ] trn_symbols_one_gram = [get_ngrams(x, n=n) for x in trn_symbols] val_symbols_one_gram = [get_ngrams(x, n=n) for x in val_symbols] tst_symbols_one_gram = [get_ngrams(x, n=n) for x in tst_symbols] rst_symbols_one_gram = [get_ngrams(x, n=n) for x in rst_symbols] logger.info("Get stats...") occurences_count_df = get_occ_df_of_all_symbols( symbols=symbol_order, data_trn=trn_symbols_one_gram, data_val=val_symbols_one_gram, data_tst=tst_symbols_one_gram, data_rst=rst_symbols_one_gram, ) occurences_count_df.columns = [ FIRST_COL_NAME, 'TRAIN_OCCURRENCES_COUNT', 'VAL_OCCURRENCES_COUNT', 'TEST_OCCURRENCES_COUNT', 'REST_OCCURRENCES_COUNT', 'TOTAL_OCCURRENCES_COUNT' ] print(occurences_count_df) occurrences_percent_df = get_rel_occ_df_of_all_symbols(occurences_count_df) occurrences_percent_df.columns = [ FIRST_COL_NAME, 'TRAIN_OCCURRENCES_PERCENT', 'VAL_OCCURRENCES_PERCENT', 'TEST_OCCURRENCES_PERCENT', 'REST_OCCURRENCES_PERCENT' ] print(occurrences_percent_df) occurrences_distribution_percent_df = get_dist_among_other_symbols_df_of_all_symbols( occs_df=occurences_count_df, data_trn=trn_symbols_one_gram, data_val=val_symbols_one_gram, data_tst=tst_symbols_one_gram, data_rst=rst_symbols_one_gram, ) occurrences_distribution_percent_df.columns = [ FIRST_COL_NAME, 'TRAIN_OCCURRENCES_DISTRIBUTION_PERCENT', 'VAL_OCCURRENCES_DISTRIBUTION_PERCENT', 'TEST_OCCURRENCES_DISTRIBUTION_PERCENT', 'REST_OCCURRENCES_DISTRIBUTION_PERCENT', 'TOTAL_OCCURRENCES_DISTRIBUTION_PERCENT' ] print(occurrences_distribution_percent_df) utterance_occurrences_count_df = get_utter_occ_df_of_all_symbols( symbols=symbol_order, data_trn=trn_symbols_one_gram, data_val=val_symbols_one_gram, data_tst=tst_symbols_one_gram, data_rst=rst_symbols_one_gram, ) utterance_occurrences_count_df.columns = [ FIRST_COL_NAME, 'TRAIN_UTTERANCE_OCCURRENCES_COUNT', 'VAL_UTTERANCE_OCCURRENCES_COUNT', 'TEST_UTTERANCE_OCCURRENCES_COUNT', 'REST_UTTERANCE_OCCURRENCES_COUNT', 'TOTAL_UTTERANCE_OCCURRENCES_COUNT' ] print(utterance_occurrences_count_df) utterance_occurrences_percent_df = get_rel_utter_occ_df_of_all_symbols( utterance_occurrences_count_df) utterance_occurrences_percent_df.columns = [ FIRST_COL_NAME, 'TRAIN_UTTERANCE_OCCURRENCES_PERCENT', 'VAL_UTTERANCE_OCCURRENCES_PERCENT', 'TEST_UTTERANCE_OCCURRENCES_PERCENT', 'REST_UTTERANCE_OCCURRENCES_PERCENT' ] print(utterance_occurrences_percent_df) uniform_occurrences_count_df = get_uniform_distr_df_for_occs( symbols=symbol_order, occ_df=occurences_count_df, ) uniform_occurrences_count_df.columns = [ FIRST_COL_NAME, 'TRAIN_UNIFORM_OCCURRENCES_COUNT', 'VAL_UNIFORM_OCCURRENCES_COUNT', 'TEST_UNIFORM_OCCURRENCES_COUNT', 'REST_UNIFORM_OCCURRENCES_COUNT', 'TOTAL_UNIFORM_OCCURRENCES_COUNT' ] print(uniform_occurrences_count_df) uniform_occurrences_percent_df = get_rel_uniform_distr_df_for_occs( symbols=symbol_order, ) uniform_occurrences_percent_df.columns = [ FIRST_COL_NAME, 'UNIFORM_OCCURRENCES_PERCENT' ] print(uniform_occurrences_percent_df) return occurences_count_df, occurrences_percent_df, occurrences_distribution_percent_df, utterance_occurrences_count_df, utterance_occurrences_percent_df, uniform_occurrences_count_df, uniform_occurrences_percent_df
def get_symbols(self) -> SymbolIdDict: return SymbolIdDict.from_raw(self.symbols)
def prep_data_list_to_dict_with_symbols(l: PreparedDataList, symbols: SymbolIdDict) -> OrderedDictType[int, List[str]]: res = OrderedDict({x.entry_id: symbols.get_symbols(x.serialized_symbol_ids) for x in l.items()}) return res
def get_shard_size(symbols: SymbolIdDict) -> int: all_symbols = symbols.get_all_symbols() count = len(all_symbols) if DEFAULT_PADDING_SYMBOL in all_symbols: count -= 1 return count
def get_formatted_v2(self, symbol_id_dict: SymbolIdDict): return get_formatted_core_v2( sent_id=self.sent_id, symbols=symbol_id_dict.get_symbols(self.serialized_symbols), original_text=self.original_text, )
def test_preprocess(self): datasets = { "ljs": (DsDataList([ DsData(0, "basename0", "speaker0", 0, "text0", "wavpath0", Language.ENG), DsData(1, "basename0", "speaker0", 0, "text0", "wavpath0", Language.ENG), DsData(2, "basename0", "speaker1", 1, "text0", "wavpath0", Language.ENG), ]), TextDataList([ TextData(0, "text_pre0", "1,2,3", Language.CHN), TextData(1, "text_pre0", "1,2,3", Language.CHN), TextData(2, "text_pre0", "1,2,3", Language.CHN), ]), WavDataList([ WavData(0, "wavpath_pre0", 7.89, 22050), WavData(1, "wavpath_pre0", 7.89, 22050), WavData(2, "wavpath_pre0", 7.89, 22050), ]), MelDataList([ MelData(0, "melpath_pre0", 80), MelData(1, "melpath_pre0", 80), MelData(2, "melpath_pre0", 80), ]), ["speaker0", "speaker1"], SymbolIdDict.init_from_symbols({"a", "b"})), "thchs": (DsDataList([ DsData(0, "basename0", "speaker0", 0, "text0", "wavpath0", Language.ENG), DsData(1, "basename0", "speaker1", 1, "text0", "wavpath0", Language.ENG), DsData(2, "basename0", "speaker1", 1, "text0", "wavpath0", Language.ENG), ]), TextDataList([ TextData(0, "text_pre0", "1,2,3", Language.CHN), TextData(1, "text_pre0", "1,2,3", Language.CHN), TextData(2, "text_pre0", "1,2,3", Language.CHN), ]), WavDataList([ WavData(0, "wavpath_pre0", 7.89, 22050), WavData(1, "wavpath_pre0", 7.89, 22050), WavData(2, "wavpath_pre0", 7.89, 22050), ]), MelDataList([ MelData(0, "melpath_pre0", 80), MelData(1, "melpath_pre0", 80), MelData(2, "melpath_pre0", 80), ]), ["speaker0", "speaker1"], SymbolIdDict.init_from_symbols({"b", "c"})) } ds_speakers = { ("ljs", "speaker0"), ("thchs", "speaker1"), } whole, conv, speakers_id_dict = merge(datasets, ds_speakers, speakers_as_accents=False) self.assertEqual(4, len(whole)) self.assertEqual(set({"a", "b", "c"}), set(conv.get_all_symbols())) # TODO self.assertEqual("1,2,3", whole[0].serialized_symbol_ids) self.assertEqual("1,2,3", whole[1].serialized_symbol_ids) self.assertEqual("1,3,4", whole[2].serialized_symbol_ids) self.assertEqual("1,3,4", whole[3].serialized_symbol_ids) self.assertEqual(["ljs,speaker0", "thchs,speaker1"], speakers_id_dict.get_speakers())