def get_final_ds_from_data(ds_data: DsDataList, text_data: TextDataList, wav_data: WavDataList, mel_data: MelDataList, wav_dir: Path, mel_dir: Path) -> FinalDsEntryList: res = WavDataList() for ds_data_entry, text_data_entry, wav_data_entry, mel_data_entry in zip( ds_data.items(), text_data.items(), wav_data.items(), mel_data.items()): assert ds_data_entry.entry_id == text_data_entry.entry_id == wav_data_entry.entry_id == mel_data_entry.entry_id assert ds_data_entry.symbols_language == text_data_entry.symbols_language new_entry = FinalDsEntry( entry_id=ds_data_entry.entry_id, basename=ds_data_entry.basename, speaker_gender=ds_data_entry.speaker_gender, speaker_name=ds_data_entry.speaker_name, symbols_language=ds_data_entry.symbols_language, symbols_original=ds_data_entry.symbols, symbols_original_format=ds_data_entry.symbols_format, symbols=text_data_entry.symbols, symbols_format=text_data_entry.symbols_format, wav_original_absolute_path=ds_data_entry.wav_absolute_path, wav_absolute_path=wav_dir / wav_data_entry.wav_relative_path, wav_duration=wav_data_entry.wav_duration, wav_sampling_rate=wav_data_entry.wav_sampling_rate, mel_absolute_path=mel_dir / mel_data_entry.mel_relative_path, mel_n_channels=mel_data_entry.mel_n_channels, ) res.append(new_entry) return res
def _save_speaker_examples(ds_dir: Path, examples: DsDataList, logger: Logger) -> None: logger.info("Saving examples for each speaker...") example_dir = get_ds_examples_dir(ds_dir) example_dir.mkdir(exist_ok=True, parents=True) for i, example in enumerate(examples.items(True), start=1): dest_file_name = f"{i}-{str(example.speaker_gender)}-{convert_to_ascii(example.speaker_name)}.wav" dest_path = example_dir / dest_file_name copyfile(example.wav_absolute_path, dest_path)
def preprocess(data: DsDataList, dest_dir: Path, n_jobs: int) -> WavDataList: assert dest_dir.is_dir() mt_method = partial( preprocess_entry, dest_dir=dest_dir, entries_count=len(data), ) with ThreadPoolExecutor(max_workers=n_jobs) as ex: result = WavDataList(tqdm(ex.map(mt_method, data.items()), total=len(data))) return result
def log_stats(ds_data: DsDataList, text_data: TextDataList): stats: List[str, int, float, float, float] = [] text_lengths = [len(x.symbols) for x in text_data.items()] stats.append(( "Overall", len(text_lengths), min(text_lengths), max(text_lengths), mean(text_lengths), sum(text_lengths), )) speakers_text_lengths: Dict[Speaker, List[float]] = {} for ds_entry, text_entry in zip(ds_data.items(), text_data.items()): if ds_entry.speaker_name not in speakers_text_lengths: speakers_text_lengths[ds_entry.speaker_name] = [] speakers_text_lengths[ds_entry.speaker_name].append( len(text_entry.symbols)) for speaker, speaker_text_lengths in speakers_text_lengths.items(): stats.append(( speaker, len(speaker_text_lengths), min(speaker_text_lengths), max(speaker_text_lengths), mean(speaker_text_lengths), sum(speaker_text_lengths), )) stats.sort(key=lambda x: (x[-1]), reverse=True) stats_csv = pd.DataFrame(stats, columns=[ "Speaker", "# Entries", "# Min", "# Max", "# Avg", "# Total", ]) logger = getLogger(__name__) with pd.option_context( 'display.max_rows', None, 'display.max_columns', None, 'display.width', None, 'display.precision', 0, ): logger.info(stats_csv)
def preprocess(data: DsDataList) -> TextDataList: result = TextDataList() for entry in data.items(True): text_entry = TextData( entry_id=entry.entry_id, symbols=entry.symbols, symbols_format=entry.symbols_format, symbols_language=entry.symbols_language, ) result.append(text_entry) return result
def log_stats(ds_data: DsDataList, wav_data: WavDataList): logger = getLogger(__name__) if len(wav_data) > 0: logger.info(f"Sampling rate: {wav_data.items()[0].wav_sampling_rate}") stats: List[str, int, float, float, float, int] = [] durations = [x.wav_duration for x in wav_data.items()] stats.append(( "Overall", len(wav_data), min(durations), max(durations), mean(durations), sum(durations) / 60, sum(durations) / 3600, )) speaker_durations: Dict[Speaker, List[float]] = {} for ds_entry, wav_entry in zip(ds_data.items(), wav_data.items()): if ds_entry.speaker_name not in speaker_durations: speaker_durations[ds_entry.speaker_name] = [] speaker_durations[ds_entry.speaker_name].append(wav_entry.wav_duration) for speaker_name, speaker_durations in speaker_durations.items(): stats.append(( speaker_name, len(speaker_durations), min(speaker_durations), max(speaker_durations), mean(speaker_durations), sum(speaker_durations) / 60, sum(speaker_durations) / 3600, )) stats.sort(key=lambda x: (x[-2]), reverse=True) stats_csv = pd.DataFrame(stats, columns=[ "Speaker", "# Entries", "Min (s)", "Max (s)", "Avg (s)", "Total (min)", "Total (h)", ]) with pd.option_context( 'display.max_rows', None, 'display.max_columns', None, 'display.width', None, 'display.precision', 4, ): print(stats_csv)
def process(data: WavDataList, ds: DsDataList, wav_dir: Path, custom_hparams: Optional[Dict[str, str]], save_callback: Callable[[WavData, DsData], Path]) -> List[Path]: hparams = TSTFTHParams() hparams = overwrite_custom_hparams(hparams, custom_hparams) mel_parser = TacotronSTFT(hparams, logger=getLogger()) all_paths: List[Path] = [] for wav_entry, ds_entry in zip(data.items(True), ds.items(True)): absolute_wav_path = wav_dir / wav_entry.wav_relative_path mel_tensor = mel_parser.get_mel_tensor_from_file(absolute_wav_path) absolute_path = save_callback(wav_entry=wav_entry, ds_entry=ds_entry, mel_tensor=mel_tensor) all_paths.append(absolute_path) return all_paths
def test_get_final_ds_from_data(): ds_data = DsData( entry_id=1, basename="basename", speaker_gender=Gender.FEMALE, speaker_name="Speaker 1", symbols=( "a", "b", ), symbols_format=SymbolFormat.GRAPHEMES, symbols_language=Language.ENG, wav_absolute_path=Path("test.wav"), ) text_data = TextData( entry_id=ds_data.entry_id, symbols=( "ʊ", "ɪ", ), symbols_format=SymbolFormat.PHONEMES_IPA, symbols_language=ds_data.symbols_language, ) wav_data = WavData( entry_id=ds_data.entry_id, wav_duration=7.5, wav_relative_path=Path("test2.wav"), wav_sampling_rate=22050, ) mel_data = MelData( entry_id=ds_data.entry_id, mel_n_channels=80, mel_relative_path=Path("test2.pt"), ) result = get_final_ds_from_data( ds_data=DsDataList([ds_data]), mel_data=MelDataList([mel_data]), mel_dir=Path("meldir"), text_data=TextDataList([text_data]), wav_data=WavDataList([wav_data]), wav_dir=Path("wavdir"), ) assert len(result) == 1 result_first_entry = result.items()[0] assert result_first_entry.entry_id == 1 assert result_first_entry.basename == "basename" assert result_first_entry.speaker_gender == Gender.FEMALE assert result_first_entry.speaker_name == "Speaker 1" assert result_first_entry.symbols_format == SymbolFormat.PHONEMES_IPA assert result_first_entry.symbols == ( "ʊ", "ɪ", ) assert result_first_entry.symbols_original == ( "a", "b", ) assert result_first_entry.symbols_original_format == SymbolFormat.GRAPHEMES assert result_first_entry.wav_duration == 7.5 assert result_first_entry.wav_sampling_rate == 22050 assert result_first_entry.wav_absolute_path == Path("wavdir/test2.wav") assert result_first_entry.wav_original_absolute_path == Path("test.wav") assert result_first_entry.mel_absolute_path == Path("meldir/test2.pt") assert result_first_entry.mel_n_channels == 80