Пример #1
0
 def __init__(self, in_dir, out_dir):
     self.in_dir = in_dir
     self.out_dir = out_dir
     self.books = [
         'ATrampAbroad',
         'TheManThatCorruptedHadleyburg',
         'LifeOnTheMississippi',
         'TheAdventuresOfTomSawyer',
     ]
     self._end_buffer = 0.05
     self._min_confidence = 90
     self.audio = Audio(hparams)
Пример #2
0
 def __init__(self,
              in_dir,
              out_dir,
              hparams,
              speaker_info_filename='speaker-info.txt'):
     self.in_dir = in_dir
     self.out_dir = out_dir
     self.audio = Audio(hparams)
     self.g2p = Flite(
         hparams.flite_binary_path,
         hparams.phoneset_path) if hparams.phoneme == 'flite' else None
     self.speaker_info_filename = speaker_info_filename
Пример #3
0
def predict(hparams,
            model_dir, postnet_model_dir,
            test_source_files, test_target_files):
    audio = Audio(hparams)

    def predict_input_fn():
        source = tf.data.TFRecordDataset(list(test_source_files))
        target = tf.data.TFRecordDataset(list(test_target_files))
        dataset = DatasetSource(source, target, hparams)
        batched = dataset.prepare_and_zip().filter_by_max_output_length().group_by_batch(batch_size=1)
        return batched.dataset

    estimator = SingleSpeakerTacotronV1Model(hparams, model_dir)

    predictions = map(
        lambda p: PredictedMel(p["id"], p["mel"], p["mel"].shape[1], p["mel"].shape[0], p["alignment"], p["source"],
                               p["text"]),
        estimator.predict(predict_input_fn))

    def predict_postnet_input_fn():
        prediction_dataset = tf.data.Dataset.from_generator(lambda: predictions,
                                                            output_types=PredictedMel(tf.int64,
                                                                                      tf.float32,
                                                                                      tf.int64,
                                                                                      tf.int64,
                                                                                      tf.float32,
                                                                                      tf.int64,
                                                                                      tf.string))
        target = tf.data.TFRecordDataset(list(test_target_files))
        dataset = PostNetDatasetSource(target, hparams)
        batched = dataset.create_source_and_target().filter_by_max_output_length().combine_with_prediction(
            prediction_dataset).expand_batch_dim()
        return batched.dataset

    postnet_estimator = TacotronV1PostNetModel(hparams, audio, postnet_model_dir)

    for v in postnet_estimator.predict(predict_postnet_input_fn):
        filename = f"{v['id']}.wav"
        filepath = os.path.join(postnet_model_dir, filename)
        audio.save_wav(v["audio"], filepath)
def train_and_evaluate(hparams, model_dir, train_target_files,
                       eval_target_files):
    audio = Audio(hparams)

    def train_input_fn():
        shuffled_train_target_files = list(train_target_files)
        shuffle(shuffled_train_target_files)
        target = tf.data.TFRecordDataset(
            [t for t in shuffled_train_target_files])

        dataset = PostNetDatasetSource(target, hparams)
        batched = dataset.create_source_and_target(
        ).filter_by_max_output_length().repeat().shuffle(
            hparams.suffle_buffer_size).group_by_batch()
        return batched.dataset

    def eval_input_fn():
        shuffled_eval_target_files = list(eval_target_files)
        shuffle(shuffled_eval_target_files)
        target = tf.data.TFRecordDataset(
            [t for t in shuffled_eval_target_files])

        dataset = PostNetDatasetSource(target, hparams)
        dataset = dataset.create_source_and_target(
        ).filter_by_max_output_length().repeat().group_by_batch(batch_size=1)
        return dataset.dataset

    run_config = tf.estimator.RunConfig(
        save_summary_steps=hparams.save_summary_steps,
        log_step_count_steps=hparams.log_step_count_steps)
    estimator = TacotronV1PostNetModel(hparams,
                                       audio,
                                       model_dir,
                                       config=run_config)

    train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn)
    eval_spec = tf.estimator.EvalSpec(
        input_fn=eval_input_fn,
        steps=hparams.num_evaluation_steps,
        throttle_secs=hparams.eval_throttle_secs,
        start_delay_secs=hparams.eval_start_delay_secs)

    tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
