Beispiel #1
0
class Synthesizer():
    def __init__(self, model_path, out_dir, text_file, sil_file,
                 use_griffin_lim, hparams):
        self.model_path = model_path
        self.out_dir = out_dir
        self.text_file = text_file
        self.sil_file = sil_file
        self.use_griffin_lim = use_griffin_lim
        self.hparams = hparams

        self.model = get_model(model_path, hparams)
        self.audio_class = Audio(hparams)

        if hparams.use_phone:
            from text.phones import Phones
            phone_class = Phones(hparams.phone_set_file)
            self.text_to_sequence = phone_class.text_to_sequence
        else:
            from text import text_to_sequence
            self.text_to_sequence = text_to_sequence

        # self.out_png_dir = os.path.join(self.out_dir, 'png')
        # os.makedirs(self.out_png_dir, exist_ok=True)

        self.out_wav_dir = os.path.join(self.out_dir, 'wav')
        os.makedirs(self.out_wav_dir, exist_ok=True)

    def get_inputs(self, meta_data):
        hparams = self.hparams
        SEQUENCE_ = []
        SPEAKERID_ = []
        STYLE_ID_ = []
        FILENAME_ = []
        # Prepare text input
        for i in range(len(meta_data)):
            filename = meta_data[i].strip().split('|')[1]
            print('Filename=', filename)

            phone_text = meta_data[i].strip().split('|')[-1]
            print('Text=', phone_text)

            speaker_id = int(meta_data[i].strip().split('|')[-2])
            print('SpeakerID=', speaker_id)

            sequence = np.array(
                self.text_to_sequence(meta_data[i].strip().split('|')[-1],
                                      ['english_cleaners']))  # [None, :]
            print(sequence)

            sequence = torch.autograd.Variable(
                torch.from_numpy(sequence)).to(device).long()

            speaker_id = torch.LongTensor(
                [speaker_id]).to(device) if hparams.is_multi_speakers else None

            style_id = torch.LongTensor([
                int(meta_data[i].strip().split('|')[1])
            ]).to(device) if hparams.is_multi_styles else None

            SEQUENCE_.append(sequence)
            SPEAKERID_.append(speaker_id)
            STYLE_ID_.append(style_id)
            FILENAME_.append(filename)

        return SEQUENCE_, SPEAKERID_, STYLE_ID_, FILENAME_

    def gen_mel(self, meta_data):
        SEQUENCE_, SPEAKERID_, STYLE_ID_, FILENAME_ = self.get_inputs(
            meta_data)
        MEL_OUTOUTS_ = []
        FILENAME_NEW_ = []
        # Decode text input and plot results
        with torch.no_grad():
            for i in range(len(SEQUENCE_)):
                mel_outputs, _, _ = self.model.inference(
                    text=SEQUENCE_[i],
                    spk_ids=SPEAKERID_[i],
                    utt_mels=TextMelLoader_refine(hparams.training_files,
                                                  hparams).utt_mels)
                MEL_OUTOUTS_.append(
                    mel_outputs.transpose(
                        0, 1).float().data.cpu().numpy())  # (dim, length)
                FILENAME_NEW_.append('spkid_' + str(SPEAKERID_[i].item()) +
                                     '_filenum_' + FILENAME_[i])
                print('mel_outputs.shape=', mel_outputs.shape)
        # image_path = os.path.join(self.out_png_dir, "{}_spec_stop.png".format(filename))
        # fig, axes = plt.subplots(1, 1, figsize=(8,8))
        # axes.imshow(mel_outputs, aspect='auto', origin='bottom', interpolation='none')
        # plt.savefig(image_path, format='png')
        # plt.close()
        return MEL_OUTOUTS_, FILENAME_NEW_

    def gen_wav_griffin_lim(self, mel_outputs, filename):
        grf_wav = self.audio_class.inv_mel_spectrogram(mel_outputs)
        grf_wav = self.audio_class.inv_preemphasize(grf_wav)
        wav_path = os.path.join(self.out_wav_dir, "{}-gl.wav".format(filename))
        self.audio_class.save_wav(grf_wav, wav_path)

    def inference_f(self):
        # print(meta_data['n'])
        meta_data = _read_meta_yyh(self.text_file)
        MEL_OUTOUTS_, FILENAME_NEW_ = self.gen_mel(meta_data)
        for i in range(len(MEL_OUTOUTS_)):
            np.save(
                os.path.join(
                    self.out_wav_dir,
                    "{}.npy".format(FILENAME_NEW_[i] + '_' +
                                    self.model_path.split('/')[-1])),
                MEL_OUTOUTS_[i].transpose(1, 0))
            if self.use_griffin_lim:
                self.gen_wav_griffin_lim(
                    MEL_OUTOUTS_[i],
                    FILENAME_NEW_[i] + '_' + self.model_path.split('/')[-1])

        source_dir = self.out_wav_dir
        target_dir_npy = r'../../../Melgan_pipeline/melgan_file/gen_npy'
        target_dir_wav = r'../../../Melgan_pipeline/melgan_file/gen_wav'
        target_dir_wav_16k = r'../../../Melgan_pipeline/melgan_file/gen_wav_16k'

        for file in sorted(os.listdir(source_dir)):
            temp_file_dir = source_dir + os.sep + file
            shutil.copy(temp_file_dir, target_dir_npy)
            os.remove(temp_file_dir)

        os.system(
            'cd ../../../Melgan_pipeline/melgan_file/scripts/ && python generate_from_folder_mels_4-4_npy.py --load_path=../../personalvoice/melgan16k --folder=../gen_npy --save_path=../gen_wav'
        )
        os.system(
            'cd ../../../Tools/16kConverter/ && python Convert16k.py --wave_path=../../Melgan_pipeline/melgan_file/gen_wav --output_dir=../../Melgan_pipeline/melgan_file/gen_wav_16k'
        )

        for file in sorted(os.listdir(target_dir_npy)):
            temp_file_dir = target_dir_npy + os.sep + file
            os.remove(temp_file_dir)

        for file in sorted(os.listdir(target_dir_wav)):
            temp_file_dir = target_dir_wav + os.sep + file
            os.remove(temp_file_dir)

        for file in sorted(os.listdir(target_dir_wav_16k)):
            temp_file_dir = target_dir_wav_16k + os.sep + file
            shutil.copy(temp_file_dir, source_dir)
            os.remove(temp_file_dir)

        return print('finished')
