示例#1
0
文件: predict.py 项目: pursu/wavenet
def predict(hparams, model_dir, checkpoint_path, output_dir, test_files):
    audio = Audio(hparams)

    def predict_input_fn():
        records = tf.data.TFRecordDataset(list(test_files))
        dataset = DatasetSource(records, hparams)
        batched = dataset.make_source_and_target().group_by_batch(
            batch_size=1).arrange_for_prediction()
        return batched.dataset

    estimator = WaveNetModel(hparams, model_dir)

    predictions = map(
        lambda p: PredictedAudio(p["id"], p["key"], p["predicted_waveform"], p[
            "ground_truth_waveform"], p["mel"], p["text"]),
        estimator.predict(predict_input_fn, checkpoint_path=checkpoint_path))

    for v in predictions:
        key = v.key.decode('utf-8')
        audio_filename = f"{key}.wav"
        audio_filepath = os.path.join(output_dir, audio_filename)
        tf.logging.info(f"Saving {audio_filepath}")
        audio.save_wav(v.predicted_waveform, audio_filepath)
        png_filename = f"{key}.png"
        png_filepath = os.path.join(output_dir, png_filename)
        tf.logging.info(f"Saving {png_filepath}")
        # ToDo: pass global step
        plot_wav(png_filepath, v.predicted_waveform, v.ground_truth_waveform,
                 key, 0, v.text.decode('utf-8'), hparams.sample_rate)
class LJSpeech:
    def __init__(self, in_dir, mel_out_dir, wav_out_dir, hparams):
        self.in_dir = in_dir
        self.mel_out_dir = mel_out_dir
        self.wav_out_dir = wav_out_dir
        self.audio = Audio(hparams)

    @property
    def record_ids(self):
        return map(lambda v: str(v), range(1, 13101))

    def record_file_path(self, record_id, kind):
        assert kind in ["source", "target"]
        return os.path.join(self.mel_out_dir,
                            f"ljspeech-{kind}-{int(record_id):05d}.tfrecord")

    def text_and_path_rdd(self, sc: SparkContext):
        return sc.parallelize(self._extract_all_text_and_path())

    def process_wav(self, rdd: RDD):
        return rdd.mapValues(self._process_wav)

    def _extract_text_and_path(self, line, index):
        parts = line.strip().split('|')
        key = parts[0]
        text = parts[2]
        wav_path = os.path.join(self.in_dir, 'wavs', '%s.wav' % key)
        return TextAndPath(index, key, wav_path, None, text)

    def _extract_all_text_and_path(self):
        with open(os.path.join(self.in_dir, 'metadata.csv'),
                  mode='r',
                  encoding='utf-8') as f:
            for index, line in enumerate(f):
                extracted = self._extract_text_and_path(line, index)
                if extracted is not None:
                    yield (index, extracted)

    def _process_wav(self, paths: TextAndPath):
        wav = self.audio.load_wav(paths.wav_path)
        mel_spectrogram = self.audio.melspectrogram(wav).astype(np.float32).T
        mel_spectrogram = self.audio.normalize_mel(mel_spectrogram)

        mel_filepath = os.path.join(self.mel_out_dir, f"{paths.key}.mfbsp")
        wav_filepath = os.path.join(self.wav_out_dir, f"{paths.key}.wav")

        mel_spectrogram.tofile(mel_filepath, format="<f4")
        self.audio.save_wav(wav, wav_filepath)