Пример #5
0
class VCTK:
    def __init__(self,
                 in_dir,
                 out_dir,
                 hparams,
                 speaker_info_filename='speaker-info.txt'):
        self.in_dir = in_dir
        self.out_dir = out_dir
        self.audio = Audio(hparams)
        self.g2p = Flite(
            hparams.flite_binary_path,
            hparams.phoneset_path) if hparams.phoneme == 'flite' else None
        self.speaker_info_filename = speaker_info_filename

    def list_files(self):
        missing = ["s5_052.txt", "s5_219.txt"]

        def wav_files(speaker_info: SpeakerInfo):
            wav_dir = os.path.join(self.in_dir, f"wav48/s{speaker_info.id}"
                                   ) if speaker_info.id == 5 else os.path.join(
                                       self.in_dir,
                                       f"wav48/p{speaker_info.id}")
            return [
                os.path.join(wav_dir, wav_file)
                for wav_file in sorted(os.listdir(wav_dir))
                if wav_file.endswith('_mic2.flac')
            ]

        def text_files(speaker_info: SpeakerInfo):
            txt_dir = os.path.join(self.in_dir, f"txt/s{speaker_info.id}"
                                   ) if speaker_info.id == 5 else os.path.join(
                                       self.in_dir, f"txt/p{speaker_info.id}")
            return [
                os.path.join(txt_dir, txt_file)
                for txt_file in sorted(os.listdir(txt_dir))
                if txt_file.endswith('.txt')
                and not os.path.basename(txt_file) in missing
            ]

        def text_and_wav_records(file_pairs, speaker_info):
            def create_record(txt_f, wav_f, speaker_info):
                key1 = os.path.basename(wav_f).replace("_mic2.flac", "")
                key2 = os.path.basename(txt_f).replace(".txt", "")
                assert key1 == key2, f"{key1} != {key2}"
                return TxtWavRecord(0, key1, txt_f, wav_f, speaker_info)

            return [
                create_record(txt_f, wav_f, speaker_info)
                for txt_f, wav_f in file_pairs
            ]

        records = sum([
            text_and_wav_records(zip(text_files(si), wav_files(si)), si)
            for si in self._load_speaker_info()
        ], [])
        return [
            TxtWavRecord(i, r.key, r.txt_path, r.wav_path, r.speaker_info)
            for i, r in enumerate(records)
        ]

    def process_sources(self, rdd: RDD):
        return rdd.map(self._process_txt)

    def process_targets(self, rdd: RDD):
        return TargetRDD(
            rdd.map(self._process_wav).persist(StorageLevel.MEMORY_AND_DISK))

    def _load_speaker_info(self):
        with open(os.path.join(self.in_dir, self.speaker_info_filename),
                  mode='r',
                  encoding='utf8') as f:
            for l in f.readlines()[1:]:
                si = l.split()
                gender = 0 if si[2] == 'F' else 1
                if str(si[0][1:]) not in ["315",
                                          "362"]:  # FixMe: Why 315 is missing?
                    yield SpeakerInfo(int(si[0][1:]), int(si[1]), gender)

    def _process_wav(self, record: TxtWavRecord):
        wav = self.audio.load_wav(record.wav_path)
        wav = self.audio.trim(wav)
        mel_spectrogram = self.audio.melspectrogram(wav).astype(np.float32).T
        file_path = os.path.join(self.out_dir, f"{record.key}.target.tfrecord")
        write_preprocessed_target_data(record.id, record.key, mel_spectrogram,
                                       file_path)
        return MelStatistics(id=record.id,
                             key=record.key,
                             min=np.min(mel_spectrogram, axis=0),
                             max=np.max(mel_spectrogram, axis=0),
                             sum=np.sum(mel_spectrogram, axis=0),
                             length=len(mel_spectrogram),
                             moment2=np.sum(np.square(mel_spectrogram),
                                            axis=0))

    def _process_txt(self, record: TxtWavRecord):
        with open(os.path.join(self.in_dir, record.txt_path),
                  mode='r',
                  encoding='utf8') as f:
            txt = f.readline().rstrip("\n")
            sequence, clean_text = text_to_sequence(txt, basic_cleaners)
            phone_ids, phone_txt = self.g2p.convert_to_phoneme(
                clean_text) if self.g2p is not None else (None, None)
            source = np.array(sequence, dtype=np.int64)
            phone_ids = np.array(
                phone_ids, dtype=np.int64) if phone_ids is not None else None
            file_path = os.path.join(self.out_dir,
                                     f"{record.key}.source.tfrecord")
            write_preprocessed_source_data(record.id, record.key, source,
                                           clean_text, phone_ids, phone_txt,
                                           record.speaker_info.id,
                                           record.speaker_info.age,
                                           record.speaker_info.gender,
                                           file_path)
            return record.key
