def data_filter(args): model, device = load_model(args) start_t_test = time.time() # Decode test_set = SVSDataset_filter( align_root_path=args.test_align, pitch_beat_root_path=args.test_pitch, wav_root_path=args.test_wav, char_max_len=args.char_max_len, max_len=args.num_frames, sr=args.sampling_rate, preemphasis=args.preemphasis, nfft=args.nfft, frame_shift=args.frame_shift, frame_length=args.frame_length, n_mels=args.n_mels, power=args.power, max_db=args.max_db, ref_db=args.ref_db, standard=args.standard, sing_quality=args.sing_quality, Hz2semitone=args.Hz2semitone, semitone_min=args.semitone_min, semitone_max=args.semitone_max, ) collate_fn_svs = SVSCollator( args.num_frames, args.char_max_len, args.use_asr_post, args.phone_size, args.n_mels, args.db_joint, False, # random crop -1, # crop_min_length args.Hz2semitone, ) test_loader = torch.utils.data.DataLoader( dataset=test_set, batch_size=1, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn_svs, pin_memory=True, ) with torch.no_grad(): for ( step, data_step, ) in enumerate(test_loader, 1): if args.db_joint: (phone, beat, pitch, spec, real, imag, length, chars, char_len_list, mel, singer_id, semitone, filename_list) = data_step else: print( "No support for augmentation with args.db_joint == False") quit() singer_id = np.array(singer_id).reshape(np.shape(phone)[0], -1) # [batch size, 1] singer_vec = singer_id.repeat(np.shape(phone)[1], axis=1) # [batch size, length] singer_vec = torch.from_numpy(singer_vec).to(device) singer_id = torch.from_numpy(singer_id).to(device) phone = phone.to(device) beat = beat.to(device) pitch = pitch.to(device).float() if semitone is not None: semitone = semitone.to(device) spec = spec.to(device).float() mel = mel.to(device).float() real = real.to(device).float() imag = imag.to(device).float() length_mask = (length > 0).int().unsqueeze(2) length_mel_mask = length_mask.repeat(1, 1, mel.shape[2]).float() length_mask = length_mask.repeat(1, 1, spec.shape[2]).float() length_mask = length_mask.to(device) length_mel_mask = length_mel_mask.to(device) length = length.to(device) char_len_list = char_len_list.to(device) if not args.use_asr_post: chars = chars.to(device) char_len_list = char_len_list.to(device) else: phone = phone.float() if args.Hz2semitone: pitch = semitone if args.normalize: sepc_normalizer = GlobalMVN(args.stats_file) mel_normalizer = GlobalMVN(args.stats_mel_file) spec, _ = sepc_normalizer(spec, length) mel, _ = mel_normalizer(mel, length) len_list, _ = torch.max(length, dim=1) # [len1, len2, len3, ...] len_list = len_list.cpu().detach().numpy() singer_out, phone_out, semitone_out = model(spec, len_list) # calculate num batch_size = np.shape(spec)[0] singer_id = singer_id.view(-1) # [batch size] _, singer_predict = torch.max(singer_out, dim=1) # [batch size] singer_correct = singer_predict.eq(singer_id).cpu().sum().numpy() for i in range(batch_size): phone_i = phone[i, :len_list[i], :].view(-1) # [valid seq len] phone_out_i = phone_out[ i, :len_list[i], :] # [valid seq len, phone_size] _, phone_predict = torch.max(phone_out_i, dim=1) phone_correct = phone_predict.eq(phone_i).cpu().sum().numpy() semitone_i = semitone[i, :len_list[i], :].view( -1) # [valid seq len] semitone_out_i = semitone_out[ i, :len_list[i], :] # [valid seq len, semitone_size] _, semitone_predict = torch.max(semitone_out_i, dim=1) semitone_correct = semitone_predict.eq( semitone_i).cpu().sum().numpy() with open(os.path.join(args.prediction_path, "filter_res.txt"), "a+") as f: f.write( f"{filename_list[i]}|{singer_predict[i]}|{phone_correct}|{semitone_correct}|{len_list[i]}\n" ) end = time.time() logging.info( f"{filename_list[i]} -- sum_time: {(end - start_t_test)}s")
def validate(dev_loader, model, device, criterion, perceptual_entropy, epoch, args, voc_model): """validate.""" losses = AverageMeter() spec_losses = AverageMeter() if args.perceptual_loss > 0: pe_losses = AverageMeter() if args.n_mels > 0: mel_losses = AverageMeter() mcd_metric = AverageMeter() if args.double_mel_loss: double_mel_losses = AverageMeter() model.eval() log_save_dir = os.path.join(args.model_save_dir, "epoch{}/log_val_figure".format(epoch)) if not os.path.exists(log_save_dir): os.makedirs(log_save_dir) start = time.time() with torch.no_grad(): for (step, data_step) in enumerate(dev_loader, 1): if args.db_joint: ( phone, beat, pitch, spec, real, imag, length, chars, char_len_list, mel, singer_id, ) = data_step singer_id = np.array(singer_id).reshape( np.shape(phone)[0], -1) # [batch size, 1] singer_vec = singer_id.repeat(np.shape(phone)[1], axis=1) # [batch size, length] singer_vec = torch.from_numpy(singer_vec).to(device) else: ( phone, beat, pitch, spec, real, imag, length, chars, char_len_list, mel, ) = data_step phone = phone.to(device) beat = beat.to(device) pitch = pitch.to(device).float() spec = spec.to(device).float() if mel is not None: mel = mel.to(device).float() real = real.to(device).float() imag = imag.to(device).float() length_mask = length.unsqueeze(2) if mel is not None: length_mel_mask = length_mask.repeat(1, 1, mel.shape[2]).float() length_mel_mask = length_mel_mask.to(device) length_mask = length_mask.repeat(1, 1, spec.shape[2]).float() length_mask = length_mask.to(device) length = length.to(device) char_len_list = char_len_list.to(device) if not args.use_asr_post: chars = chars.to(device) char_len_list = char_len_list.to(device) else: phone = phone.float() if args.model_type == "GLU_Transformer": if args.db_joint: output, att, output_mel, output_mel2 = model( chars, phone, pitch, beat, singer_vec, pos_char=char_len_list, pos_spec=length, ) else: output, att, output_mel, output_mel2 = model( chars, phone, pitch, beat, pos_char=char_len_list, pos_spec=length, ) elif args.model_type == "LSTM": if args.db_joint: output, hidden, output_mel, output_mel2 = model( phone, pitch, beat, singer_vec) else: output, hidden, output_mel, output_mel2 = model( phone, pitch, beat) att = None elif args.model_type == "GRU_gs": output, att, output_mel = model(spec, phone, pitch, beat, length, args) att = None elif args.model_type == "PureTransformer": output, att, output_mel, output_mel2 = model( chars, phone, pitch, beat, pos_char=char_len_list, pos_spec=length) elif args.model_type == "Conformer": output, att, output_mel, output_mel2 = model( chars, phone, pitch, beat, pos_char=char_len_list, pos_spec=length) elif args.model_type == "Comformer_full": if args.db_joint: output, att, output_mel, output_mel2 = model( chars, phone, pitch, beat, singer_vec, pos_char=char_len_list, pos_spec=length, ) else: output, att, output_mel, output_mel2 = model( chars, phone, pitch, beat, pos_char=char_len_list, pos_spec=length, ) elif args.model_type == "USTC_DAR": output_mel = model(phone, pitch, beat, length, args) att = None spec_origin = spec.clone() mel_origin = mel.clone() if args.normalize: sepc_normalizer = GlobalMVN(args.stats_file) mel_normalizer = GlobalMVN(args.stats_mel_file) output_mel_normalizer = GlobalMVN(args.stats_mel_file) spec, _ = sepc_normalizer(spec, length) mel, _ = mel_normalizer(mel, length) if args.model_type == "USTC_DAR": spec_loss = 0 else: spec_loss = criterion(output, spec, length_mask) if args.n_mels > 0: mel_loss = criterion(output_mel, mel, length_mel_mask) if args.double_mel_loss: double_mel_loss = criterion(output_mel2, mel, length_mel_mask) else: double_mel_loss = 0 else: mel_loss = 0 double_mel_loss = 0 if args.vocoder_category == "wavernn": dev_loss = mel_loss + double_mel_loss else: dev_loss = mel_loss + double_mel_loss + spec_loss if args.perceptual_loss > 0: pe_loss = perceptual_entropy(output, real, imag) final_loss = (args.perceptual_loss * pe_loss + (1 - args.perceptual_loss) * dev_loss) else: final_loss = dev_loss losses.update(final_loss.item(), phone.size(0)) if args.model_type != "USTC_DAR": spec_losses.update(spec_loss.item(), phone.size(0)) if args.perceptual_loss > 0: # pe_loss = perceptual_entropy(output, real, imag) pe_losses.update(pe_loss.item(), phone.size(0)) if args.n_mels > 0: mel_losses.update(mel_loss.item(), phone.size(0)) if args.double_mel_loss: double_mel_losses.update(double_mel_loss.item(), phone.size(0)) if args.model_type == "USTC_DAR": # normalize inverse stage if args.normalize and args.stats_file: # output_mel, _ = mel_normalizer.inverse(output_mel, length) mel, _ = mel_normalizer.inverse(mel, length) output_mel, _ = output_mel_normalizer.inverse( output_mel, length) mcd_value, length_sum = ( 0, 1, ) # FIX ME! Calculate_melcd_fromMelSpectrum else: # normalize inverse stage if args.normalize and args.stats_file: output, _ = sepc_normalizer.inverse(output, length) # output_mel, _ = mel_normalizer.inverse(output_mel, length) mel, _ = mel_normalizer.inverse(mel, length) output_mel, _ = output_mel_normalizer.inverse( output_mel, length) (mcd_value, length_sum) = Metrics.Calculate_melcd_fromLinearSpectrum( output, spec_origin, length, args) mcd_metric.update(mcd_value, length_sum) if step % args.dev_step_log == 0: if args.model_type == "USTC_DAR": log_figure_mel(step, output_mel, mel_origin, att, length, log_save_dir, args) else: if args.vocoder_category == "wavernn": for i in range(output_mel.shape[0]): one_batch_output_mel = output_mel[i].unsqueeze(0) one_batch_mel = mel[i].unsqueeze(0) log_mel( step, one_batch_output_mel, one_batch_mel, att, length, log_save_dir, args, voc_model, ) else: log_figure(step, output, spec_origin, att, length, log_save_dir, args) out_log = ("step {}: train_loss {:.4f}; " "spec_loss {:.4f}; mcd_value {:.4f};".format( step, losses.avg, spec_losses.avg, mcd_metric.avg)) if args.perceptual_loss > 0: out_log += "pe_loss {:.4f}; ".format(pe_losses.avg) if args.n_mels > 0: out_log += "mel_loss {:.4f}; ".format(mel_losses.avg) if args.double_mel_loss: out_log += "dmel_loss {:.4f}; ".format( double_mel_losses.avg) end = time.time() print("{} -- sum_time: {}s".format(out_log, (end - start))) info = { "loss": losses.avg, "spec_loss": spec_losses.avg, "mcd_value": mcd_metric.avg, } if args.perceptual_loss > 0: info["pe_loss"] = pe_losses.avg if args.n_mels > 0: info["mel_loss"] = mel_losses.avg return info
def train_one_epoch( train_loader, model, device, optimizer, criterion, perceptual_entropy, epoch, args, voc_model, ): """train_one_epoch.""" losses = AverageMeter() spec_losses = AverageMeter() if args.perceptual_loss > 0: pe_losses = AverageMeter() if args.n_mels > 0: mel_losses = AverageMeter() # mcd_metric = AverageMeter() # f0_distortion_metric, vuv_error_metric = # AverageMeter(), AverageMeter() if args.double_mel_loss: double_mel_losses = AverageMeter() model.train() log_save_dir = os.path.join(args.model_save_dir, "epoch{}/log_train_figure".format(epoch)) if not os.path.exists(log_save_dir): os.makedirs(log_save_dir) start = time.time() # f0_ground_truth_all = np.reshape(np.array([]), (-1, 1)) # f0_synthesis_all = np.reshape(np.array([]), (-1, 1)) for (step, data_step) in enumerate(train_loader, 1): if args.db_joint: ( phone, beat, pitch, spec, real, imag, length, chars, char_len_list, mel, singer_id, ) = data_step singer_id = np.array(singer_id).reshape(np.shape(phone)[0], -1) # [batch size, 1] singer_vec = singer_id.repeat(np.shape(phone)[1], axis=1) # [batch size, length] singer_vec = torch.from_numpy(singer_vec).to(device) else: ( phone, beat, pitch, spec, real, imag, length, chars, char_len_list, mel, ) = data_step phone = phone.to(device) beat = beat.to(device) pitch = pitch.to(device).float() spec = spec.to(device).float() if mel is not None: mel = mel.to(device).float() real = real.to(device).float() imag = imag.to(device).float() length_mask = length.unsqueeze(2) if mel is not None: length_mel_mask = length_mask.repeat(1, 1, mel.shape[2]).float() length_mel_mask = length_mel_mask.to(device) length_mask = length_mask.repeat(1, 1, spec.shape[2]).float() length_mask = length_mask.to(device) length = length.to(device) char_len_list = char_len_list.to(device) if not args.use_asr_post: chars = chars.to(device) char_len_list = char_len_list.to(device) else: phone = phone.float() # output = [batch size, num frames, feat_dim] # output_mel = [batch size, num frames, n_mels dimension] if args.model_type == "GLU_Transformer": if args.db_joint: output, att, output_mel, output_mel2 = model( chars, phone, pitch, beat, singer_vec, pos_char=char_len_list, pos_spec=length, ) else: output, att, output_mel, output_mel2 = model( chars, phone, pitch, beat, pos_char=char_len_list, pos_spec=length) elif args.model_type == "LSTM": if args.db_joint: output, hidden, output_mel, output_mel2 = model( phone, pitch, beat, singer_vec) else: output, hidden, output_mel, output_mel2 = model( phone, pitch, beat) att = None elif args.model_type == "GRU_gs": output, att, output_mel = model(spec, phone, pitch, beat, length, args) att = None elif args.model_type == "PureTransformer": output, att, output_mel, output_mel2 = model( chars, phone, pitch, beat, pos_char=char_len_list, pos_spec=length) elif args.model_type == "Conformer": # print(f"chars: {np.shape(chars)}, phone: # {np.shape(phone)}, length: {np.shape(length)}") output, att, output_mel, output_mel2 = model( chars, phone, pitch, beat, pos_char=char_len_list, pos_spec=length) elif args.model_type == "Comformer_full": if args.db_joint: output, att, output_mel, output_mel2 = model( chars, phone, pitch, beat, singer_vec, pos_char=char_len_list, pos_spec=length, ) else: output, att, output_mel, output_mel2 = model( chars, phone, pitch, beat, pos_char=char_len_list, pos_spec=length) elif args.model_type == "USTC_DAR": output_mel = model(phone, pitch, beat, length, args) # mel loss written in spec loss att = None spec_origin = spec.clone() mel_origin = mel.clone() if args.normalize: sepc_normalizer = GlobalMVN(args.stats_file) mel_normalizer = GlobalMVN(args.stats_mel_file) output_mel_normalizer = GlobalMVN(args.stats_mel_file) spec, _ = sepc_normalizer(spec, length) mel, _ = mel_normalizer(mel, length) if args.model_type == "USTC_DAR": spec_loss = 0 else: spec_loss = criterion(output, spec, length_mask) if args.n_mels > 0: mel_loss = criterion(output_mel, mel, length_mel_mask) if args.double_mel_loss: double_mel_loss = criterion(output_mel2, mel, length_mel_mask) else: double_mel_loss = 0 else: mel_loss = 0 double_mel_loss = 0 if args.vocoder_category == "wavernn": train_loss = mel_loss + double_mel_loss else: train_loss = mel_loss + double_mel_loss + spec_loss if args.perceptual_loss > 0: pe_loss = perceptual_entropy(output, real, imag) final_loss = (args.perceptual_loss * pe_loss + (1 - args.perceptual_loss) * train_loss) else: final_loss = train_loss final_loss = final_loss / args.accumulation_steps final_loss.backward() if args.gradclip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradclip) if (epoch + 1) % args.accumulation_steps == 0: if args.optimizer == "noam": optimizer.step_and_update_lr() else: optimizer.step() # 梯度清零 optimizer.zero_grad() losses.update(final_loss.item(), phone.size(0)) if args.model_type != "USTC_DAR": spec_losses.update(spec_loss.item(), phone.size(0)) if args.perceptual_loss > 0: pe_losses.update(pe_loss.item(), phone.size(0)) if args.n_mels > 0: mel_losses.update(mel_loss.item(), phone.size(0)) if args.double_mel_loss: double_mel_losses.update(double_mel_loss.item(), phone.size(0)) if step % args.train_step_log == 0: end = time.time() if args.model_type == "USTC_DAR": # normalize inverse 只在infer的时候用,因为log过程需要转换成wav,和计算mcd等指标 if args.normalize and args.stats_file: output_mel, _ = mel_normalizer.inverse(output_mel, length) mel, _ = mel_normalizer.inverse(mel, length) output_mel = output_mel_normalizer.inverse( output_mel, length) log_figure_mel(step, output_mel, mel_origin, att, length, log_save_dir, args) out_log = "step {}: train_loss {:.4f}; spec_loss {:.4f};".format( step, losses.avg, spec_losses.avg) else: # normalize inverse 只在infer的时候用,因为log过程需要转换成wav,和计算mcd等指标 if args.normalize and args.stats_file: output, _ = sepc_normalizer.inverse(output, length) mel, _ = mel_normalizer.inverse(mel, length) output_mel = output_mel_normalizer.inverse( output_mel, length) if args.vocoder_category == "wavernn": for i in range(output_mel[0].shape[0]): one_batch_output_mel = output_mel[0][i].unsqueeze(0) one_batch_mel = mel[i].unsqueeze(0) log_mel( step, one_batch_output_mel, one_batch_mel, att, length, log_save_dir, args, voc_model, ) else: log_figure(step, output, spec_origin, att, length, log_save_dir, args) out_log = "step {}: train_loss {:.4f}; spec_loss {:.4f};".format( step, losses.avg, spec_losses.avg) if args.perceptual_loss > 0: out_log += "pe_loss {:.4f}; ".format(pe_losses.avg) if args.n_mels > 0: out_log += "mel_loss {:.4f}; ".format(mel_losses.avg) if args.double_mel_loss: out_log += "dmel_loss {:.4f}; ".format( double_mel_losses.avg) print("{} -- sum_time: {:.2f}s".format(out_log, (end - start))) info = {"loss": losses.avg, "spec_loss": spec_losses.avg} if args.perceptual_loss > 0: info["pe_loss"] = pe_losses.avg if args.n_mels > 0: info["mel_loss"] = mel_losses.avg return info
def infer_predictor(args): """infer.""" torch.cuda.set_device(args.gpu_id) logging.info(f"GPU {args.gpu_id} is used") torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.enabled = False device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # prepare model model = RNN_Discriminator( embed_size=128, d_model=128, hidden_size=128, num_layers=2, n_specs=1025, singer_size=7, phone_size=43, simitone_size=59, dropout=0.1, bidirectional=True, device=device, ) logging.info(f"{model}") model = model.to(device) logging.info( f"The model has {count_parameters(model):,} trainable parameters") # Load model weights logging.info(f"Loading pretrained weights from {args.model_file}") checkpoint = torch.load(args.model_file, map_location=device) state_dict = checkpoint["state_dict"] model_dict = model.state_dict() state_dict_new = {} para_list = [] for k, v in state_dict.items(): # assert k in model_dict if (k == "normalizer.mean" or k == "normalizer.std" or k == "mel_normalizer.mean" or k == "mel_normalizer.std"): continue if model_dict[k].size() == state_dict[k].size(): state_dict_new[k] = v else: para_list.append(k) logging.info(f"Total {len(state_dict)} parameter sets, " f"loaded {len(state_dict_new)} parameter set") if len(para_list) > 0: logging.warning(f"Not loading {para_list} because of different sizes") model.load_state_dict(state_dict_new) logging.info(f"Loaded checkpoint {args.model_file}") model = model.to(device) model.eval() # Decode test_set = SVSDataset( align_root_path=args.test_align, pitch_beat_root_path=args.test_pitch, wav_root_path=args.test_wav, char_max_len=args.char_max_len, max_len=args.num_frames, sr=args.sampling_rate, preemphasis=args.preemphasis, nfft=args.nfft, frame_shift=args.frame_shift, frame_length=args.frame_length, n_mels=args.n_mels, power=args.power, max_db=args.max_db, ref_db=args.ref_db, standard=args.standard, sing_quality=args.sing_quality, db_joint=args.db_joint, Hz2semitone=args.Hz2semitone, semitone_min=args.semitone_min, semitone_max=args.semitone_max, phone_shift_size=-1, semitone_shift=False, ) collate_fn_svs = SVSCollator( args.num_frames, args.char_max_len, args.use_asr_post, args.phone_size, args.n_mels, args.db_joint, False, # random crop -1, # crop_min_length args.Hz2semitone, ) test_loader = torch.utils.data.DataLoader( dataset=test_set, batch_size=1, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn_svs, pin_memory=True, ) criterion = nn.CrossEntropyLoss(reduction="sum") start_t_test = time.time() singer_losses = AverageMeter() phone_losses = AverageMeter() semitone_losses = AverageMeter() singer_count = AverageMeter() phone_count = AverageMeter() semitone_count = AverageMeter() with torch.no_grad(): for (step, data_step) in enumerate(test_loader, 1): if args.db_joint: ( phone, beat, pitch, spec, real, imag, length, chars, char_len_list, mel, singer_id, semitone, ) = data_step singer_id = np.array(singer_id).reshape( np.shape(phone)[0], -1) # [batch size, 1] singer_vec = singer_id.repeat(np.shape(phone)[1], axis=1) # [batch size, length] singer_vec = torch.from_numpy(singer_vec).to(device) singer_id = torch.from_numpy(singer_id).to(device) else: ( phone, beat, pitch, spec, real, imag, length, chars, char_len_list, mel, semitone, ) = data_step phone = phone.to(device) beat = beat.to(device) pitch = pitch.to(device).float() if semitone is not None: semitone = semitone.to(device) spec = spec.to(device).float() mel = mel.to(device).float() real = real.to(device).float() imag = imag.to(device).float() length_mask = (length > 0).int().unsqueeze(2) length_mel_mask = length_mask.repeat(1, 1, mel.shape[2]).float() length_mask = length_mask.repeat(1, 1, spec.shape[2]).float() length_mask = length_mask.to(device) length_mel_mask = length_mel_mask.to(device) length = length.to(device) char_len_list = char_len_list.to(device) if not args.use_asr_post: chars = chars.to(device) char_len_list = char_len_list.to(device) else: phone = phone.float() if args.Hz2semitone: pitch = semitone if args.normalize: sepc_normalizer = GlobalMVN(args.stats_file) mel_normalizer = GlobalMVN(args.stats_mel_file) spec, _ = sepc_normalizer(spec, length) mel, _ = mel_normalizer(mel, length) len_list, _ = torch.max(length, dim=1) # [len1, len2, len3, ...] len_list = len_list.cpu().detach().numpy() singer_out, phone_out, semitone_out = model(spec, len_list) # calculate CrossEntropy loss (defination - reduction:sum) phone_loss = 0 semitone_loss = 0 phone_correct = 0 semitone_correct = 0 batch_size = np.shape(spec)[0] for i in range(batch_size): phone_i = phone[i, :len_list[i], :].view(-1) # [valid seq len] phone_out_i = phone_out[ i, :len_list[i], :] # [valid seq len, phone_size] phone_loss += criterion(phone_out_i, phone_i) _, phone_predict = torch.max(phone_out_i, dim=1) phone_correct += phone_predict.eq(phone_i).cpu().sum().numpy() semitone_i = semitone[i, :len_list[i], :].view( -1) # [valid seq len] semitone_out_i = semitone_out[ i, :len_list[i], :] # [valid seq len, semitone_size] semitone_loss += criterion(semitone_out_i, semitone_i) _, semitone_predict = torch.max(semitone_out_i, dim=1) semitone_correct += semitone_predict.eq( semitone_i).cpu().sum().numpy() singer_id = singer_id.view(-1) # [batch size] _, singer_predict = torch.max(singer_out, dim=1) singer_correct = singer_predict.eq(singer_id).cpu().sum().numpy() phone_loss /= np.sum(len_list) semitone_loss /= np.sum(len_list) singer_loss = criterion(singer_out, singer_id) / batch_size # restore loss info singer_losses.update(singer_loss.item(), batch_size) phone_losses.update(phone_loss.item(), np.sum(len_list)) semitone_losses.update(semitone_loss.item(), np.sum(len_list)) singer_count.update(singer_correct.item() / batch_size, batch_size) phone_count.update(phone_correct.item() / np.sum(len_list), np.sum(len_list)) semitone_count.update(semitone_correct.item() / np.sum(len_list), np.sum(len_list)) if step % 1 == 0: end = time.time() out_log = "step {}: loss {:.6f}, ".format( step, singer_losses.avg + phone_losses.avg + semitone_losses.avg) out_log += "\t singer_loss: {:.4f} ".format(singer_losses.avg) out_log += "phone_loss: {:.4f} ".format(phone_losses.avg) out_log += "semitone_loss: {:.4f} \n".format( semitone_losses.avg) out_log += "\t singer_accuracy: {:.4f}% ".format( singer_count.avg * 100) out_log += "phone_accuracy: {:.4f}% ".format(phone_count.avg * 100) out_log += "semitone_accuracy: {:.4f}% ".format( semitone_count.avg * 100) print("{} -- sum_time: {:.2f}s".format(out_log, (end - start_t_test))) end_t_test = time.time() out_log = "\nTest Stage: " out_log += "loss: {:.4f}, ".format(singer_losses.avg + phone_losses.avg + semitone_losses.avg) out_log += "singer_loss: {:.4f}, ".format(singer_losses.avg) out_log += "phone_loss: {:.4f}, semitone_loss: {:.4f} \n".format( phone_losses.avg, semitone_losses.avg, ) out_log += "singer_accuracy: {:.4f}%, ".format(singer_count.avg * 100) out_log += "phone_accuracy: {:.4f}%, semitone_accuracy: {:.4f}% ".format( phone_count.avg * 100, semitone_count.avg * 100) logging.info("{} time: {:.2f}s".format(out_log, end_t_test - start_t_test))
def augmentation(args, target_singer_id, output_path): if not os.path.exists(output_path): os.makedirs(output_path) model, device = load_model(args) start_t_test = time.time() # Decode test_set = SVSDataset( align_root_path=args.test_align, pitch_beat_root_path=args.test_pitch, wav_root_path=args.test_wav, char_max_len=args.char_max_len, max_len=args.num_frames, sr=args.sampling_rate, preemphasis=args.preemphasis, nfft=args.nfft, frame_shift=args.frame_shift, frame_length=args.frame_length, n_mels=args.n_mels, power=args.power, max_db=args.max_db, ref_db=args.ref_db, standard=args.standard, sing_quality=args.sing_quality, db_joint=args.db_joint, Hz2semitone=args.Hz2semitone, semitone_min=args.semitone_min, semitone_max=args.semitone_max, phone_shift_size=-1, semitone_shift=False, ) collate_fn_svs = SVSCollator( args.num_frames, args.char_max_len, args.use_asr_post, args.phone_size, args.n_mels, args.db_joint, False, # random crop -1, # crop_min_length args.Hz2semitone, ) test_loader = torch.utils.data.DataLoader( dataset=test_set, batch_size=1, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn_svs, pin_memory=True, ) with torch.no_grad(): for ( step, data_step, ) in enumerate(test_loader, 1): if args.db_joint: (phone, beat, pitch, spec, real, imag, length, chars, char_len_list, mel, singer_id, semitone, filename_list, flag_filter_list) = data_step else: print( "No support for augmentation with args.db_joint == False") quit() t_singer_id = np.array(target_singer_id).reshape( np.shape(phone)[0], -1) # [batch size, 1] singer_vec = t_singer_id.repeat(np.shape(phone)[1], axis=1) # [batch size, length] singer_vec = torch.from_numpy(singer_vec).to(device) phone = phone.to(device) beat = beat.to(device) pitch = pitch.to(device).float() if semitone is not None: semitone = semitone.to(device) spec = spec.to(device).float() mel = mel.to(device).float() real = real.to(device).float() imag = imag.to(device).float() length_mask = (length > 0).int().unsqueeze(2) length_mel_mask = length_mask.repeat(1, 1, mel.shape[2]).float() length_mask = length_mask.repeat(1, 1, spec.shape[2]).float() length_mask = length_mask.to(device) length_mel_mask = length_mel_mask.to(device) length = length.to(device) char_len_list = char_len_list.to(device) if not args.use_asr_post: chars = chars.to(device) char_len_list = char_len_list.to(device) else: phone = phone.float() if args.Hz2semitone: pitch = semitone if args.model_type == "GLU_Transformer": if args.db_joint: output, att, output_mel, output_mel2 = model( chars, phone, pitch, beat, singer_vec, pos_char=char_len_list, pos_spec=length, ) elif args.model_type == "LSTM": if args.db_joint: output, hidden, output_mel, output_mel2 = model( phone, pitch, beat, singer_vec) elif args.model_type == "Comformer_full": if args.db_joint: output, att, output_mel, output_mel2 = model( chars, phone, pitch, beat, singer_vec, pos_char=char_len_list, pos_spec=length, ) if args.normalize: sepc_normalizer = GlobalMVN(args.stats_file) mel_normalizer = GlobalMVN(args.stats_mel_file) # normalize inverse stage if args.normalize and args.stats_file: output, _ = sepc_normalizer.inverse(output, length) # write wav output = output.cpu().detach().numpy()[0] length = np.max(length.cpu().detach().numpy()[0]) output = output[:length] wav = spectrogram2wav( output, args.max_db, args.ref_db, args.preemphasis, args.power, args.sampling_rate, args.frame_shift, args.frame_length, args.nfft, ) wr_fname = filename_list[0] + "-" + str( target_singer_id) # batch_size = 1 if librosa.__version__ < "0.8.0": librosa.output.write_wav( os.path.join(output_path, "{}.wav".format(wr_fname)), wav, args.sampling_rate, ) else: # librosa > 0.8 remove librosa.output.write_wav module sf.write( os.path.join(output_path, "{}.wav".format(wr_fname)), wav, args.sampling_rate, format="wav", subtype="PCM_24", ) end = time.time() out_log = os.path.join(output_path, "{}.wav".format(wr_fname)) logging.info(f"{out_log} -- sum_time: {(end - start_t_test)}s")
def infer(args): """infer.""" torch.cuda.set_device(args.gpu_id) logging.info(f"GPU {args.gpu_id} is used") torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.enabled = False device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # prepare model if args.model_type == "GLU_Transformer": model = GLU_TransformerSVS( phone_size=args.phone_size, embed_size=args.embedding_size, hidden_size=args.hidden_size, glu_num_layers=args.glu_num_layers, dropout=args.dropout, output_dim=args.feat_dim, dec_nhead=args.dec_nhead, dec_num_block=args.dec_num_block, n_mels=args.n_mels, double_mel_loss=args.double_mel_loss, local_gaussian=args.local_gaussian, device=device, ) elif args.model_type == "LSTM": model = LSTMSVS( phone_size=args.phone_size, embed_size=args.embedding_size, d_model=args.hidden_size, num_layers=args.num_rnn_layers, dropout=args.dropout, d_output=args.feat_dim, n_mels=args.n_mels, double_mel_loss=args.double_mel_loss, device=device, use_asr_post=args.use_asr_post, ) elif args.model_type == "GRU_gs": model = GRUSVS_gs( phone_size=args.phone_size, embed_size=args.embedding_size, d_model=args.hidden_size, num_layers=args.num_rnn_layers, dropout=args.dropout, d_output=args.feat_dim, n_mels=args.n_mels, double_mel_loss=args.double_mel_loss, device=device, use_asr_post=args.use_asr_post, ) elif args.model_type == "PureTransformer": model = TransformerSVS( phone_size=args.phone_size, embed_size=args.embedding_size, hidden_size=args.hidden_size, glu_num_layers=args.glu_num_layers, dropout=args.dropout, output_dim=args.feat_dim, dec_nhead=args.dec_nhead, dec_num_block=args.dec_num_block, n_mels=args.n_mels, double_mel_loss=args.double_mel_loss, local_gaussian=args.local_gaussian, device=device, ) elif args.model_type == "Conformer": model = ConformerSVS( phone_size=args.phone_size, embed_size=args.embedding_size, enc_attention_dim=args.enc_attention_dim, enc_attention_heads=args.enc_attention_heads, enc_linear_units=args.enc_linear_units, enc_num_blocks=args.enc_num_blocks, enc_dropout_rate=args.enc_dropout_rate, enc_positional_dropout_rate=args.enc_positional_dropout_rate, enc_attention_dropout_rate=args.enc_attention_dropout_rate, enc_input_layer=args.enc_input_layer, enc_normalize_before=args.enc_normalize_before, enc_concat_after=args.enc_concat_after, enc_positionwise_layer_type=args.enc_positionwise_layer_type, enc_positionwise_conv_kernel_size=(args.enc_positionwise_conv_kernel_size), enc_macaron_style=args.enc_macaron_style, enc_pos_enc_layer_type=args.enc_pos_enc_layer_type, enc_selfattention_layer_type=args.enc_selfattention_layer_type, enc_activation_type=args.enc_activation_type, enc_use_cnn_module=args.enc_use_cnn_module, enc_cnn_module_kernel=args.enc_cnn_module_kernel, enc_padding_idx=args.enc_padding_idx, output_dim=args.feat_dim, dec_nhead=args.dec_nhead, dec_num_block=args.dec_num_block, n_mels=args.n_mels, double_mel_loss=args.double_mel_loss, local_gaussian=args.local_gaussian, dec_dropout=args.dec_dropout, device=device, ) elif args.model_type == "Comformer_full": model = ConformerSVS_FULL( phone_size=args.phone_size, embed_size=args.embedding_size, output_dim=args.feat_dim, n_mels=args.n_mels, enc_attention_dim=args.enc_attention_dim, enc_attention_heads=args.enc_attention_heads, enc_linear_units=args.enc_linear_units, enc_num_blocks=args.enc_num_blocks, enc_dropout_rate=args.enc_dropout_rate, enc_positional_dropout_rate=args.enc_positional_dropout_rate, enc_attention_dropout_rate=args.enc_attention_dropout_rate, enc_input_layer=args.enc_input_layer, enc_normalize_before=args.enc_normalize_before, enc_concat_after=args.enc_concat_after, enc_positionwise_layer_type=args.enc_positionwise_layer_type, enc_positionwise_conv_kernel_size=(args.enc_positionwise_conv_kernel_size), enc_macaron_style=args.enc_macaron_style, enc_pos_enc_layer_type=args.enc_pos_enc_layer_type, enc_selfattention_layer_type=args.enc_selfattention_layer_type, enc_activation_type=args.enc_activation_type, enc_use_cnn_module=args.enc_use_cnn_module, enc_cnn_module_kernel=args.enc_cnn_module_kernel, enc_padding_idx=args.enc_padding_idx, dec_attention_dim=args.dec_attention_dim, dec_attention_heads=args.dec_attention_heads, dec_linear_units=args.dec_linear_units, dec_num_blocks=args.dec_num_blocks, dec_dropout_rate=args.dec_dropout_rate, dec_positional_dropout_rate=args.dec_positional_dropout_rate, dec_attention_dropout_rate=args.dec_attention_dropout_rate, dec_input_layer=args.dec_input_layer, dec_normalize_before=args.dec_normalize_before, dec_concat_after=args.dec_concat_after, dec_positionwise_layer_type=args.dec_positionwise_layer_type, dec_positionwise_conv_kernel_size=(args.dec_positionwise_conv_kernel_size), dec_macaron_style=args.dec_macaron_style, dec_pos_enc_layer_type=args.dec_pos_enc_layer_type, dec_selfattention_layer_type=args.dec_selfattention_layer_type, dec_activation_type=args.dec_activation_type, dec_use_cnn_module=args.dec_use_cnn_module, dec_cnn_module_kernel=args.dec_cnn_module_kernel, dec_padding_idx=args.dec_padding_idx, device=device, ) else: raise ValueError("Not Support Model Type %s" % args.model_type) logging.info(f"{model}") logging.info(f"The model has {count_parameters(model):,} trainable parameters") # Load model weights logging.info(f"Loading pretrained weights from {args.model_file}") checkpoint = torch.load(args.model_file, map_location=device) state_dict = checkpoint["state_dict"] model_dict = model.state_dict() state_dict_new = {} para_list = [] for k, v in state_dict.items(): # assert k in model_dict if ( k == "normalizer.mean" or k == "normalizer.std" or k == "mel_normalizer.mean" or k == "mel_normalizer.std" ): continue if model_dict[k].size() == state_dict[k].size(): state_dict_new[k] = v else: para_list.append(k) logging.info( f"Total {len(state_dict)} parameter sets, " f"loaded {len(state_dict_new)} parameter set" ) if len(para_list) > 0: logging.warning(f"Not loading {para_list} because of different sizes") model.load_state_dict(state_dict_new) logging.info(f"Loaded checkpoint {args.model_file}") model = model.to(device) model.eval() # Decode test_set = SVSDataset( align_root_path=args.test_align, pitch_beat_root_path=args.test_pitch, wav_root_path=args.test_wav, char_max_len=args.char_max_len, max_len=args.num_frames, sr=args.sampling_rate, preemphasis=args.preemphasis, nfft=args.nfft, frame_shift=args.frame_shift, frame_length=args.frame_length, n_mels=args.n_mels, power=args.power, max_db=args.max_db, ref_db=args.ref_db, standard=args.standard, sing_quality=args.sing_quality, ) collate_fn_svs = SVSCollator( args.num_frames, args.char_max_len, args.use_asr_post, args.phone_size, ) test_loader = torch.utils.data.DataLoader( dataset=test_set, batch_size=1, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn_svs, pin_memory=True, ) if args.loss == "l1": criterion = MaskedLoss("l1") elif args.loss == "mse": criterion = MaskedLoss("mse") else: raise ValueError("Not Support Loss Type") losses = AverageMeter() spec_losses = AverageMeter() if args.perceptual_loss > 0: pe_losses = AverageMeter() if args.n_mels > 0: mel_losses = AverageMeter() mcd_metric = AverageMeter() f0_distortion_metric, vuv_error_metric = ( AverageMeter(), AverageMeter(), ) if args.double_mel_loss: double_mel_losses = AverageMeter() model.eval() if not os.path.exists(args.prediction_path): os.makedirs(args.prediction_path) f0_ground_truth_all = np.reshape(np.array([]), (-1, 1)) f0_synthesis_all = np.reshape(np.array([]), (-1, 1)) start_t_test = time.time() # preload vocoder model if args.vocoder_category == "wavernn": voc_model = WaveRNN( rnn_dims=512, fc_dims=512, bits=9, pad=2, upsample_factors=( 5, 5, 11, ), feat_dims=80, compute_dims=128, res_out_dims=128, res_blocks=10, hop_length=275, # 12.5ms - in line with Tacotron 2 paper sample_rate=22050, mode="MOL", ).to(device) voc_model.load("./weights/wavernn/latest_weights.pyt") with torch.no_grad(): for ( step, ( phone, beat, pitch, spec, real, imag, length, chars, char_len_list, mel, ), ) in enumerate(test_loader, 1): # if step >= args.decode_sample: # break phone = phone.to(device) beat = beat.to(device) pitch = pitch.to(device).float() spec = spec.to(device).float() mel = mel.to(device).float() real = real.to(device).float() imag = imag.to(device).float() length_mask = length.unsqueeze(2) length_mel_mask = length_mask.repeat(1, 1, mel.shape[2]).float() length_mask = length_mask.repeat(1, 1, spec.shape[2]).float() length_mask = length_mask.to(device) length_mel_mask = length_mel_mask.to(device) length = length.to(device) char_len_list = char_len_list.to(device) if not args.use_asr_post: chars = chars.to(device) char_len_list = char_len_list.to(device) else: phone = phone.float() if args.model_type == "GLU_Transformer": output, att, output_mel, output_mel2 = model( chars, phone, pitch, beat, pos_char=char_len_list, pos_spec=length, ) elif args.model_type == "LSTM": output, hidden, output_mel, output_mel2 = model(phone, pitch, beat) att = None elif args.model_type == "GRU_gs": output, att, output_mel = model(spec, phone, pitch, beat, length, args) att = None elif args.model_type == "PureTransformer": output, att, output_mel, output_mel2 = model( chars, phone, pitch, beat, pos_char=char_len_list, pos_spec=length, ) elif args.model_type == "Conformer": output, att, output_mel, output_mel2 = model( chars, phone, pitch, beat, pos_char=char_len_list, pos_spec=length, ) elif args.model_type == "Comformer_full": output, att, output_mel, output_mel2 = model( chars, phone, pitch, beat, pos_char=char_len_list, pos_spec=length, ) spec_origin = spec.clone() # spec_origin = spec if args.normalize: sepc_normalizer = GlobalMVN(args.stats_file) mel_normalizer = GlobalMVN(args.stats_mel_file) spec, _ = sepc_normalizer(spec, length) mel, _ = mel_normalizer(mel, length) spec_loss = criterion(output, spec, length_mask) if args.n_mels > 0: mel_loss = criterion(output_mel, mel, length_mel_mask) else: mel_loss = 0 final_loss = mel_loss + spec_loss losses.update(final_loss.item(), phone.size(0)) spec_losses.update(spec_loss.item(), phone.size(0)) if args.n_mels > 0: mel_losses.update(mel_loss.item(), phone.size(0)) # normalize inverse stage if args.normalize and args.stats_file: output, _ = sepc_normalizer.inverse(output, length) # spec,_ = sepc_normalizer.inverse(spec,length) (mcd_value, length_sum,) = Metrics.Calculate_melcd_fromLinearSpectrum( output, spec_origin, length, args ) ( f0_distortion_value, voiced_frame_number_step, vuv_error_value, frame_number_step, f0_ground_truth_step, f0_synthesis_step, ) = Metrics.Calculate_f0RMSE_VUV_CORR_fromWav( output, spec_origin, length, args, "test" ) f0_ground_truth_all = np.concatenate( (f0_ground_truth_all, f0_ground_truth_step), axis=0 ) f0_synthesis_all = np.concatenate( (f0_synthesis_all, f0_synthesis_step), axis=0 ) mcd_metric.update(mcd_value, length_sum) f0_distortion_metric.update(f0_distortion_value, voiced_frame_number_step) vuv_error_metric.update(vuv_error_value, frame_number_step) if step % 1 == 0: if args.vocoder_category == "griffin": log_figure( step, output, spec_origin, att, length, args.prediction_path, args, ) elif args.vocoder_category == "wavernn": log_mel( step, output_mel, spec_origin, att, length, args.prediction_path, args, voc_model, ) out_log = ( "step {}:train_loss{:.4f};" "spec_loss{:.4f};mcd_value{:.4f};".format( step, losses.avg, spec_losses.avg, mcd_metric.avg, ) ) if args.perceptual_loss > 0: out_log += " pe_loss {:.4f}; ".format(pe_losses.avg) if args.n_mels > 0: out_log += " mel_loss {:.4f}; ".format(mel_losses.avg) if args.double_mel_loss: out_log += " dmel_loss {:.4f}; ".format(double_mel_losses.avg) end = time.time() logging.info(f"{out_log} -- sum_time: {(end - start_t_test)}s") end_t_test = time.time() out_log = "Test Stage: " out_log += "spec_loss: {:.4f} ".format(spec_losses.avg) if args.n_mels > 0: out_log += "mel_loss: {:.4f}, ".format(mel_losses.avg) # if args.perceptual_loss > 0: # out_log += 'pe_loss: {:.4f}, '.format(train_info['pe_loss']) f0_corr = Metrics.compute_f0_corr(f0_ground_truth_all, f0_synthesis_all) out_log += "\n\t mcd_value {:.4f} dB ".format(mcd_metric.avg) out_log += ( " f0_rmse_value {:.4f} Hz, " "vuv_error_value {:.4f} %, F0_CORR {:.4f}; ".format( np.sqrt(f0_distortion_metric.avg), vuv_error_metric.avg * 100, f0_corr, ) ) logging.info("{} time: {:.2f}s".format(out_log, end_t_test - start_t_test))