class Synthesizer():
    def __init__(self, model_path, out_dir, text_file, sil_file,
                 use_griffin_lim, gen_wavenet_fea, hparams):
        self.out_dir = out_dir
        self.text_file = text_file
        self.sil_file = sil_file
        self.use_griffin_lim = use_griffin_lim
        self.gen_wavenet_fea = gen_wavenet_fea
        self.hparams = hparams

        self.model = get_model(model_path, hparams)
        self.audio_class = Audio(hparams)

        if hparams.use_phone:
            from text.phones import Phones
            phone_class = Phones(hparams.phone_set_file)
            self.text_to_sequence = phone_class.text_to_sequence
        else:
            from text import text_to_sequence
            self.text_to_sequence = text_to_sequence

        if hparams.is_multi_speakers and not hparams.use_pretrained_spkemb:
            self.speaker_id_dict = gen_speaker_id_dict(hparams)

        self.out_png_dir = os.path.join(self.out_dir, 'png')
        os.makedirs(self.out_png_dir, exist_ok=True)
        if self.use_griffin_lim:
            self.out_wav_dir = os.path.join(self.out_dir, 'wav')
            os.makedirs(self.out_wav_dir, exist_ok=True)
        if self.gen_wavenet_fea:
            self.out_mel_dir = os.path.join(self.out_dir, 'mel')
            os.makedirs(self.out_mel_dir, exist_ok=True)

    def get_mel_gt(self, wavname):
        hparams = self.hparams
        if not hparams.load_mel:
            if hparams.use_hdf5:
                with h5py.File(hparams.hdf5_file, 'r') as h5:
                    data = h5[wavname][:]
            else:
                filename = os.path.join(hparams.wav_dir, wavname + '.wav')
                sr_t, audio = wavread(filename)
                assert sr_t == hparams.sample_rate
            audio_norm = audio / hparams.max_wav_value
            wav = self.audio_class._preemphasize(audio_norm)
            melspec = self.audio_class.melspectrogram(wav, clip_norm=True)
            melspec = torch.FloatTensor(melspec.astype(np.float32))
        else:
            if hparams.use_zip:
                with zipfile.ZipFile(hparams.zip_path, 'r') as f:
                    data = f.read(wavname)
                    melspec = np.load(io.BytesIO(data))
                melspec = torch.FloatTensor(melspec.astype(np.float32))
            elif hparams.use_hdf5:
                with h5py.File(hparams.hdf5_file, 'r') as h5:
                    melspec = h5[wavname][:]
                melspec = torch.FloatTensor(melspec.astype(np.float32))
            else:
                filename = os.path.join(hparams.wav_dir, wavname + '.npy')
                melspec = torch.from_numpy(np.load(filename))
        melspec = torch.unsqueeze(melspec, 0)
        return melspec

    def get_inputs(self, meta_data):
        hparams = self.hparams
        # Prepare text input
        # filename = meta_data['n']
        # filename = os.path.splitext(os.path.basename(filename))[0]
        filename = meta_data[0].strip().split('|')[0]
        print(meta_data[0].strip().split('|')[-1])
        print(meta_data[0].strip().split('|')[1])
        sequence = np.array(
            self.text_to_sequence(meta_data[0].strip().split('|')[-1],
                                  ['english_cleaners']))  # [None, :]
        # sequence = torch.autograd.Variable(
        #     torch.from_numpy(sequence)).cuda().long()
        print(sequence)
        sequence = torch.autograd.Variable(
            torch.from_numpy(sequence)).to(device).long()

        if hparams.is_multi_speakers:
            if hparams.use_pretrained_spkemb:
                ref_file = meta_data['r']
                spk_embedding = np.array(np.load(ref_file))
                spk_embedding = torch.autograd.Variable(
                    torch.from_numpy(spk_embedding)).to(device).float()
                inputs = (sequence, spk_embedding)
            else:
                speaker_name = filename.split('_')[0]
                speaker_id = self.speaker_id_dict[speaker_name]
                speaker_id = np.array([speaker_id])
                # speaker_id = torch.autograd.Variable(
                #     torch.from_numpy(speaker_id)).cuda().long()
                speaker_id = torch.autograd.Variable(
                    torch.from_numpy(speaker_id)).to(device).long()
                inputs = (sequence, speaker_id)

        if hparams.is_multi_styles:
            style_id = np.array([int(meta_data[0].strip().split('|')[1])])
            style_id = torch.autograd.Variable(
                torch.from_numpy(style_id)).to(device).long()
            inputs = (sequence, style_id)

        elif hparams.use_vqvae:
            ref_file = meta_data['r']
            spk_ref = self.get_mel_gt(ref_file)
            inputs = (sequence, spk_ref)
        else:
            inputs = (sequence)

        return inputs, filename

    def gen_mel(self, meta_data):
        inputs, filename = self.get_inputs(meta_data)
        speaker_id = None
        style_id = None
        spk_embedding = None
        spk_ref = None
        if self.hparams.is_multi_speakers:
            if self.hparams.use_pretrained_spkemb:
                sequence, spk_embedding = inputs
            else:
                sequence, speaker_id = inputs
        elif hparams.use_vqvae:
            sequence, spk_ref = inputs
        else:
            sequence = inputs

        if self.hparams.is_multi_styles:
            sequence, style_id = inputs

        # Decode text input and plot results
        with torch.no_grad():
            mel_outputs, gate_outputs, att_ws = self.model.inference(
                sequence,
                self.hparams,
                spk_id=speaker_id,
                style_id=style_id,
                spemb=spk_embedding,
                spk_ref=spk_ref)

            duration_list = DurationCalculator._calculate_duration(att_ws)
            print('att_ws.shape=', att_ws.shape)
            print('duration=', duration_list)
            print('duration_sum=', torch.sum(duration_list))
            print('focus_rete=',
                  DurationCalculator._calculate_focus_rete(att_ws))
            # print(mel_outputs.shape) # (length, dim)

            mel_outputs = mel_outputs.transpose(
                0, 1).float().data.cpu().numpy()  # (dim, length)
            mel_outputs_with_duration = get_duration_matrix(
                char_text_dir=self.text_file,
                duration_tensor=duration_list,
                save_mode='phone').transpose(0, 1).float().data.cpu().numpy()
            gate_outputs = gate_outputs.float().data.cpu().numpy()
            att_ws = att_ws.float().data.cpu().numpy()

        image_path = os.path.join(self.out_png_dir,
                                  "{}_att.png".format(filename))
        _plot_and_save(att_ws, image_path)

        image_path = os.path.join(self.out_png_dir,
                                  "{}_spec_stop.png".format(filename))
        fig, axes = plt.subplots(3, 1, figsize=(8, 8))
        axes[0].imshow(mel_outputs,
                       aspect='auto',
                       origin='bottom',
                       interpolation='none')
        axes[1].imshow(mel_outputs_with_duration,
                       aspect='auto',
                       origin='bottom',
                       interpolation='none')
        axes[2].scatter(range(len(gate_outputs)),
                        gate_outputs,
                        alpha=0.5,
                        color='red',
                        marker='.',
                        s=5,
                        label='predicted')
        plt.savefig(image_path, format='png')
        plt.close()

        return mel_outputs, filename

    def gen_wav_griffin_lim(self, mel_outputs, filename):
        grf_wav = self.audio_class.inv_mel_spectrogram(mel_outputs)
        grf_wav = self.audio_class.inv_preemphasize(grf_wav)
        wav_path = os.path.join(self.out_wav_dir, "{}-gl.wav".format(filename))
        self.audio_class.save_wav(grf_wav, wav_path)

    def gen_wavenet_feature(self, mel_outputs, filename, add_end_sil=True):
        # denormalize
        mel = self.audio_class._denormalize(mel_outputs)
        # normalize to 0-1
        mel = np.clip(((mel - self.audio_class.hparams.min_level_db) /
                       (-self.audio_class.hparams.min_level_db)), 0, 1)

        mel = mel.T.astype(np.float32)

        frame_size = 200
        SILSEG = 0.3
        SAMPLING = 16000
        sil_samples = int(SILSEG * SAMPLING)
        sil_frames = int(sil_samples / frame_size)
        sil_data, _ = soundfile.read(self.sil_file)
        sil_data = sil_data[:sil_samples]

        sil_mel_spec, _ = self.audio_class._magnitude_spectrogram(
            sil_data, clip_norm=True)
        sil_mel_spec = (sil_mel_spec + 4.0) / 8.0

        pad_mel_data = np.concatenate((sil_mel_spec[:sil_frames], mel), axis=0)
        if add_end_sil:
            pad_mel_data = np.concatenate(
                (pad_mel_data, sil_mel_spec[:sil_frames]), axis=0)
        out_mel_file = os.path.join(self.out_mel_dir,
                                    '{}-wn.mel'.format(filename))
        save_htk_data(pad_mel_data, out_mel_file)

    def inference_f(self):
        # print(meta_data['n'])
        meta_data = _read_meta_yyh(self.text_file)
        mel_outputs, filename = self.gen_mel(meta_data)
        print('my_mel_outputs=', mel_outputs)
        print('my_mel_outputs_max=', np.max(mel_outputs))
        print('my_mel_outputs_min=', np.min(mel_outputs))
        mel_outputs = np.load(r'../out_0.npy').transpose(1, 0)
        mel_outputs = mel_outputs * 8.0 - 4.0
        print('his_mel_outputs=', mel_outputs)
        print('his_mel_outputs_max=', np.max(mel_outputs))
        print('his_mel_outputs_min=', np.min(mel_outputs))
        if self.use_griffin_lim:
            self.gen_wav_griffin_lim(mel_outputs, filename)
        if self.gen_wavenet_fea:
            self.gen_wavenet_feature(mel_outputs, filename)
        return filename

    def inference(self):
        all_meta_data = _read_meta(self.text_file, hparams.meta_format)
        list(map(self.inference_f, all_meta_data))