Пример #6
0
class Synthesize:
    def __init__(self,
                 in_dir,
                 out_dir,
                 hparams,
                 gender,
                 speakerID,
                 speaker_info_filename='speaker-info.txt'):
        self.in_dir = in_dir
        self.out_dir = out_dir
        self.audio = Audio(hparams)
        self.g2p = Flite(
            hparams.flite_binary_path,
            hparams.phoneset_path) if hparams.phoneme == 'flite' else None
        self.speaker_info_filename = speaker_info_filename
        self.gender = gender
        self.speakerID = speakerID

    def list_files(self):
        def text_files(speaker_info: SpeakerInfo):
            txt_dir = self.in_dir
            return [
                os.path.join(txt_dir, txt_file)
                for txt_file in sorted(os.listdir(txt_dir))
                if txt_file.endswith('.txt')
            ]

        def text_and_wav_records(file_pairs, speaker_info):
            def create_record(txt_f, wav_f, speaker_info):
                wav_f = txt_f.split('.')[0] + '.wav'
                key1 = os.path.basename(wav_f).replace(".wav", "")
                key2 = os.path.basename(txt_f).replace(".txt", "")
                assert key1 == key2, f"{key1} != {key2}"
                return TxtWavRecord(0, key1, txt_f, wav_f, speaker_info)

            return [
                create_record(txt_f, wav_f, speaker_info)
                for txt_f, wav_f in file_pairs
            ]

        records = sum([
            text_and_wav_records(zip(text_files(si), text_files(si)), si)
            for si in self._load_speaker_info()
        ], [])
        return [
            TxtWavRecord(i, r.key, r.txt_path, r.wav_path, r.speaker_info)
            for i, r in enumerate(records)
        ]

    def process_sources(self, rdd: RDD):
        return rdd.map(self._process_txt)

    def process_targets(self, rdd: RDD):
        return TargetRDD(
            rdd.map(self._process_wav).persist(StorageLevel.MEMORY_AND_DISK))

    def _load_speaker_info(self):
        ## filling in the age field with 100, arbitrarily, since this
        ## does not get used at all currently
        gender = 0 if self.gender == 'F' else 1
        yield SpeakerInfo(int(self.speakerID[1:]), 100, gender)

    def _process_wav(self, record: TxtWavRecord):
        ## dummy audio
        wav = np.asarray([0 for x in range(0, 48000)], dtype='float32')
        mel_spectrogram = self.audio.melspectrogram(wav).astype(np.float32).T
        file_path = os.path.join(self.out_dir, f"{record.key}.target.tfrecord")
        write_preprocessed_target_data(record.id, record.key, mel_spectrogram,
                                       file_path)
        return MelStatistics(id=record.id,
                             key=record.key,
                             min=np.min(mel_spectrogram, axis=0),
                             max=np.max(mel_spectrogram, axis=0),
                             sum=np.sum(mel_spectrogram, axis=0),
                             length=len(mel_spectrogram),
                             moment2=np.sum(np.square(mel_spectrogram),
                                            axis=0))

    def _process_txt(self, record: TxtWavRecord):
        with open(os.path.join(self.in_dir, record.txt_path),
                  mode='r',
                  encoding='utf8') as f:
            txt = f.readline().rstrip("\n")
            sequence, clean_text = text_to_sequence(txt, basic_cleaners)
            phone_ids, phone_txt = self.g2p.convert_to_phoneme(
                clean_text) if self.g2p is not None else (None, None)
            source = np.array(sequence, dtype=np.int64)
            phone_ids = np.array(
                phone_ids, dtype=np.int64) if phone_ids is not None else None
            file_path = os.path.join(self.out_dir,
                                     f"{record.key}.source.tfrecord")
            write_preprocessed_source_data(record.id, record.key, source,
                                           clean_text, phone_ids, phone_txt,
                                           record.speaker_info.id,
                                           record.speaker_info.age,
                                           record.speaker_info.gender,
                                           file_path)
            return record.key