示例#3
0
def predict(hparams,
            model_dir, checkpoint_path, output_dir,
            test_source_files, test_target_files):
    if hparams.half_precision:
        backend.set_floatx(tf.float16.name)
        backend.set_epsilon(1e-4)

    audio = Audio(hparams)

    def predict_input_fn():
        source = tf.data.TFRecordDataset(list(test_source_files))
        target = tf.data.TFRecordDataset(list(test_target_files))
        dataset = dataset_factory(source, target, hparams)
        batched = dataset.prepare_and_zip().group_by_batch(
            batch_size=1).move_mel_to_source()
        return batched.dataset

    estimator = model_factory(hparams, model_dir, None)

    predictions = map(
        lambda p: PredictedMel(p["id"], p["key"], p["mel"], p.get("mel_postnet"), p["mel"].shape[1], p["mel"].shape[0],
                               p["ground_truth_mel"],
                               p["alignment"], p.get("alignment2"), p.get("alignment3"), p.get("alignment4"),
                               p.get("alignment5"), p.get("alignment6"), p.get("attention2_gate_activation"),
                               p["source"], p["text"], p.get("accent_type"), p),
        estimator.predict(predict_input_fn, checkpoint_path=checkpoint_path))

    for v in predictions:
        key = v.key.decode('utf-8')
        mel_filename = f"{key}.{hparams.predicted_mel_extension}"
        mel_filepath = os.path.join(output_dir, mel_filename)
        ground_truth_mel = v.ground_truth_mel.astype(np.float32)
        predicted_mel = v.predicted_mel.astype(np.float32)
        mel_denormalized = audio.denormalize_mel(predicted_mel)

        linear_spec = audio.logmelspc_to_linearspc(mel_denormalized)
        wav = audio.griffin_lim(linear_spec)
        audio.save_wav(wav, os.path.join(output_dir, f"{key}.wav"))

        assert mel_denormalized.shape[1] == hparams.num_mels
        mel_denormalized.tofile(mel_filepath, format='<f4')
        text = v.text.decode("utf-8")
        plot_filename = f"{key}.png"
        plot_filepath = os.path.join(output_dir, plot_filename)
        alignments = [x.astype(np.float32) for x in [
            v.alignment, v.alignment2, v.alignment3, v.alignment4, v.alignment5, v.alignment6] if x is not None]

        if hparams.model == "SSNTModel":
            ssnt_metrics.save_alignment_and_log_probs([v.alignment],
                                                      [v.all_fields["log_emit_and_shift_probs"]],
                                                      [v.all_fields["log_output_prob"]],
                                                      [None],
                                                      text, v.key, 0,
                                                      os.path.join(output_dir, f"{key}_probs.png"))
            ssnt_metrics.write_prediction_result(v.id, key, text,
                                                 v.all_fields["log_emit_and_shift_probs"],
                                                 v.all_fields["log_output_prob"],
                                                 os.path.join(output_dir, f"{key}_probs.tfrecord"))
        plot_predictions(alignments, ground_truth_mel, predicted_mel, text, v.key, plot_filepath)
        prediction_filename = f"{key}.tfrecord"
        prediction_filepath = os.path.join(output_dir, prediction_filename)
        write_prediction_result(v.id, key, alignments, mel_denormalized, audio.denormalize_mel(ground_truth_mel),
                                text, v.source, prediction_filepath)
