def extract_mels_teacher_forcing( output_directory, checkpoint_path, hparams, file_list, n_gpus, rank, group_name, extract_type="mels", ): device = torch.device("cuda:{:d}".format(rank)) if hparams.distributed_run: init_distributed(hparams, n_gpus, rank, group_name) torch.manual_seed(hparams.seed) torch.cuda.manual_seed(hparams.seed) np.random.seed(hparams.seed) eval_loader = prepare_dataloader(hparams, file_list) model = load_model(hparams) checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") model.load_state_dict(checkpoint_dict["state_dict"]) model.eval() # if hparams.fp16_run: # model.half() for batch in tqdm(eval_loader): x, y = parse_batch(batch) with torch.no_grad(): y_pred = model(x) if extract_type == "mels": for res, fname, out_length in zip(y_pred[2], y[0], x[4]): speaker_name, fname = fname.split('/')[-1].split('.')[0].split( '_')[0], fname.split('/')[-1].split('.')[0] np.save( os.path.join(output_directory, speaker_name, fname + ".npy"), res.cpu().numpy()[:, :out_length], ) elif extract_type == "alignments": for alignment, fname, seq_len, out_length in zip( y_pred[4], y[0], x[1], x[4]): alignment = alignment.T[:seq_len, :out_length] np.save( os.path.join(output_directory, fname + ".npy"), np.bincount( np.argmax(alignment.cpu().numpy(), axis=0), minlength=alignment.shape[0], ), ) else: raise Exception(f"Extracting {extract_type} is not supported.")
def GTA_Synthesis(output_directory, checkpoint_path, n_gpus, rank, group_name, hparams, training_mode, verify_outputs, use_val_files, fp16_save, extra_info='', audio_offset=0): """Generate Ground-Truth-Aligned Spectrograms for Training WaveGlow.""" if audio_offset: hparams.load_mel_from_disk = False if hparams.distributed_run: init_distributed(hparams, n_gpus, rank, group_name) torch.manual_seed(hparams.seed) torch.cuda.manual_seed(hparams.seed) if use_val_files: filelisttype = "val" hparams.training_files = hparams.validation_files else: filelisttype = "train" train_loader, _, collate_fn, train_sampler, train_set = prepare_dataloaders( hparams, audio_offset=audio_offset) if training_mode and hparams.drop_frame_rate > 0.: if rank != 0: # if global_mean not yet calcuated, wait for main thread to do it while not os.path.exists(hparams.global_mean_npy): time.sleep(1) global_mean = get_global_mean(train_loader, hparams.global_mean_npy, hparams) hparams.global_mean = global_mean model = load_model(hparams) # Load checkpoint if one exists assert checkpoint_path is not None if checkpoint_path is not None: model = warm_start_model(checkpoint_path, model) if training_mode: model.train() else: model.eval() if hparams.distributed_run or torch.cuda.device_count() > 1: batch_parser = model.parse_batch else: batch_parser = model.parse_batch # ================ MAIN TRAINNIG LOOP! =================== os.makedirs(os.path.join(output_directory), exist_ok=True) f = open(os.path.join(output_directory, f'map_{filelisttype}_{rank}.txt'), 'a', encoding='utf-8') os.makedirs(os.path.join(output_directory, 'mels'), exist_ok=True) total_number_of_data = len(train_set.audiopaths_and_text) max_itter = int(total_number_of_data / hparams.batch_size) remainder_size = total_number_of_data % hparams.batch_size duration = time.time() total = len(train_loader) rolling_sum = StreamingMovingAverage(100) for i, batch in enumerate(train_loader): batch_size = hparams.batch_size if i is not max_itter else remainder_size # get wavefile path audiopaths_and_text = train_set.audiopaths_and_text[ i * hparams.batch_size:i * hparams.batch_size + batch_size] audiopaths = [x[0] for x in audiopaths_and_text] # file name list orig_speaker_ids = [x[2] for x in audiopaths_and_text] # file name list # get len texts indx_list = np.arange(i * hparams.batch_size, i * hparams.batch_size + batch_size).tolist() len_text_list = [] for batch_index in indx_list: text, *_ = train_set.__getitem__(batch_index) len_text_list.append(text.size(0)) _, input_lengths, _, _, output_lengths, speaker_id, _, _ = batch # output_lengths: original mel length input_lengths_, ids_sorted_decreasing = torch.sort( torch.LongTensor(len_text_list), dim=0, descending=True) ids_sorted_decreasing = ids_sorted_decreasing.numpy( ) # ids_sorted_decreasing, original index org_audiopaths = [] # original_file_name mel_paths = [] speaker_ids = [] for k in range(batch_size): d = audiopaths[ids_sorted_decreasing[k]] org_audiopaths.append(d) mel_paths.append(d.replace(".npy", ".mel").replace('.wav', '.mel')) speaker_ids.append(orig_speaker_ids[ids_sorted_decreasing[k]]) x, _ = batch_parser(batch) _, mel_outputs_postnet, _, _ = model(x, teacher_force_till=9999, p_teacher_forcing=1.0) mel_outputs_postnet = mel_outputs_postnet.data.cpu().numpy() for k in range(batch_size): wav_path = org_audiopaths[k].replace(".npy", ".wav") offset_append = '' if audio_offset == 0 else str(audio_offset) mel_path = mel_paths[ k] + offset_append + '.npy' # ext = '.mel.npy' or '.mel1.npy' ... '.mel599.npy' speaker_id = speaker_ids[k] map = "{}|{}|{}\n".format(wav_path, mel_path, speaker_id) f.write(map) mel = mel_outputs_postnet[k, :, :output_lengths[k]] print(wav_path, input_lengths[k], output_lengths[k], mel_outputs_postnet.shape, mel.shape, speaker_id) if fp16_save: mel = mel.astype(np.float16) np.save(mel_path, mel) if verify_outputs: orig_shape = train_set.get_mel(wav_path).shape assert orig_shape == mel.shape, f"Target shape {orig_shape} does not match generated mel shape {mel.shape}.\nFilepath: '{wav_path}'" # check mel from wav_path has same shape as mel just saved duration = time.time() - duration avg_duration = rolling_sum.process(duration) time_left = round(((total - i) * avg_duration) / 3600, 2) print( f'{extra_info}{i}/{total} compute and save GTA melspectrograms in {i}th batch, {duration}s, {time_left}hrs left' ) duration = time.time() f.close()
def GTA_Synthesis(output_directory, checkpoint_path, n_gpus, rank, group_name, hparams): """Generate Ground-Truth-Aligned Spectrograms for Training WaveGlow.""" if hparams.distributed_run: init_distributed(hparams, n_gpus, rank, group_name) torch.manual_seed(hparams.seed) torch.cuda.manual_seed(hparams.seed) model = load_model(hparams) train_loader, valset, collate_fn, train_sampler, train_set = prepare_dataloaders( hparams) # Load checkpoint if one exists assert checkpoint_path is not None if checkpoint_path is not None: model = warm_start_model(checkpoint_path, model) model.eval() if hparams.distributed_run or torch.cuda.device_count() > 1: batch_parser = model.parse_batch else: batch_parser = model.parse_batch # ================ MAIN TRAINNIG LOOP! =================== os.makedirs(os.path.join(output_directory), exist_ok=True) f = open(os.path.join(output_directory, f'map_{rank}.txt'), 'w', encoding='utf-8') os.makedirs(os.path.join(output_directory, 'mels'), exist_ok=True) total_number_of_data = len(train_set.audiopaths_and_text) max_itter = int(total_number_of_data / hparams.batch_size) remainder_size = total_number_of_data % hparams.batch_size duration = time.time() total = len(train_loader) rolling_sum = StreamingMovingAverage(100) for i, batch in enumerate(train_loader): batch_size = hparams.batch_size if i is not max_itter else remainder_size # get wavefile path audiopaths_and_text = train_set.audiopaths_and_text[ i * hparams.batch_size:i * hparams.batch_size + batch_size] audiopaths = [x[0] for x in audiopaths_and_text] # file name list orig_speaker_ids = [x[2] for x in audiopaths_and_text] # file name list # get len texts indx_list = np.arange(i * hparams.batch_size, i * hparams.batch_size + batch_size).tolist() len_text_list = [] for batch_index in indx_list: text, _, _ = train_set.__getitem__(batch_index) len_text_list.append(text.size(0)) _, input_lengths, _, _, output_lengths, speaker_id = batch # output_lengths: original mel length input_lengths_, ids_sorted_decreasing = torch.sort( torch.LongTensor(len_text_list), dim=0, descending=True) ids_sorted_decreasing = ids_sorted_decreasing.numpy( ) # ids_sorted_decreasing, original index org_audiopaths = [] # original_file_name mel_paths = [] speaker_ids = [] for k in range(batch_size): d = audiopaths[ids_sorted_decreasing[k]] org_audiopaths.append(d) mel_paths.append(d.replace(".npy", ".mel").replace('.wav', '.mel')) speaker_ids.append(orig_speaker_ids[ids_sorted_decreasing[k]]) x, _ = batch_parser(batch) _, mel_outputs_postnet, _, _ = model(x, teacher_force_till=9999, p_teacher_forcing=1.0) mel_outputs_postnet = mel_outputs_postnet.data.cpu().numpy() for k in range(batch_size): wav_path = org_audiopaths[k].replace(".npy", ".wav") mel_path = mel_paths[k] + '.npy' speaker_id = speaker_ids[k] map = "{}|{}|{}\n".format(wav_path, mel_path, speaker_id) f.write(map) # To do: size mismatch #diff = output_lengths[k] - (input_lengths[k] / hparams.hop_length) #diff = diff.data.data.cpu().numpy() mel = mel_outputs_postnet[k, :, :output_lengths[k]] print(wav_path, input_lengths[k], output_lengths[k], mel_outputs_postnet.shape, mel.shape, speaker_id) np.save(mel_path, mel) duration = time.time() - duration avg_duration = rolling_sum.process(duration) time_left = round(((total - i) / avg_duration) / 3600, 2) print( f'{i}/{total} compute and save GTA melspectrograms in {i}th batch, {duration}s, {time_left}hrs left' ) duration = time.time() f.close()
torch.backends.cudnn.benchmark = hparams.cudnn_benchmark print("FP16 Run:", hparams.fp16_run) print("Distributed Run:", hparams.distributed_run) print("Rank:", args.rank) if hparams.fp16_run: from apex import amp if not args.use_validation_files: hparams.batch_size = hparams.batch_size * 6 # no gradients stored so batch size can go up a bunch torch.autograd.set_grad_enabled(False) if hparams.distributed_run: init_distributed(hparams, args.n_gpus, args.rank, args.group_name) torch.manual_seed(hparams.seed) torch.cuda.manual_seed(hparams.seed) if args.extremeGTA: for ind, ioffset in enumerate(range(0, hparams.hop_length, args.extremeGTA)): # generate aligned spectrograms for all audio samples if ind < 0: continue GTA_Synthesis(hparams, args, audio_offset=ioffset, extra_info=f"{ind+1}/{hparams.hop_length//args.extremeGTA} ") elif args.save_letter_alignments and args.save_phone_alignments: hparams.p_arpabet = 0.0 GTA_Synthesis(hparams, args, extra_info="1/2 ") hparams.p_arpabet = 1.0 GTA_Synthesis(hparams, args, extra_info="2/2 ") else: GTA_Synthesis(hparams, args)