def predict(hparams, model_dir, checkpoint_path, output_dir, test_files): audio = Audio(hparams) def predict_input_fn(): records = tf.data.TFRecordDataset(list(test_files)) dataset = DatasetSource(records, hparams) batched = dataset.make_source_and_target().group_by_batch( batch_size=1).arrange_for_prediction() return batched.dataset estimator = WaveNetModel(hparams, model_dir) predictions = map( lambda p: PredictedAudio(p["id"], p["key"], p["predicted_waveform"], p[ "ground_truth_waveform"], p["mel"], p["text"]), estimator.predict(predict_input_fn, checkpoint_path=checkpoint_path)) for v in predictions: key = v.key.decode('utf-8') audio_filename = f"{key}.wav" audio_filepath = os.path.join(output_dir, audio_filename) tf.logging.info(f"Saving {audio_filepath}") audio.save_wav(v.predicted_waveform, audio_filepath) png_filename = f"{key}.png" png_filepath = os.path.join(output_dir, png_filename) tf.logging.info(f"Saving {png_filepath}") # ToDo: pass global step plot_wav(png_filepath, v.predicted_waveform, v.ground_truth_waveform, key, 0, v.text.decode('utf-8'), hparams.sample_rate)
class LJSpeech: def __init__(self, in_dir, mel_out_dir, wav_out_dir, hparams): self.in_dir = in_dir self.mel_out_dir = mel_out_dir self.wav_out_dir = wav_out_dir self.audio = Audio(hparams) @property def record_ids(self): return map(lambda v: str(v), range(1, 13101)) def record_file_path(self, record_id, kind): assert kind in ["source", "target"] return os.path.join(self.mel_out_dir, f"ljspeech-{kind}-{int(record_id):05d}.tfrecord") def text_and_path_rdd(self, sc: SparkContext): return sc.parallelize(self._extract_all_text_and_path()) def process_wav(self, rdd: RDD): return rdd.mapValues(self._process_wav) def _extract_text_and_path(self, line, index): parts = line.strip().split('|') key = parts[0] text = parts[2] wav_path = os.path.join(self.in_dir, 'wavs', '%s.wav' % key) return TextAndPath(index, key, wav_path, None, text) def _extract_all_text_and_path(self): with open(os.path.join(self.in_dir, 'metadata.csv'), mode='r', encoding='utf-8') as f: for index, line in enumerate(f): extracted = self._extract_text_and_path(line, index) if extracted is not None: yield (index, extracted) def _process_wav(self, paths: TextAndPath): wav = self.audio.load_wav(paths.wav_path) mel_spectrogram = self.audio.melspectrogram(wav).astype(np.float32).T mel_spectrogram = self.audio.normalize_mel(mel_spectrogram) mel_filepath = os.path.join(self.mel_out_dir, f"{paths.key}.mfbsp") wav_filepath = os.path.join(self.wav_out_dir, f"{paths.key}.wav") mel_spectrogram.tofile(mel_filepath, format="<f4") self.audio.save_wav(wav, wav_filepath)
def predict(hparams, model_dir, checkpoint_path, output_dir, test_source_files, test_target_files): if hparams.half_precision: backend.set_floatx(tf.float16.name) backend.set_epsilon(1e-4) audio = Audio(hparams) def predict_input_fn(): source = tf.data.TFRecordDataset(list(test_source_files)) target = tf.data.TFRecordDataset(list(test_target_files)) dataset = dataset_factory(source, target, hparams) batched = dataset.prepare_and_zip().group_by_batch( batch_size=1).move_mel_to_source() return batched.dataset estimator = model_factory(hparams, model_dir, None) predictions = map( lambda p: PredictedMel(p["id"], p["key"], p["mel"], p.get("mel_postnet"), p["mel"].shape[1], p["mel"].shape[0], p["ground_truth_mel"], p["alignment"], p.get("alignment2"), p.get("alignment3"), p.get("alignment4"), p.get("alignment5"), p.get("alignment6"), p.get("attention2_gate_activation"), p["source"], p["text"], p.get("accent_type"), p), estimator.predict(predict_input_fn, checkpoint_path=checkpoint_path)) for v in predictions: key = v.key.decode('utf-8') mel_filename = f"{key}.{hparams.predicted_mel_extension}" mel_filepath = os.path.join(output_dir, mel_filename) ground_truth_mel = v.ground_truth_mel.astype(np.float32) predicted_mel = v.predicted_mel.astype(np.float32) mel_denormalized = audio.denormalize_mel(predicted_mel) linear_spec = audio.logmelspc_to_linearspc(mel_denormalized) wav = audio.griffin_lim(linear_spec) audio.save_wav(wav, os.path.join(output_dir, f"{key}.wav")) assert mel_denormalized.shape[1] == hparams.num_mels mel_denormalized.tofile(mel_filepath, format='<f4') text = v.text.decode("utf-8") plot_filename = f"{key}.png" plot_filepath = os.path.join(output_dir, plot_filename) alignments = [x.astype(np.float32) for x in [ v.alignment, v.alignment2, v.alignment3, v.alignment4, v.alignment5, v.alignment6] if x is not None] if hparams.model == "SSNTModel": ssnt_metrics.save_alignment_and_log_probs([v.alignment], [v.all_fields["log_emit_and_shift_probs"]], [v.all_fields["log_output_prob"]], [None], text, v.key, 0, os.path.join(output_dir, f"{key}_probs.png")) ssnt_metrics.write_prediction_result(v.id, key, text, v.all_fields["log_emit_and_shift_probs"], v.all_fields["log_output_prob"], os.path.join(output_dir, f"{key}_probs.tfrecord")) plot_predictions(alignments, ground_truth_mel, predicted_mel, text, v.key, plot_filepath) prediction_filename = f"{key}.tfrecord" prediction_filepath = os.path.join(output_dir, prediction_filename) write_prediction_result(v.id, key, alignments, mel_denormalized, audio.denormalize_mel(ground_truth_mel), text, v.source, prediction_filepath)
class Synthesizer(): def __init__(self, model_path, out_dir, text_file, sil_file, use_griffin_lim, hparams): self.model_path = model_path self.out_dir = out_dir self.text_file = text_file self.sil_file = sil_file self.use_griffin_lim = use_griffin_lim self.hparams = hparams self.model = get_model(model_path, hparams) self.audio_class = Audio(hparams) if hparams.use_phone: from text.phones import Phones phone_class = Phones(hparams.phone_set_file) self.text_to_sequence = phone_class.text_to_sequence else: from text import text_to_sequence self.text_to_sequence = text_to_sequence # self.out_png_dir = os.path.join(self.out_dir, 'png') # os.makedirs(self.out_png_dir, exist_ok=True) self.out_wav_dir = os.path.join(self.out_dir, 'wav') os.makedirs(self.out_wav_dir, exist_ok=True) def get_inputs(self, meta_data): hparams = self.hparams SEQUENCE_ = [] SPEAKERID_ = [] STYLE_ID_ = [] FILENAME_ = [] # Prepare text input for i in range(len(meta_data)): filename = meta_data[i].strip().split('|')[1] print('Filename=', filename) phone_text = meta_data[i].strip().split('|')[-1] print('Text=', phone_text) speaker_id = int(meta_data[i].strip().split('|')[-2]) print('SpeakerID=', speaker_id) sequence = np.array( self.text_to_sequence(meta_data[i].strip().split('|')[-1], ['english_cleaners'])) # [None, :] print(sequence) sequence = torch.autograd.Variable( torch.from_numpy(sequence)).to(device).long() speaker_id = torch.LongTensor( [speaker_id]).to(device) if hparams.is_multi_speakers else None style_id = torch.LongTensor([ int(meta_data[i].strip().split('|')[1]) ]).to(device) if hparams.is_multi_styles else None SEQUENCE_.append(sequence) SPEAKERID_.append(speaker_id) STYLE_ID_.append(style_id) FILENAME_.append(filename) return SEQUENCE_, SPEAKERID_, STYLE_ID_, FILENAME_ def gen_mel(self, meta_data): SEQUENCE_, SPEAKERID_, STYLE_ID_, FILENAME_ = self.get_inputs( meta_data) MEL_OUTOUTS_ = [] FILENAME_NEW_ = [] # Decode text input and plot results with torch.no_grad(): for i in range(len(SEQUENCE_)): mel_outputs, _, _ = self.model.inference( text=SEQUENCE_[i], spk_ids=SPEAKERID_[i], utt_mels=TextMelLoader_refine(hparams.training_files, hparams).utt_mels) MEL_OUTOUTS_.append( mel_outputs.transpose( 0, 1).float().data.cpu().numpy()) # (dim, length) FILENAME_NEW_.append('spkid_' + str(SPEAKERID_[i].item()) + '_filenum_' + FILENAME_[i]) print('mel_outputs.shape=', mel_outputs.shape) # image_path = os.path.join(self.out_png_dir, "{}_spec_stop.png".format(filename)) # fig, axes = plt.subplots(1, 1, figsize=(8,8)) # axes.imshow(mel_outputs, aspect='auto', origin='bottom', interpolation='none') # plt.savefig(image_path, format='png') # plt.close() return MEL_OUTOUTS_, FILENAME_NEW_ def gen_wav_griffin_lim(self, mel_outputs, filename): grf_wav = self.audio_class.inv_mel_spectrogram(mel_outputs) grf_wav = self.audio_class.inv_preemphasize(grf_wav) wav_path = os.path.join(self.out_wav_dir, "{}-gl.wav".format(filename)) self.audio_class.save_wav(grf_wav, wav_path) def inference_f(self): # print(meta_data['n']) meta_data = _read_meta_yyh(self.text_file) MEL_OUTOUTS_, FILENAME_NEW_ = self.gen_mel(meta_data) for i in range(len(MEL_OUTOUTS_)): np.save( os.path.join( self.out_wav_dir, "{}.npy".format(FILENAME_NEW_[i] + '_' + self.model_path.split('/')[-1])), MEL_OUTOUTS_[i].transpose(1, 0)) if self.use_griffin_lim: self.gen_wav_griffin_lim( MEL_OUTOUTS_[i], FILENAME_NEW_[i] + '_' + self.model_path.split('/')[-1]) source_dir = self.out_wav_dir target_dir_npy = r'../../../Melgan_pipeline/melgan_file/gen_npy' target_dir_wav = r'../../../Melgan_pipeline/melgan_file/gen_wav' target_dir_wav_16k = r'../../../Melgan_pipeline/melgan_file/gen_wav_16k' for file in sorted(os.listdir(source_dir)): temp_file_dir = source_dir + os.sep + file shutil.copy(temp_file_dir, target_dir_npy) os.remove(temp_file_dir) os.system( 'cd ../../../Melgan_pipeline/melgan_file/scripts/ && python generate_from_folder_mels_4-4_npy.py --load_path=../../personalvoice/melgan16k --folder=../gen_npy --save_path=../gen_wav' ) os.system( 'cd ../../../Tools/16kConverter/ && python Convert16k.py --wave_path=../../Melgan_pipeline/melgan_file/gen_wav --output_dir=../../Melgan_pipeline/melgan_file/gen_wav_16k' ) for file in sorted(os.listdir(target_dir_npy)): temp_file_dir = target_dir_npy + os.sep + file os.remove(temp_file_dir) for file in sorted(os.listdir(target_dir_wav)): temp_file_dir = target_dir_wav + os.sep + file os.remove(temp_file_dir) for file in sorted(os.listdir(target_dir_wav_16k)): temp_file_dir = target_dir_wav_16k + os.sep + file shutil.copy(temp_file_dir, source_dir) os.remove(temp_file_dir) return print('finished')
class Synthesizer(): def __init__(self, model_path, out_dir, text_file, sil_file, use_griffin_lim, gen_wavenet_fea, hparams): self.out_dir = out_dir self.text_file = text_file self.sil_file = sil_file self.use_griffin_lim = use_griffin_lim self.gen_wavenet_fea = gen_wavenet_fea self.hparams = hparams self.model = get_model(model_path, hparams) self.audio_class = Audio(hparams) if hparams.use_phone: from text.phones import Phones phone_class = Phones(hparams.phone_set_file) self.text_to_sequence = phone_class.text_to_sequence else: from text import text_to_sequence self.text_to_sequence = text_to_sequence if hparams.is_multi_speakers and not hparams.use_pretrained_spkemb: self.speaker_id_dict = gen_speaker_id_dict(hparams) self.out_png_dir = os.path.join(self.out_dir, 'png') os.makedirs(self.out_png_dir, exist_ok=True) if self.use_griffin_lim: self.out_wav_dir = os.path.join(self.out_dir, 'wav') os.makedirs(self.out_wav_dir, exist_ok=True) if self.gen_wavenet_fea: self.out_mel_dir = os.path.join(self.out_dir, 'mel') os.makedirs(self.out_mel_dir, exist_ok=True) def get_mel_gt(self, wavname): hparams = self.hparams if not hparams.load_mel: if hparams.use_hdf5: with h5py.File(hparams.hdf5_file, 'r') as h5: data = h5[wavname][:] else: filename = os.path.join(hparams.wav_dir, wavname + '.wav') sr_t, audio = wavread(filename) assert sr_t == hparams.sample_rate audio_norm = audio / hparams.max_wav_value wav = self.audio_class._preemphasize(audio_norm) melspec = self.audio_class.melspectrogram(wav, clip_norm=True) melspec = torch.FloatTensor(melspec.astype(np.float32)) else: if hparams.use_zip: with zipfile.ZipFile(hparams.zip_path, 'r') as f: data = f.read(wavname) melspec = np.load(io.BytesIO(data)) melspec = torch.FloatTensor(melspec.astype(np.float32)) elif hparams.use_hdf5: with h5py.File(hparams.hdf5_file, 'r') as h5: melspec = h5[wavname][:] melspec = torch.FloatTensor(melspec.astype(np.float32)) else: filename = os.path.join(hparams.wav_dir, wavname + '.npy') melspec = torch.from_numpy(np.load(filename)) melspec = torch.unsqueeze(melspec, 0) return melspec def get_inputs(self, meta_data): hparams = self.hparams # Prepare text input # filename = meta_data['n'] # filename = os.path.splitext(os.path.basename(filename))[0] filename = meta_data[0].strip().split('|')[0] print(meta_data[0].strip().split('|')[-1]) print(meta_data[0].strip().split('|')[1]) sequence = np.array( self.text_to_sequence(meta_data[0].strip().split('|')[-1], ['english_cleaners'])) # [None, :] # sequence = torch.autograd.Variable( # torch.from_numpy(sequence)).cuda().long() print(sequence) sequence = torch.autograd.Variable( torch.from_numpy(sequence)).to(device).long() if hparams.is_multi_speakers: if hparams.use_pretrained_spkemb: ref_file = meta_data['r'] spk_embedding = np.array(np.load(ref_file)) spk_embedding = torch.autograd.Variable( torch.from_numpy(spk_embedding)).to(device).float() inputs = (sequence, spk_embedding) else: speaker_name = filename.split('_')[0] speaker_id = self.speaker_id_dict[speaker_name] speaker_id = np.array([speaker_id]) # speaker_id = torch.autograd.Variable( # torch.from_numpy(speaker_id)).cuda().long() speaker_id = torch.autograd.Variable( torch.from_numpy(speaker_id)).to(device).long() inputs = (sequence, speaker_id) if hparams.is_multi_styles: style_id = np.array([int(meta_data[0].strip().split('|')[1])]) style_id = torch.autograd.Variable( torch.from_numpy(style_id)).to(device).long() inputs = (sequence, style_id) elif hparams.use_vqvae: ref_file = meta_data['r'] spk_ref = self.get_mel_gt(ref_file) inputs = (sequence, spk_ref) else: inputs = (sequence) return inputs, filename def gen_mel(self, meta_data): inputs, filename = self.get_inputs(meta_data) speaker_id = None style_id = None spk_embedding = None spk_ref = None if self.hparams.is_multi_speakers: if self.hparams.use_pretrained_spkemb: sequence, spk_embedding = inputs else: sequence, speaker_id = inputs elif hparams.use_vqvae: sequence, spk_ref = inputs else: sequence = inputs if self.hparams.is_multi_styles: sequence, style_id = inputs # Decode text input and plot results with torch.no_grad(): mel_outputs, gate_outputs, att_ws = self.model.inference( sequence, self.hparams, spk_id=speaker_id, style_id=style_id, spemb=spk_embedding, spk_ref=spk_ref) duration_list = DurationCalculator._calculate_duration(att_ws) print('att_ws.shape=', att_ws.shape) print('duration=', duration_list) print('duration_sum=', torch.sum(duration_list)) print('focus_rete=', DurationCalculator._calculate_focus_rete(att_ws)) # print(mel_outputs.shape) # (length, dim) mel_outputs = mel_outputs.transpose( 0, 1).float().data.cpu().numpy() # (dim, length) mel_outputs_with_duration = get_duration_matrix( char_text_dir=self.text_file, duration_tensor=duration_list, save_mode='phone').transpose(0, 1).float().data.cpu().numpy() gate_outputs = gate_outputs.float().data.cpu().numpy() att_ws = att_ws.float().data.cpu().numpy() image_path = os.path.join(self.out_png_dir, "{}_att.png".format(filename)) _plot_and_save(att_ws, image_path) image_path = os.path.join(self.out_png_dir, "{}_spec_stop.png".format(filename)) fig, axes = plt.subplots(3, 1, figsize=(8, 8)) axes[0].imshow(mel_outputs, aspect='auto', origin='bottom', interpolation='none') axes[1].imshow(mel_outputs_with_duration, aspect='auto', origin='bottom', interpolation='none') axes[2].scatter(range(len(gate_outputs)), gate_outputs, alpha=0.5, color='red', marker='.', s=5, label='predicted') plt.savefig(image_path, format='png') plt.close() return mel_outputs, filename def gen_wav_griffin_lim(self, mel_outputs, filename): grf_wav = self.audio_class.inv_mel_spectrogram(mel_outputs) grf_wav = self.audio_class.inv_preemphasize(grf_wav) wav_path = os.path.join(self.out_wav_dir, "{}-gl.wav".format(filename)) self.audio_class.save_wav(grf_wav, wav_path) def gen_wavenet_feature(self, mel_outputs, filename, add_end_sil=True): # denormalize mel = self.audio_class._denormalize(mel_outputs) # normalize to 0-1 mel = np.clip(((mel - self.audio_class.hparams.min_level_db) / (-self.audio_class.hparams.min_level_db)), 0, 1) mel = mel.T.astype(np.float32) frame_size = 200 SILSEG = 0.3 SAMPLING = 16000 sil_samples = int(SILSEG * SAMPLING) sil_frames = int(sil_samples / frame_size) sil_data, _ = soundfile.read(self.sil_file) sil_data = sil_data[:sil_samples] sil_mel_spec, _ = self.audio_class._magnitude_spectrogram( sil_data, clip_norm=True) sil_mel_spec = (sil_mel_spec + 4.0) / 8.0 pad_mel_data = np.concatenate((sil_mel_spec[:sil_frames], mel), axis=0) if add_end_sil: pad_mel_data = np.concatenate( (pad_mel_data, sil_mel_spec[:sil_frames]), axis=0) out_mel_file = os.path.join(self.out_mel_dir, '{}-wn.mel'.format(filename)) save_htk_data(pad_mel_data, out_mel_file) def inference_f(self): # print(meta_data['n']) meta_data = _read_meta_yyh(self.text_file) mel_outputs, filename = self.gen_mel(meta_data) print('my_mel_outputs=', mel_outputs) print('my_mel_outputs_max=', np.max(mel_outputs)) print('my_mel_outputs_min=', np.min(mel_outputs)) mel_outputs = np.load(r'../out_0.npy').transpose(1, 0) mel_outputs = mel_outputs * 8.0 - 4.0 print('his_mel_outputs=', mel_outputs) print('his_mel_outputs_max=', np.max(mel_outputs)) print('his_mel_outputs_min=', np.min(mel_outputs)) if self.use_griffin_lim: self.gen_wav_griffin_lim(mel_outputs, filename) if self.gen_wavenet_fea: self.gen_wavenet_feature(mel_outputs, filename) return filename def inference(self): all_meta_data = _read_meta(self.text_file, hparams.meta_format) list(map(self.inference_f, all_meta_data))