示例#4
0
class Synthesizer():
    def __init__(self, model_path, out_dir, text_file, sil_file,
                 use_griffin_lim, hparams):
        self.model_path = model_path
        self.out_dir = out_dir
        self.text_file = text_file
        self.sil_file = sil_file
        self.use_griffin_lim = use_griffin_lim
        self.hparams = hparams

        self.model = get_model(model_path, hparams)
        self.audio_class = Audio(hparams)

        if hparams.use_phone:
            from text.phones import Phones
            phone_class = Phones(hparams.phone_set_file)
            self.text_to_sequence = phone_class.text_to_sequence
        else:
            from text import text_to_sequence
            self.text_to_sequence = text_to_sequence

        # self.out_png_dir = os.path.join(self.out_dir, 'png')
        # os.makedirs(self.out_png_dir, exist_ok=True)

        self.out_wav_dir = os.path.join(self.out_dir, 'wav')
        os.makedirs(self.out_wav_dir, exist_ok=True)

    def get_inputs(self, meta_data):
        hparams = self.hparams
        SEQUENCE_ = []
        SPEAKERID_ = []
        STYLE_ID_ = []
        FILENAME_ = []
        # Prepare text input
        for i in range(len(meta_data)):
            filename = meta_data[i].strip().split('|')[1]
            print('Filename=', filename)

            phone_text = meta_data[i].strip().split('|')[-1]
            print('Text=', phone_text)

            speaker_id = int(meta_data[i].strip().split('|')[-2])
            print('SpeakerID=', speaker_id)

            sequence = np.array(
                self.text_to_sequence(meta_data[i].strip().split('|')[-1],
                                      ['english_cleaners']))  # [None, :]
            print(sequence)

            sequence = torch.autograd.Variable(
                torch.from_numpy(sequence)).to(device).long()

            speaker_id = torch.LongTensor(
                [speaker_id]).to(device) if hparams.is_multi_speakers else None

            style_id = torch.LongTensor([
                int(meta_data[i].strip().split('|')[1])
            ]).to(device) if hparams.is_multi_styles else None

            SEQUENCE_.append(sequence)
            SPEAKERID_.append(speaker_id)
            STYLE_ID_.append(style_id)
            FILENAME_.append(filename)

        return SEQUENCE_, SPEAKERID_, STYLE_ID_, FILENAME_

    def gen_mel(self, meta_data):
        SEQUENCE_, SPEAKERID_, STYLE_ID_, FILENAME_ = self.get_inputs(
            meta_data)
        MEL_OUTOUTS_ = []
        FILENAME_NEW_ = []
        # Decode text input and plot results
        with torch.no_grad():
            for i in range(len(SEQUENCE_)):
                mel_outputs, _, _ = self.model.inference(
                    text=SEQUENCE_[i],
                    spk_ids=SPEAKERID_[i],
                    utt_mels=TextMelLoader_refine(hparams.training_files,
                                                  hparams).utt_mels)
                MEL_OUTOUTS_.append(
                    mel_outputs.transpose(
                        0, 1).float().data.cpu().numpy())  # (dim, length)
                FILENAME_NEW_.append('spkid_' + str(SPEAKERID_[i].item()) +
                                     '_filenum_' + FILENAME_[i])
                print('mel_outputs.shape=', mel_outputs.shape)
        # image_path = os.path.join(self.out_png_dir, "{}_spec_stop.png".format(filename))
        # fig, axes = plt.subplots(1, 1, figsize=(8,8))
        # axes.imshow(mel_outputs, aspect='auto', origin='bottom', interpolation='none')
        # plt.savefig(image_path, format='png')
        # plt.close()
        return MEL_OUTOUTS_, FILENAME_NEW_

    def gen_wav_griffin_lim(self, mel_outputs, filename):
        grf_wav = self.audio_class.inv_mel_spectrogram(mel_outputs)
        grf_wav = self.audio_class.inv_preemphasize(grf_wav)
        wav_path = os.path.join(self.out_wav_dir, "{}-gl.wav".format(filename))
        self.audio_class.save_wav(grf_wav, wav_path)

    def inference_f(self):
        # print(meta_data['n'])
        meta_data = _read_meta_yyh(self.text_file)
        MEL_OUTOUTS_, FILENAME_NEW_ = self.gen_mel(meta_data)
        for i in range(len(MEL_OUTOUTS_)):
            np.save(
                os.path.join(
                    self.out_wav_dir,
                    "{}.npy".format(FILENAME_NEW_[i] + '_' +
                                    self.model_path.split('/')[-1])),
                MEL_OUTOUTS_[i].transpose(1, 0))
            if self.use_griffin_lim:
                self.gen_wav_griffin_lim(
                    MEL_OUTOUTS_[i],
                    FILENAME_NEW_[i] + '_' + self.model_path.split('/')[-1])

        source_dir = self.out_wav_dir
        target_dir_npy = r'../../../Melgan_pipeline/melgan_file/gen_npy'
        target_dir_wav = r'../../../Melgan_pipeline/melgan_file/gen_wav'
        target_dir_wav_16k = r'../../../Melgan_pipeline/melgan_file/gen_wav_16k'

        for file in sorted(os.listdir(source_dir)):
            temp_file_dir = source_dir + os.sep + file
            shutil.copy(temp_file_dir, target_dir_npy)
            os.remove(temp_file_dir)

        os.system(
            'cd ../../../Melgan_pipeline/melgan_file/scripts/ && python generate_from_folder_mels_4-4_npy.py --load_path=../../personalvoice/melgan16k --folder=../gen_npy --save_path=../gen_wav'
        )
        os.system(
            'cd ../../../Tools/16kConverter/ && python Convert16k.py --wave_path=../../Melgan_pipeline/melgan_file/gen_wav --output_dir=../../Melgan_pipeline/melgan_file/gen_wav_16k'
        )

        for file in sorted(os.listdir(target_dir_npy)):
            temp_file_dir = target_dir_npy + os.sep + file
            os.remove(temp_file_dir)

        for file in sorted(os.listdir(target_dir_wav)):
            temp_file_dir = target_dir_wav + os.sep + file
            os.remove(temp_file_dir)

        for file in sorted(os.listdir(target_dir_wav_16k)):
            temp_file_dir = target_dir_wav_16k + os.sep + file
            shutil.copy(temp_file_dir, source_dir)
            os.remove(temp_file_dir)

        return print('finished')