Пример #7
0
class Blizzard2012(Corpus):

    def __init__(self, in_dir, out_dir):
        self.in_dir = in_dir
        self.out_dir = out_dir
        self.books = [
            'ATrampAbroad',
            'TheManThatCorruptedHadleyburg',
            'LifeOnTheMississippi',
            'TheAdventuresOfTomSawyer',
        ]
        self._end_buffer = 0.05
        self._min_confidence = 90
        self.audio = Audio(hparams)

    @property
    def training_source_files(self):
        return [os.path.join(self.out_dir, f"blizzard2012-source-{record_id:05d}.tfrecord") for record_id in
                range(321, 23204)]

    @property
    def training_target_files(self):
        return [os.path.join(self.out_dir, f"blizzard2012-target-{record_id:05d}.tfrecord") for record_id in
                range(321, 23204)]

    @property
    def validation_source_files(self):
        return [os.path.join(self.out_dir, f"blizzard2012-source-{record_id:05d}.tfrecord") for record_id in
                range(11, 321)]

    @property
    def validation_target_files(self):
        return [os.path.join(self.out_dir, f"blizzard2012-target-{record_id:05d}.tfrecord") for record_id in
                range(11, 321)]

    @property
    def test_source_files(self):
        return [os.path.join(self.out_dir, f"blizzard2012-source-{record_id:05d}.tfrecord") for record_id in
                range(1, 11)]

    @property
    def test_target_files(self):
        return [os.path.join(self.out_dir, f"blizzard2012-target-{record_id:05d}.tfrecord") for record_id in
                range(1, 11)]

    def text_and_path_rdd(self, sc: SparkContext):
        return sc.parallelize(
            self._extract_all_text_and_path())

    def process_targets(self, rdd: RDD):
        return rdd.mapValues(self._process_target)

    def process_sources(self, rdd: RDD):
        return rdd.mapValues(self._process_source)

    def aggregate_source_metadata(self, rdd: RDD):
        def map_fn(splitIndex, iterator):
            csv, max_len, count = reduce(
                lambda acc, kv: (
                    "\n".join([acc[0], source_metadata_to_tsv(kv[1])]), max(acc[1], len(kv[1].text)), acc[2] + 1),
                iterator, ("", 0, 0))
            filename = f"blizzard2012-source-metadata-{splitIndex:03d}.tsv"
            filepath = os.path.join(self.out_dir, filename)
            with open(filepath, mode="w") as f:
                f.write(csv)
            yield count, max_len

        return rdd.sortByKey().mapPartitionsWithIndex(
            map_fn, preservesPartitioning=True).fold(
            (0, 0), lambda acc, xy: (acc[0] + xy[0], max(acc[1], xy[1])))

    def aggregate_target_metadata(self, rdd: RDD):
        def map_fn(splitIndex, iterator):
            csv, max_len, count = reduce(
                lambda acc, kv: (
                    "\n".join([acc[0], target_metadata_to_tsv(kv[1])]), max(acc[1], kv[1].n_frames), acc[2] + 1),
                iterator, ("", 0, 0))
            filename = f"blizzard2012-target-metadata-{splitIndex:03d}.tsv"
            filepath = os.path.join(self.out_dir, filename)
            with open(filepath, mode="w") as f:
                f.write(csv)
            yield count, max_len

        return rdd.sortByKey().mapPartitionsWithIndex(
            map_fn, preservesPartitioning=True).fold(
            (0, 0), lambda acc, xy: (acc[0] + xy[0], max(acc[1], xy[1])))

    def _extract_text_and_path(self, book, line, index):
        parts = line.strip().split('\t')
        if line[0] is not '#' and len(parts) == 8 and float(parts[3]) > self._min_confidence:
            wav_path = os.path.join(self.in_dir, book, 'wav', '%s.wav' % parts[0])
            labels_path = os.path.join(self.in_dir, book, 'lab', '%s.lab' % parts[0])
            text = parts[5]
            return TextAndPath(index, wav_path, labels_path, text)

    def _extract_all_text_and_path(self):
        index = 1
        for book in self.books:
            with open(os.path.join(self.in_dir, book, 'sentence_index.txt'), mode='r') as f:
                for line in f:
                    extracted = self._extract_text_and_path(book, line, index)
                    if extracted is not None:
                        yield (index, extracted)
                        index += 1

    def _load_labels(self, path):
        labels = []
        with open(os.path.join(path)) as f:
            for line in f:
                parts = line.strip().split(' ')
                if len(parts) >= 3:
                    labels.append((float(parts[0]), ' '.join(parts[2:])))
        start = 0
        end = None
        if labels[0][1] == 'sil':
            start = labels[0][0]
        if labels[-1][1] == 'sil':
            end = labels[-2][0] + self._end_buffer
        return (start, end)

    def _text_to_sequence(self, text):
        sequence = [ord(c) for c in text] + [eos]
        sequence = np.array(sequence, dtype=np.int64)
        return sequence

    def _process_target(self, paths: TextAndPath):
        wav = self.audio.load_wav(paths.wav_path)
        start_offset, end_offset = self._load_labels(paths.labels_path)
        start = int(start_offset * hparams.sample_rate)
        end = int(end_offset * hparams.sample_rate) if end_offset is not None else -1
        wav = wav[start:end]
        spectrogram = self.audio.spectrogram(wav).astype(np.float32)
        n_frames = spectrogram.shape[1]
        mel_spectrogram = self.audio.melspectrogram(wav).astype(np.float32)
        filename = f"blizzard2012-target-{paths.id:05d}.tfrecord"
        filepath = os.path.join(self.out_dir, filename)
        tfrecord.write_preprocessed_target_data(paths.id, spectrogram.T, mel_spectrogram.T, filepath)
        return TargetMetaData(paths.id, filepath, n_frames)

    def _process_source(self, paths: TextAndPath):
        sequence = self._text_to_sequence(paths.text)
        filename = f"blizzard2012-source-{paths.id:05d}.tfrecord"
        filepath = os.path.join(self.out_dir, filename)
        tfrecord.write_preprocessed_source_data2(paths.id, paths.text, sequence, paths.text, sequence, filepath)
        return SourceMetaData(paths.id, filepath, paths.text)