예제 #1
0
파일: ljspeech.py 프로젝트: pursu/wavenet
class LJSpeech:
    def __init__(self, in_dir, out_dir, hparams):
        self.in_dir = in_dir
        self.out_dir = out_dir
        self.audio = Audio(hparams)

    def text_and_audio_path_rdd(self, sc: SparkContext):
        return sc.parallelize(self._extract_all_text_and_audio_path())

    def process_data(self, rdd: RDD):
        return rdd.mapValues(self._process_source_and_target)

    def _extract_text_and_path(self, line, index):
        parts = line.strip().split('|')
        key = parts[0]
        wav_path = os.path.join(self.in_dir, 'wavs', '%s.wav' % parts[0])
        text = parts[2]
        return TextAndAudioPath(index, key, wav_path, text)

    def _extract_all_text_and_audio_path(self):
        index = 1
        with open(os.path.join(self.in_dir, 'metadata.csv'),
                  mode='r',
                  encoding='utf-8') as f:
            for line in f:
                extracted = self._extract_text_and_path(line, index)
                if extracted is not None:
                    yield (index, extracted)
                    index += 1

    def _process_source_and_target(self, paths: TextAndAudioPath):
        wav = self.audio.load_wav(paths.wav_path)
        n_samples = len(wav)
        mel_spectrogram = self.audio.melspectrogram(wav).astype(np.float32).T
        n_frames = mel_spectrogram.shape[0]
        filename = f"{paths.key}.tfrecord"
        filepath = os.path.join(self.out_dir, filename)
        tfrecord.write_preprocessed_data(paths.id, paths.key, wav,
                                         mel_spectrogram, paths.text, filepath)
        return SourceAndTargetMetaData(paths.id, paths.key, n_samples,
                                       n_frames, filepath)

    def _process_mel(self, paths: TextAndAudioPath):
        wav = self.audio.load_wav(paths.wav_path)
        mel_spectrogram = self.audio.melspectrogram(wav).astype(np.float32).T
        sum_mel_powers = np.sum(mel_spectrogram, axis=1)
        n_frames = mel_spectrogram.shape[0]
        return MelMetaData(n_frames, sum_mel_powers)
class LJSpeech:
    def __init__(self, in_dir, mel_out_dir, wav_out_dir, hparams):
        self.in_dir = in_dir
        self.mel_out_dir = mel_out_dir
        self.wav_out_dir = wav_out_dir
        self.audio = Audio(hparams)

    @property
    def record_ids(self):
        return map(lambda v: str(v), range(1, 13101))

    def record_file_path(self, record_id, kind):
        assert kind in ["source", "target"]
        return os.path.join(self.mel_out_dir,
                            f"ljspeech-{kind}-{int(record_id):05d}.tfrecord")

    def text_and_path_rdd(self, sc: SparkContext):
        return sc.parallelize(self._extract_all_text_and_path())

    def process_wav(self, rdd: RDD):
        return rdd.mapValues(self._process_wav)

    def _extract_text_and_path(self, line, index):
        parts = line.strip().split('|')
        key = parts[0]
        text = parts[2]
        wav_path = os.path.join(self.in_dir, 'wavs', '%s.wav' % key)
        return TextAndPath(index, key, wav_path, None, text)

    def _extract_all_text_and_path(self):
        with open(os.path.join(self.in_dir, 'metadata.csv'),
                  mode='r',
                  encoding='utf-8') as f:
            for index, line in enumerate(f):
                extracted = self._extract_text_and_path(line, index)
                if extracted is not None:
                    yield (index, extracted)

    def _process_wav(self, paths: TextAndPath):
        wav = self.audio.load_wav(paths.wav_path)
        mel_spectrogram = self.audio.melspectrogram(wav).astype(np.float32).T
        mel_spectrogram = self.audio.normalize_mel(mel_spectrogram)

        mel_filepath = os.path.join(self.mel_out_dir, f"{paths.key}.mfbsp")
        wav_filepath = os.path.join(self.wav_out_dir, f"{paths.key}.wav")

        mel_spectrogram.tofile(mel_filepath, format="<f4")
        self.audio.save_wav(wav, wav_filepath)