class Synthesizer():
    def __init__(self, model_path, out_dir, text_file, sil_file,
                 use_griffin_lim, gen_wavenet_fea, hparams):
        self.out_dir = out_dir
        self.text_file = text_file
        self.sil_file = sil_file
        self.use_griffin_lim = use_griffin_lim
        self.gen_wavenet_fea = gen_wavenet_fea
        self.hparams = hparams

        self.model = get_model(model_path, hparams)
        self.audio_class = Audio(hparams)

        if hparams.use_phone:
            from text.phones import Phones
            phone_class = Phones(hparams.phone_set_file)
            self.text_to_sequence = phone_class.text_to_sequence
        else:
            from text import text_to_sequence
            self.text_to_sequence = text_to_sequence

        if hparams.is_multi_speakers and not hparams.use_pretrained_spkemb:
            self.speaker_id_dict = gen_speaker_id_dict(hparams)

        self.out_png_dir = os.path.join(self.out_dir, 'png')
        os.makedirs(self.out_png_dir, exist_ok=True)
        if self.use_griffin_lim:
            self.out_wav_dir = os.path.join(self.out_dir, 'wav')
            os.makedirs(self.out_wav_dir, exist_ok=True)
        if self.gen_wavenet_fea:
            self.out_mel_dir = os.path.join(self.out_dir, 'mel')
            os.makedirs(self.out_mel_dir, exist_ok=True)

    def get_mel_gt(self, wavname):
        hparams = self.hparams
        if not hparams.load_mel:
            if hparams.use_hdf5:
                with h5py.File(hparams.hdf5_file, 'r') as h5:
                    data = h5[wavname][:]
            else:
                filename = os.path.join(hparams.wav_dir, wavname + '.wav')
                sr_t, audio = wavread(filename)
                assert sr_t == hparams.sample_rate
            audio_norm = audio / hparams.max_wav_value
            wav = self.audio_class._preemphasize(audio_norm)
            melspec = self.audio_class.melspectrogram(wav, clip_norm=True)
            melspec = torch.FloatTensor(melspec.astype(np.float32))
        else:
            if hparams.use_zip:
                with zipfile.ZipFile(hparams.zip_path, 'r') as f:
                    data = f.read(wavname)
                    melspec = np.load(io.BytesIO(data))
                melspec = torch.FloatTensor(melspec.astype(np.float32))
            elif hparams.use_hdf5:
                with h5py.File(hparams.hdf5_file, 'r') as h5:
                    melspec = h5[wavname][:]
                melspec = torch.FloatTensor(melspec.astype(np.float32))
            else:
                filename = os.path.join(hparams.wav_dir, wavname + '.npy')
                melspec = torch.from_numpy(np.load(filename))
        melspec = torch.unsqueeze(melspec, 0)
        return melspec

    def get_inputs(self, meta_data):
        hparams = self.hparams
        # Prepare text input
        # filename = meta_data['n']
        # filename = os.path.splitext(os.path.basename(filename))[0]
        filename = meta_data[0].strip().split('|')[0]
        print(meta_data[0].strip().split('|')[-1])
        print(meta_data[0].strip().split('|')[1])
        sequence = np.array(
            self.text_to_sequence(meta_data[0].strip().split('|')[-1],
                                  ['english_cleaners']))  # [None, :]
        # sequence = torch.autograd.Variable(
        #     torch.from_numpy(sequence)).cuda().long()
        print(sequence)
        sequence = torch.autograd.Variable(
            torch.from_numpy(sequence)).to(device).long()

        if hparams.is_multi_speakers:
            if hparams.use_pretrained_spkemb:
                ref_file = meta_data['r']
                spk_embedding = np.array(np.load(ref_file))
                spk_embedding = torch.autograd.Variable(
                    torch.from_numpy(spk_embedding)).to(device).float()
                inputs = (sequence, spk_embedding)
            else:
                speaker_name = filename.split('_')[0]
                speaker_id = self.speaker_id_dict[speaker_name]
                speaker_id = np.array([speaker_id])
                # speaker_id = torch.autograd.Variable(
                #     torch.from_numpy(speaker_id)).cuda().long()
                speaker_id = torch.autograd.Variable(
                    torch.from_numpy(speaker_id)).to(device).long()
                inputs = (sequence, speaker_id)

        if hparams.is_multi_styles:
            style_id = np.array([int(meta_data[0].strip().split('|')[1])])
            style_id = torch.autograd.Variable(
                torch.from_numpy(style_id)).to(device).long()
            inputs = (sequence, style_id)

        elif hparams.use_vqvae:
            ref_file = meta_data['r']
            spk_ref = self.get_mel_gt(ref_file)
            inputs = (sequence, spk_ref)
        else:
            inputs = (sequence)

        return inputs, filename

    def gen_mel(self, meta_data):
        inputs, filename = self.get_inputs(meta_data)
        speaker_id = None
        style_id = None
        spk_embedding = None
        spk_ref = None
        if self.hparams.is_multi_speakers:
            if self.hparams.use_pretrained_spkemb:
                sequence, spk_embedding = inputs
            else:
                sequence, speaker_id = inputs
        elif hparams.use_vqvae:
            sequence, spk_ref = inputs
        else:
            sequence = inputs

        if self.hparams.is_multi_styles:
            sequence, style_id = inputs

        # Decode text input and plot results
        with torch.no_grad():
            mel_outputs, gate_outputs, att_ws = self.model.inference(
                sequence,
                self.hparams,
                spk_id=speaker_id,
                style_id=style_id,
                spemb=spk_embedding,
                spk_ref=spk_ref)

            duration_list = DurationCalculator._calculate_duration(att_ws)
            print('att_ws.shape=', att_ws.shape)
            print('duration=', duration_list)
            print('duration_sum=', torch.sum(duration_list))
            print('focus_rete=',
                  DurationCalculator._calculate_focus_rete(att_ws))
            # print(mel_outputs.shape) # (length, dim)

            mel_outputs = mel_outputs.transpose(
                0, 1).float().data.cpu().numpy()  # (dim, length)
            mel_outputs_with_duration = get_duration_matrix(
                char_text_dir=self.text_file,
                duration_tensor=duration_list,
                save_mode='phone').transpose(0, 1).float().data.cpu().numpy()
            gate_outputs = gate_outputs.float().data.cpu().numpy()
            att_ws = att_ws.float().data.cpu().numpy()

        image_path = os.path.join(self.out_png_dir,
                                  "{}_att.png".format(filename))
        _plot_and_save(att_ws, image_path)

        image_path = os.path.join(self.out_png_dir,
                                  "{}_spec_stop.png".format(filename))
        fig, axes = plt.subplots(3, 1, figsize=(8, 8))
        axes[0].imshow(mel_outputs,
                       aspect='auto',
                       origin='bottom',
                       interpolation='none')
        axes[1].imshow(mel_outputs_with_duration,
                       aspect='auto',
                       origin='bottom',
                       interpolation='none')
        axes[2].scatter(range(len(gate_outputs)),
                        gate_outputs,
                        alpha=0.5,
                        color='red',
                        marker='.',
                        s=5,
                        label='predicted')
        plt.savefig(image_path, format='png')
        plt.close()

        return mel_outputs, filename

    def gen_wav_griffin_lim(self, mel_outputs, filename):
        grf_wav = self.audio_class.inv_mel_spectrogram(mel_outputs)
        grf_wav = self.audio_class.inv_preemphasize(grf_wav)
        wav_path = os.path.join(self.out_wav_dir, "{}-gl.wav".format(filename))
        self.audio_class.save_wav(grf_wav, wav_path)

    def gen_wavenet_feature(self, mel_outputs, filename, add_end_sil=True):
        # denormalize
        mel = self.audio_class._denormalize(mel_outputs)
        # normalize to 0-1
        mel = np.clip(((mel - self.audio_class.hparams.min_level_db) /
                       (-self.audio_class.hparams.min_level_db)), 0, 1)

        mel = mel.T.astype(np.float32)

        frame_size = 200
        SILSEG = 0.3
        SAMPLING = 16000
        sil_samples = int(SILSEG * SAMPLING)
        sil_frames = int(sil_samples / frame_size)
        sil_data, _ = soundfile.read(self.sil_file)
        sil_data = sil_data[:sil_samples]

        sil_mel_spec, _ = self.audio_class._magnitude_spectrogram(
            sil_data, clip_norm=True)
        sil_mel_spec = (sil_mel_spec + 4.0) / 8.0

        pad_mel_data = np.concatenate((sil_mel_spec[:sil_frames], mel), axis=0)
        if add_end_sil:
            pad_mel_data = np.concatenate(
                (pad_mel_data, sil_mel_spec[:sil_frames]), axis=0)
        out_mel_file = os.path.join(self.out_mel_dir,
                                    '{}-wn.mel'.format(filename))
        save_htk_data(pad_mel_data, out_mel_file)

    def inference_f(self):
        # print(meta_data['n'])
        meta_data = _read_meta_yyh(self.text_file)
        mel_outputs, filename = self.gen_mel(meta_data)
        print('my_mel_outputs=', mel_outputs)
        print('my_mel_outputs_max=', np.max(mel_outputs))
        print('my_mel_outputs_min=', np.min(mel_outputs))
        mel_outputs = np.load(r'../out_0.npy').transpose(1, 0)
        mel_outputs = mel_outputs * 8.0 - 4.0
        print('his_mel_outputs=', mel_outputs)
        print('his_mel_outputs_max=', np.max(mel_outputs))
        print('his_mel_outputs_min=', np.min(mel_outputs))
        if self.use_griffin_lim:
            self.gen_wav_griffin_lim(mel_outputs, filename)
        if self.gen_wavenet_fea:
            self.gen_wavenet_feature(mel_outputs, filename)
        return filename

    def inference(self):
        all_meta_data = _read_meta(self.text_file, hparams.meta_format)
        list(map(self.inference_f, all_meta_data))