예제 #3
0
class VCTK:
    def __init__(self,
                 in_dir,
                 out_dir,
                 hparams,
                 speaker_info_filename='speaker-info.txt'):
        self.in_dir = in_dir
        self.out_dir = out_dir
        self.speaker_info_filename = speaker_info_filename
        self.audio = Audio(hparams)

    def list_files(self):
        def wav_files(speaker_info: SpeakerInfo):
            wav_dir = os.path.join(self.in_dir, f"wav48/p{speaker_info.id}")
            return [
                os.path.join(wav_dir, wav_file)
                for wav_file in sorted(os.listdir(wav_dir))
                if wav_file.endswith('.wav')
            ]

        def text_files(speaker_info: SpeakerInfo):
            txt_dir = os.path.join(self.in_dir, f"txt/p{speaker_info.id}")
            return [
                os.path.join(txt_dir, txt_file)
                for txt_file in sorted(os.listdir(txt_dir))
                if txt_file.endswith('.txt')
            ]

        def text_and_wav_records(file_pairs, speaker_info):
            def create_record(txt_f, wav_f, speaker_info):
                key1 = os.path.basename(wav_f).strip('.wav')
                key2 = os.path.basename(txt_f).strip('.txt')
                assert key1 == key2
                return TxtWavRecord(0, key1, txt_f, wav_f, speaker_info)

            return [
                create_record(txt_f, wav_f, speaker_info)
                for txt_f, wav_f in file_pairs
            ]

        records = sum([
            text_and_wav_records(zip(text_files(si), wav_files(si)), si)
            for si in self._load_speaker_info()
        ], [])
        return [
            TxtWavRecord(i, r.key, r.txt_path, r.wav_path, r.speaker_info)
            for i, r in enumerate(records)
        ]

    def process_sources(self, rdd: RDD):
        return rdd.map(self._process_txt)

    def process_targets(self, rdd: RDD):
        return TargetRDD(
            rdd.map(self._process_wav).persist(StorageLevel.MEMORY_AND_DISK))

    def _load_speaker_info(self):
        with open(os.path.join(self.in_dir, self.speaker_info_filename),
                  mode='r',
                  encoding='utf8') as f:
            for l in f.readlines()[1:]:
                si = l.split()
                gender = 0 if si[2] == 'F' else 1
                if str(si[0]) != "315":  # FixMe: Why 315 is missing?
                    yield SpeakerInfo(int(si[0]), int(si[1]), gender)

    def _process_wav(self, record: TxtWavRecord):
        wav = self.audio.load_wav(record.wav_path)
        wav = self.audio.trim(wav)
        mel_spectrogram = self.audio.melspectrogram(wav).astype(np.float32).T
        file_path = os.path.join(self.out_dir, f"{record.key}.target.tfrecord")
        write_preprocessed_target_data(record.id, record.key, mel_spectrogram,
                                       file_path)
        return MelStatistics(id=record.id,
                             key=record.key,
                             min=np.min(mel_spectrogram, axis=0),
                             max=np.max(mel_spectrogram, axis=0),
                             sum=np.sum(mel_spectrogram, axis=0),
                             length=len(mel_spectrogram),
                             moment2=np.sum(np.square(mel_spectrogram),
                                            axis=0))

    def _process_txt(self, record: TxtWavRecord):
        with open(os.path.join(self.in_dir, record.txt_path),
                  mode='r',
                  encoding='utf8') as f:
            txt = f.readline().rstrip("\n")
            sequence, clean_text = text_to_sequence(txt, basic_cleaners)
            source = np.array(sequence, dtype=np.int64)
            file_path = os.path.join(self.out_dir,
                                     f"{record.key}.source.tfrecord")
            write_preprocessed_source_data(record.id, record.key, source,
                                           clean_text, record.speaker_info.id,
                                           record.speaker_info.age,
                                           record.speaker_info.gender,
                                           file_path)
            return record.key
예제 #4
0
class LJSpeech:
    def __init__(self, in_dir, out_dir, hparams):
        self.in_dir = in_dir
        self.out_dir = out_dir
        self.audio = Audio(hparams)

    @property
    def record_ids(self):
        return map(lambda v: str(v), range(1, 13101))

    def record_file_path(self, record_id, kind):
        assert kind in ["source", "target"]
        return os.path.join(self.out_dir,
                            f"ljspeech-{kind}-{int(record_id):05d}.tfrecord")

    def text_and_path_rdd(self, sc: SparkContext):
        return sc.parallelize(self._extract_all_text_and_path())

    def process_targets(self, rdd: RDD):
        return TargetRDD(
            rdd.mapValues(self._process_target).persist(
                StorageLevel.MEMORY_AND_DISK))

    def process_sources(self, rdd: RDD):
        return rdd.mapValues(self._process_source)

    def _extract_text_and_path(self, line, index):
        parts = line.strip().split('|')
        key = parts[0]
        text = parts[2]
        wav_path = os.path.join(self.in_dir, 'wavs', '%s.wav' % key)
        return TextAndPath(index, key, wav_path, None, text)

    def _extract_all_text_and_path(self):
        with open(os.path.join(self.in_dir, 'metadata.csv'),
                  mode='r',
                  encoding='utf-8') as f:
            for index, line in enumerate(f):
                extracted = self._extract_text_and_path(line, index)
                if extracted is not None:
                    yield (index, extracted)

    def _text_to_sequence(self, text):
        sequence, clean_text = text_to_sequence(text, english_cleaners)
        sequence = np.array(sequence, dtype=np.int64)
        return sequence, clean_text

    def _process_target(self, paths: TextAndPath):
        wav = self.audio.load_wav(paths.wav_path)
        mel_spectrogram = self.audio.melspectrogram(wav).astype(np.float32).T
        filename = f"{paths.key}.target.tfrecord"
        filepath = os.path.join(self.out_dir, filename)
        write_preprocessed_target_data(paths.id, paths.key, mel_spectrogram,
                                       filepath)
        return MelStatistics(id=paths.id,
                             key=paths.key,
                             min=np.min(mel_spectrogram, axis=0),
                             max=np.max(mel_spectrogram, axis=0),
                             sum=np.sum(mel_spectrogram, axis=0),
                             length=len(mel_spectrogram),
                             moment2=np.sum(np.square(mel_spectrogram),
                                            axis=0))

    def _process_source(self, paths: TextAndPath):
        sequence, clean_text = self._text_to_sequence(paths.text)
        filename = f"{paths.key}.source.tfrecord"
        filepath = os.path.join(self.out_dir, filename)
        write_preprocessed_source_data(paths.id, paths.key, sequence,
                                       clean_text, filepath)
        return paths.key