def train(args): """train.""" np.random.seed(args.seed) torch.manual_seed(args.seed) if torch.cuda.is_available() and args.auto_select_gpu is True: cvd = use_single_gpu() logging.info(f"GPU {cvd} is used") torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # torch.backends.cudnn.enabled = False device = torch.device("cuda" if torch.cuda.is_available() else "cpu") elif torch.cuda.is_available() and args.auto_select_gpu is False: torch.cuda.set_device(args.gpu_id) logging.info(f"GPU {args.gpu_id} is used") torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # torch.backends.cudnn.enabled = False device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: device = torch.device("cpu") logging.info("Warning: CPU is used") train_set = SVSDataset( align_root_path=args.train_align, pitch_beat_root_path=args.train_pitch, wav_root_path=args.train_wav, char_max_len=args.char_max_len, max_len=args.num_frames, sr=args.sampling_rate, preemphasis=args.preemphasis, nfft=args.nfft, frame_shift=args.frame_shift, frame_length=args.frame_length, n_mels=args.n_mels, power=args.power, max_db=args.max_db, ref_db=args.ref_db, sing_quality=args.sing_quality, standard=args.standard, db_joint=args.db_joint, Hz2semitone=args.Hz2semitone, semitone_min=args.semitone_min, semitone_max=args.semitone_max, phone_shift_size=args.phone_shift_size, semitone_shift=args.semitone_shift, ) dev_set = SVSDataset( align_root_path=args.val_align, pitch_beat_root_path=args.val_pitch, wav_root_path=args.val_wav, char_max_len=args.char_max_len, max_len=args.num_frames, sr=args.sampling_rate, preemphasis=args.preemphasis, nfft=args.nfft, frame_shift=args.frame_shift, frame_length=args.frame_length, n_mels=args.n_mels, power=args.power, max_db=args.max_db, ref_db=args.ref_db, sing_quality=args.sing_quality, standard=args.standard, db_joint=args.db_joint, Hz2semitone=args.Hz2semitone, semitone_min=args.semitone_min, semitone_max=args.semitone_max, phone_shift_size=-1, semitone_shift=False, ) collate_fn_svs_train = SVSCollator( args.num_frames, args.char_max_len, args.use_asr_post, args.phone_size, args.n_mels, args.db_joint, args.random_crop, args.crop_min_length, args.Hz2semitone, ) collate_fn_svs_val = SVSCollator( args.num_frames, args.char_max_len, args.use_asr_post, args.phone_size, args.n_mels, args.db_joint, False, # random crop -1, # crop_min_length args.Hz2semitone, ) train_loader = torch.utils.data.DataLoader( dataset=train_set, batch_size=args.batchsize, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn_svs_train, pin_memory=True, ) dev_loader = torch.utils.data.DataLoader( dataset=dev_set, batch_size=args.batchsize, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn_svs_val, pin_memory=True, ) assert (args.feat_dim == dev_set[0]["spec"].shape[1] or args.feat_dim == dev_set[0]["mel"].shape[1]) if args.collect_stats: collect_stats(train_loader, args) logging.info("collect_stats finished !") quit() # prepare model if args.model_type == "GLU_Transformer": if args.db_joint: model = GLU_TransformerSVS_combine( phone_size=args.phone_size, singer_size=args.singer_size, embed_size=args.embedding_size, hidden_size=args.hidden_size, glu_num_layers=args.glu_num_layers, dropout=args.dropout, output_dim=args.feat_dim, dec_nhead=args.dec_nhead, dec_num_block=args.dec_num_block, n_mels=args.n_mels, double_mel_loss=args.double_mel_loss, local_gaussian=args.local_gaussian, Hz2semitone=args.Hz2semitone, semitone_size=args.semitone_size, device=device, ) else: model = GLU_TransformerSVS( phone_size=args.phone_size, embed_size=args.embedding_size, hidden_size=args.hidden_size, glu_num_layers=args.glu_num_layers, dropout=args.dropout, output_dim=args.feat_dim, dec_nhead=args.dec_nhead, dec_num_block=args.dec_num_block, n_mels=args.n_mels, double_mel_loss=args.double_mel_loss, local_gaussian=args.local_gaussian, Hz2semitone=args.Hz2semitone, semitone_size=args.semitone_size, device=device, ) elif args.model_type == "LSTM": if args.db_joint: model = LSTMSVS_combine( phone_size=args.phone_size, singer_size=args.singer_size, embed_size=args.embedding_size, d_model=args.hidden_size, num_layers=args.num_rnn_layers, dropout=args.dropout, d_output=args.feat_dim, n_mels=args.n_mels, double_mel_loss=args.double_mel_loss, Hz2semitone=args.Hz2semitone, semitone_size=args.semitone_size, device=device, use_asr_post=args.use_asr_post, ) else: model = LSTMSVS( phone_size=args.phone_size, embed_size=args.embedding_size, d_model=args.hidden_size, num_layers=args.num_rnn_layers, dropout=args.dropout, d_output=args.feat_dim, n_mels=args.n_mels, double_mel_loss=args.double_mel_loss, Hz2semitone=args.Hz2semitone, semitone_size=args.semitone_size, device=device, use_asr_post=args.use_asr_post, ) elif args.model_type == "GRU_gs": model = GRUSVS_gs( phone_size=args.phone_size, embed_size=args.embedding_size, d_model=args.hidden_size, num_layers=args.num_rnn_layers, dropout=args.dropout, d_output=args.feat_dim, n_mels=args.n_mels, device=device, use_asr_post=args.use_asr_post, ) elif args.model_type == "PureTransformer": model = TransformerSVS( phone_size=args.phone_size, embed_size=args.embedding_size, hidden_size=args.hidden_size, glu_num_layers=args.glu_num_layers, dropout=args.dropout, output_dim=args.feat_dim, dec_nhead=args.dec_nhead, dec_num_block=args.dec_num_block, n_mels=args.n_mels, double_mel_loss=args.double_mel_loss, local_gaussian=args.local_gaussian, device=device, ) elif args.model_type == "Conformer": model = ConformerSVS( phone_size=args.phone_size, embed_size=args.embedding_size, enc_attention_dim=args.enc_attention_dim, enc_attention_heads=args.enc_attention_heads, enc_linear_units=args.enc_linear_units, enc_num_blocks=args.enc_num_blocks, enc_dropout_rate=args.enc_dropout_rate, enc_positional_dropout_rate=args.enc_positional_dropout_rate, enc_attention_dropout_rate=args.enc_attention_dropout_rate, enc_input_layer=args.enc_input_layer, enc_normalize_before=args.enc_normalize_before, enc_concat_after=args.enc_concat_after, enc_positionwise_layer_type=args.enc_positionwise_layer_type, enc_positionwise_conv_kernel_size=( args.enc_positionwise_conv_kernel_size), enc_macaron_style=args.enc_macaron_style, enc_pos_enc_layer_type=args.enc_pos_enc_layer_type, enc_selfattention_layer_type=args.enc_selfattention_layer_type, enc_activation_type=args.enc_activation_type, enc_use_cnn_module=args.enc_use_cnn_module, enc_cnn_module_kernel=args.enc_cnn_module_kernel, enc_padding_idx=args.enc_padding_idx, output_dim=args.feat_dim, dec_nhead=args.dec_nhead, dec_num_block=args.dec_num_block, n_mels=args.n_mels, double_mel_loss=args.double_mel_loss, local_gaussian=args.local_gaussian, dec_dropout=args.dec_dropout, Hz2semitone=args.Hz2semitone, semitone_size=args.semitone_size, device=device, ) elif args.model_type == "Comformer_full": if args.db_joint: model = ConformerSVS_FULL_combine( phone_size=args.phone_size, singer_size=args.singer_size, embed_size=args.embedding_size, output_dim=args.feat_dim, n_mels=args.n_mels, enc_attention_dim=args.enc_attention_dim, enc_attention_heads=args.enc_attention_heads, enc_linear_units=args.enc_linear_units, enc_num_blocks=args.enc_num_blocks, enc_dropout_rate=args.enc_dropout_rate, enc_positional_dropout_rate=args.enc_positional_dropout_rate, enc_attention_dropout_rate=args.enc_attention_dropout_rate, enc_input_layer=args.enc_input_layer, enc_normalize_before=args.enc_normalize_before, enc_concat_after=args.enc_concat_after, enc_positionwise_layer_type=args.enc_positionwise_layer_type, enc_positionwise_conv_kernel_size=( args.enc_positionwise_conv_kernel_size), enc_macaron_style=args.enc_macaron_style, enc_pos_enc_layer_type=args.enc_pos_enc_layer_type, enc_selfattention_layer_type=args.enc_selfattention_layer_type, enc_activation_type=args.enc_activation_type, enc_use_cnn_module=args.enc_use_cnn_module, enc_cnn_module_kernel=args.enc_cnn_module_kernel, enc_padding_idx=args.enc_padding_idx, dec_attention_dim=args.dec_attention_dim, dec_attention_heads=args.dec_attention_heads, dec_linear_units=args.dec_linear_units, dec_num_blocks=args.dec_num_blocks, dec_dropout_rate=args.dec_dropout_rate, dec_positional_dropout_rate=args.dec_positional_dropout_rate, dec_attention_dropout_rate=args.dec_attention_dropout_rate, dec_input_layer=args.dec_input_layer, dec_normalize_before=args.dec_normalize_before, dec_concat_after=args.dec_concat_after, dec_positionwise_layer_type=args.dec_positionwise_layer_type, dec_positionwise_conv_kernel_size=( args.dec_positionwise_conv_kernel_size), dec_macaron_style=args.dec_macaron_style, dec_pos_enc_layer_type=args.dec_pos_enc_layer_type, dec_selfattention_layer_type=args.dec_selfattention_layer_type, dec_activation_type=args.dec_activation_type, dec_use_cnn_module=args.dec_use_cnn_module, dec_cnn_module_kernel=args.dec_cnn_module_kernel, dec_padding_idx=args.dec_padding_idx, Hz2semitone=args.Hz2semitone, semitone_size=args.semitone_size, device=device, ) else: model = ConformerSVS_FULL( phone_size=args.phone_size, embed_size=args.embedding_size, output_dim=args.feat_dim, n_mels=args.n_mels, enc_attention_dim=args.enc_attention_dim, enc_attention_heads=args.enc_attention_heads, enc_linear_units=args.enc_linear_units, enc_num_blocks=args.enc_num_blocks, enc_dropout_rate=args.enc_dropout_rate, enc_positional_dropout_rate=args.enc_positional_dropout_rate, enc_attention_dropout_rate=args.enc_attention_dropout_rate, enc_input_layer=args.enc_input_layer, enc_normalize_before=args.enc_normalize_before, enc_concat_after=args.enc_concat_after, enc_positionwise_layer_type=args.enc_positionwise_layer_type, enc_positionwise_conv_kernel_size=( args.enc_positionwise_conv_kernel_size), enc_macaron_style=args.enc_macaron_style, enc_pos_enc_layer_type=args.enc_pos_enc_layer_type, enc_selfattention_layer_type=args.enc_selfattention_layer_type, enc_activation_type=args.enc_activation_type, enc_use_cnn_module=args.enc_use_cnn_module, enc_cnn_module_kernel=args.enc_cnn_module_kernel, enc_padding_idx=args.enc_padding_idx, dec_attention_dim=args.dec_attention_dim, dec_attention_heads=args.dec_attention_heads, dec_linear_units=args.dec_linear_units, dec_num_blocks=args.dec_num_blocks, dec_dropout_rate=args.dec_dropout_rate, dec_positional_dropout_rate=args.dec_positional_dropout_rate, dec_attention_dropout_rate=args.dec_attention_dropout_rate, dec_input_layer=args.dec_input_layer, dec_normalize_before=args.dec_normalize_before, dec_concat_after=args.dec_concat_after, dec_positionwise_layer_type=args.dec_positionwise_layer_type, dec_positionwise_conv_kernel_size=( args.dec_positionwise_conv_kernel_size), dec_macaron_style=args.dec_macaron_style, dec_pos_enc_layer_type=args.dec_pos_enc_layer_type, dec_selfattention_layer_type=args.dec_selfattention_layer_type, dec_activation_type=args.dec_activation_type, dec_use_cnn_module=args.dec_use_cnn_module, dec_cnn_module_kernel=args.dec_cnn_module_kernel, dec_padding_idx=args.dec_padding_idx, Hz2semitone=args.Hz2semitone, semitone_size=args.semitone_size, device=device, ) elif args.model_type == "USTC_DAR": model = USTC_SVS( phone_size=args.phone_size, embed_size=args.embedding_size, middle_dim_fc=args.middle_dim_fc, output_dim=args.feat_dim, multi_history_num=args.multi_history_num, middle_dim_prenet=args.middle_dim_prenet, n_blocks_prenet=args.n_blocks_prenet, n_heads_prenet=args.n_heads_prenet, kernel_size_prenet=args.kernel_size_prenet, bi_d_model=args.bi_d_model, bi_num_layers=args.bi_num_layers, uni_d_model=args.uni_d_model, uni_num_layers=args.uni_num_layers, dropout=args.dropout, feedbackLink_drop_rate=args.feedbackLink_drop_rate, device=device, ) else: raise ValueError("Not Support Model Type %s" % args.model_type) logging.info(f"{model}") model = model.to(device) logging.info( f"The model has {count_parameters(model):,} trainable parameters") model_load_dir = "" pretrain_encoder_dir = "" start_epoch = 1 # FIX ME if args.pretrain_encoder != "": pretrain_encoder_dir = args.pretrain_encoder if args.initmodel != "": model_load_dir = args.initmodel if args.resume and os.path.exists(args.model_save_dir): checks = os.listdir(args.model_save_dir) start_epoch = max( list( map( lambda x: int(x.split(".")[0].split("_")[-1]) if x.endswith("pth.tar") else -1, checks, ))) model_temp_load_dir = "{}/epoch_loss_{}.pth.tar".format( args.model_save_dir, start_epoch) if start_epoch < 0: model_load_dir = "" elif os.path.isfile(model_temp_load_dir): model_load_dir = model_temp_load_dir else: model_load_dir = "{}/epoch_spec_loss_{}.pth.tar".format( args.model_save_dir, start_epoch) # load encoder parm from Transformer-TTS if pretrain_encoder_dir != "": pretrain = torch.load(pretrain_encoder_dir, map_location=device) pretrain_dict = pretrain["model"] model_dict = model.state_dict() state_dict_new = {} para_list = [] i = 0 for k, v in pretrain_dict.items(): k_new = k[7:] if (k_new in model_dict and model_dict[k_new].size() == pretrain_dict[k].size()): i += 1 state_dict_new[k_new] = v model_dict.update(state_dict_new) model.load_state_dict(model_dict) logging.info(f"Load {i} layers total. Load pretrain encoder success !") # load weights for pre-trained model if model_load_dir != "": logging.info(f"Model Start to Load, dir: {model_load_dir}") model_load = torch.load(model_load_dir, map_location=device) loading_dict = model_load["state_dict"] model_dict = model.state_dict() state_dict_new = {} para_list = [] for k, v in loading_dict.items(): # assert k in model_dict if (k == "normalizer.mean" or k == "normalizer.std" or k == "mel_normalizer.mean" or k == "mel_normalizer.std"): continue if model_dict[k].size() == loading_dict[k].size(): state_dict_new[k] = v else: para_list.append(k) logging.info(f"Total {len(loading_dict)} parameter sets, " f"Loaded {len(state_dict_new)} parameter sets") if len(para_list) > 0: logging.warning("Not loading {} because of different sizes".format( ", ".join(para_list))) model_dict.update(state_dict_new) model.load_state_dict(model_dict) logging.info(f"Loaded checkpoint {args.initmodel}") # setup optimizer if args.optimizer == "noam": optimizer = ScheduledOptim( torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-09), args.hidden_size, args.noam_warmup_steps, args.noam_scale, ) elif args.optimizer == "adam": optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-09) if args.scheduler == "OneCycleLR": scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=args.lr, steps_per_epoch=len(train_loader), epochs=args.max_epochs, ) elif args.scheduler == "ReduceLROnPlateau": scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, "min", verbose=True, patience=10, factor=0.5) elif args.scheduler == "ExponentialLR": scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, verbose=True, gamma=0.9886) else: raise ValueError("Not Support Optimizer") # Setup tensorborad logger if args.use_tfboard: from tensorboardX import SummaryWriter logger = SummaryWriter("{}/log".format(args.model_save_dir)) else: logger = None if args.loss == "l1": loss = MaskedLoss("l1", mask_free=args.mask_free) elif args.loss == "mse": loss = MaskedLoss("mse", mask_free=args.mask_free) else: raise ValueError("Not Support Loss Type") if args.perceptual_loss > 0: win_length = int(args.sampling_rate * args.frame_length) psd_dict, bark_num = cal_psd2bark_dict(fs=args.sampling_rate, win_len=win_length) sf = cal_spread_function(bark_num) loss_perceptual_entropy = PerceptualEntropy(bark_num, sf, args.sampling_rate, win_length, psd_dict) else: loss_perceptual_entropy = None # Training total_loss_epoch_to_save = {} total_loss_counter = 0 spec_loss_epoch_to_save = {} spec_loss_counter = 0 total_learning_step = 0 # args.num_saved_model = 5 # preload vocoder model voc_model = [] if args.vocoder_category == "wavernn": logging.info("load voc_model from {}".format(args.wavernn_voc_model)) voc_model = WaveRNN( rnn_dims=args.voc_rnn_dims, fc_dims=args.voc_fc_dims, bits=args.voc_bits, pad=args.voc_pad, upsample_factors=( args.voc_upsample_factors_0, args.voc_upsample_factors_1, args.voc_upsample_factors_2, ), feat_dims=args.n_mels, compute_dims=args.voc_compute_dims, res_out_dims=args.voc_res_out_dims, res_blocks=args.voc_res_blocks, hop_length=args.hop_length, sample_rate=args.sampling_rate, mode=args.voc_mode, ).to(device) voc_model.load(args.wavernn_voc_model) for epoch in range(start_epoch + 1, 1 + args.max_epochs): """Train Stage""" start_t_train = time.time() train_info = train_one_epoch( train_loader, model, device, optimizer, loss, loss_perceptual_entropy, epoch, args, voc_model, ) end_t_train = time.time() out_log = "Train epoch: {:04d}, ".format(epoch) if args.optimizer == "noam": out_log += "lr: {:.6f}, ".format( optimizer._optimizer.param_groups[0]["lr"]) elif args.optimizer == "adam": out_log += "lr: {:.6f}, ".format(optimizer.param_groups[0]["lr"]) if args.vocoder_category == "wavernn": out_log += "loss: {:.4f} ".format(train_info["loss"]) else: out_log += "loss: {:.4f}, spec_loss: {:.4f} ".format( train_info["loss"], train_info["spec_loss"]) if args.n_mels > 0: out_log += "mel_loss: {:.4f}, ".format(train_info["mel_loss"]) if args.perceptual_loss > 0: out_log += "pe_loss: {:.4f}, ".format(train_info["pe_loss"]) logging.info("{} time: {:.2f}s".format(out_log, end_t_train - start_t_train)) """Dev Stage""" torch.backends.cudnn.enabled = False # 莫名的bug,关掉才可以跑 # start_t_dev = time.time() dev_info = validate( dev_loader, model, device, loss, loss_perceptual_entropy, epoch, args, voc_model, ) end_t_dev = time.time() dev_log = "Dev epoch: {:04d}, loss: {:.4f}, spec_loss: {:.4f}, ".format( epoch, dev_info["loss"], dev_info["spec_loss"]) dev_log += "mcd_value: {:.4f}, ".format(dev_info["mcd_value"]) if args.n_mels > 0: dev_log += "mel_loss: {:.4f}, ".format(dev_info["mel_loss"]) if args.perceptual_loss > 0: dev_log += "pe_loss: {:.4f}, ".format(dev_info["pe_loss"]) logging.info("{} time: {:.2f}s".format(dev_log, end_t_dev - start_t_train)) sys.stdout.flush() torch.backends.cudnn.enabled = True if args.scheduler == "OneCycleLR": scheduler.step() elif args.scheduler == "ReduceLROnPlateau": scheduler.step(dev_info["loss"]) elif args.scheduler == "ExponentialLR": before = total_learning_step // args.lr_decay_learning_steps total_learning_step += len(train_loader) after = total_learning_step // args.lr_decay_learning_steps if after > before: # decay per 250 learning steps scheduler.step() """Save model Stage""" if not os.path.exists(args.model_save_dir): os.makedirs(args.model_save_dir) (total_loss_counter, total_loss_epoch_to_save) = Auto_save_model( args, epoch, model, optimizer, train_info, dev_info, logger, total_loss_counter, total_loss_epoch_to_save, save_loss_select="loss", ) if (dev_info["spec_loss"] != 0): # spec_loss 有意义时再存模型,比如 USTC DAR model 不需要计算线性谱spec loss (spec_loss_counter, spec_loss_epoch_to_save) = Auto_save_model( args, epoch, model, optimizer, train_info, dev_info, logger, spec_loss_counter, spec_loss_epoch_to_save, save_loss_select="spec_loss", ) if args.use_tfboard: logger.close()
def data_filter(args): model, device = load_model(args) start_t_test = time.time() # Decode test_set = SVSDataset_filter( align_root_path=args.test_align, pitch_beat_root_path=args.test_pitch, wav_root_path=args.test_wav, char_max_len=args.char_max_len, max_len=args.num_frames, sr=args.sampling_rate, preemphasis=args.preemphasis, nfft=args.nfft, frame_shift=args.frame_shift, frame_length=args.frame_length, n_mels=args.n_mels, power=args.power, max_db=args.max_db, ref_db=args.ref_db, standard=args.standard, sing_quality=args.sing_quality, Hz2semitone=args.Hz2semitone, semitone_min=args.semitone_min, semitone_max=args.semitone_max, ) collate_fn_svs = SVSCollator( args.num_frames, args.char_max_len, args.use_asr_post, args.phone_size, args.n_mels, args.db_joint, False, # random crop -1, # crop_min_length args.Hz2semitone, ) test_loader = torch.utils.data.DataLoader( dataset=test_set, batch_size=1, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn_svs, pin_memory=True, ) with torch.no_grad(): for ( step, data_step, ) in enumerate(test_loader, 1): if args.db_joint: (phone, beat, pitch, spec, real, imag, length, chars, char_len_list, mel, singer_id, semitone, filename_list) = data_step else: print( "No support for augmentation with args.db_joint == False") quit() singer_id = np.array(singer_id).reshape(np.shape(phone)[0], -1) # [batch size, 1] singer_vec = singer_id.repeat(np.shape(phone)[1], axis=1) # [batch size, length] singer_vec = torch.from_numpy(singer_vec).to(device) singer_id = torch.from_numpy(singer_id).to(device) phone = phone.to(device) beat = beat.to(device) pitch = pitch.to(device).float() if semitone is not None: semitone = semitone.to(device) spec = spec.to(device).float() mel = mel.to(device).float() real = real.to(device).float() imag = imag.to(device).float() length_mask = (length > 0).int().unsqueeze(2) length_mel_mask = length_mask.repeat(1, 1, mel.shape[2]).float() length_mask = length_mask.repeat(1, 1, spec.shape[2]).float() length_mask = length_mask.to(device) length_mel_mask = length_mel_mask.to(device) length = length.to(device) char_len_list = char_len_list.to(device) if not args.use_asr_post: chars = chars.to(device) char_len_list = char_len_list.to(device) else: phone = phone.float() if args.Hz2semitone: pitch = semitone if args.normalize: sepc_normalizer = GlobalMVN(args.stats_file) mel_normalizer = GlobalMVN(args.stats_mel_file) spec, _ = sepc_normalizer(spec, length) mel, _ = mel_normalizer(mel, length) len_list, _ = torch.max(length, dim=1) # [len1, len2, len3, ...] len_list = len_list.cpu().detach().numpy() singer_out, phone_out, semitone_out = model(spec, len_list) # calculate num batch_size = np.shape(spec)[0] singer_id = singer_id.view(-1) # [batch size] _, singer_predict = torch.max(singer_out, dim=1) # [batch size] singer_correct = singer_predict.eq(singer_id).cpu().sum().numpy() for i in range(batch_size): phone_i = phone[i, :len_list[i], :].view(-1) # [valid seq len] phone_out_i = phone_out[ i, :len_list[i], :] # [valid seq len, phone_size] _, phone_predict = torch.max(phone_out_i, dim=1) phone_correct = phone_predict.eq(phone_i).cpu().sum().numpy() semitone_i = semitone[i, :len_list[i], :].view( -1) # [valid seq len] semitone_out_i = semitone_out[ i, :len_list[i], :] # [valid seq len, semitone_size] _, semitone_predict = torch.max(semitone_out_i, dim=1) semitone_correct = semitone_predict.eq( semitone_i).cpu().sum().numpy() with open(os.path.join(args.prediction_path, "filter_res.txt"), "a+") as f: f.write( f"{filename_list[i]}|{singer_predict[i]}|{phone_correct}|{semitone_correct}|{len_list[i]}\n" ) end = time.time() logging.info( f"{filename_list[i]} -- sum_time: {(end - start_t_test)}s")
def infer_predictor(args): """infer.""" torch.cuda.set_device(args.gpu_id) logging.info(f"GPU {args.gpu_id} is used") torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.enabled = False device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # prepare model model = RNN_Discriminator( embed_size=128, d_model=128, hidden_size=128, num_layers=2, n_specs=1025, singer_size=7, phone_size=43, simitone_size=59, dropout=0.1, bidirectional=True, device=device, ) logging.info(f"{model}") model = model.to(device) logging.info( f"The model has {count_parameters(model):,} trainable parameters") # Load model weights logging.info(f"Loading pretrained weights from {args.model_file}") checkpoint = torch.load(args.model_file, map_location=device) state_dict = checkpoint["state_dict"] model_dict = model.state_dict() state_dict_new = {} para_list = [] for k, v in state_dict.items(): # assert k in model_dict if (k == "normalizer.mean" or k == "normalizer.std" or k == "mel_normalizer.mean" or k == "mel_normalizer.std"): continue if model_dict[k].size() == state_dict[k].size(): state_dict_new[k] = v else: para_list.append(k) logging.info(f"Total {len(state_dict)} parameter sets, " f"loaded {len(state_dict_new)} parameter set") if len(para_list) > 0: logging.warning(f"Not loading {para_list} because of different sizes") model.load_state_dict(state_dict_new) logging.info(f"Loaded checkpoint {args.model_file}") model = model.to(device) model.eval() # Decode test_set = SVSDataset( align_root_path=args.test_align, pitch_beat_root_path=args.test_pitch, wav_root_path=args.test_wav, char_max_len=args.char_max_len, max_len=args.num_frames, sr=args.sampling_rate, preemphasis=args.preemphasis, nfft=args.nfft, frame_shift=args.frame_shift, frame_length=args.frame_length, n_mels=args.n_mels, power=args.power, max_db=args.max_db, ref_db=args.ref_db, standard=args.standard, sing_quality=args.sing_quality, db_joint=args.db_joint, Hz2semitone=args.Hz2semitone, semitone_min=args.semitone_min, semitone_max=args.semitone_max, phone_shift_size=-1, semitone_shift=False, ) collate_fn_svs = SVSCollator( args.num_frames, args.char_max_len, args.use_asr_post, args.phone_size, args.n_mels, args.db_joint, False, # random crop -1, # crop_min_length args.Hz2semitone, ) test_loader = torch.utils.data.DataLoader( dataset=test_set, batch_size=1, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn_svs, pin_memory=True, ) criterion = nn.CrossEntropyLoss(reduction="sum") start_t_test = time.time() singer_losses = AverageMeter() phone_losses = AverageMeter() semitone_losses = AverageMeter() singer_count = AverageMeter() phone_count = AverageMeter() semitone_count = AverageMeter() with torch.no_grad(): for (step, data_step) in enumerate(test_loader, 1): if args.db_joint: ( phone, beat, pitch, spec, real, imag, length, chars, char_len_list, mel, singer_id, semitone, ) = data_step singer_id = np.array(singer_id).reshape( np.shape(phone)[0], -1) # [batch size, 1] singer_vec = singer_id.repeat(np.shape(phone)[1], axis=1) # [batch size, length] singer_vec = torch.from_numpy(singer_vec).to(device) singer_id = torch.from_numpy(singer_id).to(device) else: ( phone, beat, pitch, spec, real, imag, length, chars, char_len_list, mel, semitone, ) = data_step phone = phone.to(device) beat = beat.to(device) pitch = pitch.to(device).float() if semitone is not None: semitone = semitone.to(device) spec = spec.to(device).float() mel = mel.to(device).float() real = real.to(device).float() imag = imag.to(device).float() length_mask = (length > 0).int().unsqueeze(2) length_mel_mask = length_mask.repeat(1, 1, mel.shape[2]).float() length_mask = length_mask.repeat(1, 1, spec.shape[2]).float() length_mask = length_mask.to(device) length_mel_mask = length_mel_mask.to(device) length = length.to(device) char_len_list = char_len_list.to(device) if not args.use_asr_post: chars = chars.to(device) char_len_list = char_len_list.to(device) else: phone = phone.float() if args.Hz2semitone: pitch = semitone if args.normalize: sepc_normalizer = GlobalMVN(args.stats_file) mel_normalizer = GlobalMVN(args.stats_mel_file) spec, _ = sepc_normalizer(spec, length) mel, _ = mel_normalizer(mel, length) len_list, _ = torch.max(length, dim=1) # [len1, len2, len3, ...] len_list = len_list.cpu().detach().numpy() singer_out, phone_out, semitone_out = model(spec, len_list) # calculate CrossEntropy loss (defination - reduction:sum) phone_loss = 0 semitone_loss = 0 phone_correct = 0 semitone_correct = 0 batch_size = np.shape(spec)[0] for i in range(batch_size): phone_i = phone[i, :len_list[i], :].view(-1) # [valid seq len] phone_out_i = phone_out[ i, :len_list[i], :] # [valid seq len, phone_size] phone_loss += criterion(phone_out_i, phone_i) _, phone_predict = torch.max(phone_out_i, dim=1) phone_correct += phone_predict.eq(phone_i).cpu().sum().numpy() semitone_i = semitone[i, :len_list[i], :].view( -1) # [valid seq len] semitone_out_i = semitone_out[ i, :len_list[i], :] # [valid seq len, semitone_size] semitone_loss += criterion(semitone_out_i, semitone_i) _, semitone_predict = torch.max(semitone_out_i, dim=1) semitone_correct += semitone_predict.eq( semitone_i).cpu().sum().numpy() singer_id = singer_id.view(-1) # [batch size] _, singer_predict = torch.max(singer_out, dim=1) singer_correct = singer_predict.eq(singer_id).cpu().sum().numpy() phone_loss /= np.sum(len_list) semitone_loss /= np.sum(len_list) singer_loss = criterion(singer_out, singer_id) / batch_size # restore loss info singer_losses.update(singer_loss.item(), batch_size) phone_losses.update(phone_loss.item(), np.sum(len_list)) semitone_losses.update(semitone_loss.item(), np.sum(len_list)) singer_count.update(singer_correct.item() / batch_size, batch_size) phone_count.update(phone_correct.item() / np.sum(len_list), np.sum(len_list)) semitone_count.update(semitone_correct.item() / np.sum(len_list), np.sum(len_list)) if step % 1 == 0: end = time.time() out_log = "step {}: loss {:.6f}, ".format( step, singer_losses.avg + phone_losses.avg + semitone_losses.avg) out_log += "\t singer_loss: {:.4f} ".format(singer_losses.avg) out_log += "phone_loss: {:.4f} ".format(phone_losses.avg) out_log += "semitone_loss: {:.4f} \n".format( semitone_losses.avg) out_log += "\t singer_accuracy: {:.4f}% ".format( singer_count.avg * 100) out_log += "phone_accuracy: {:.4f}% ".format(phone_count.avg * 100) out_log += "semitone_accuracy: {:.4f}% ".format( semitone_count.avg * 100) print("{} -- sum_time: {:.2f}s".format(out_log, (end - start_t_test))) end_t_test = time.time() out_log = "\nTest Stage: " out_log += "loss: {:.4f}, ".format(singer_losses.avg + phone_losses.avg + semitone_losses.avg) out_log += "singer_loss: {:.4f}, ".format(singer_losses.avg) out_log += "phone_loss: {:.4f}, semitone_loss: {:.4f} \n".format( phone_losses.avg, semitone_losses.avg, ) out_log += "singer_accuracy: {:.4f}%, ".format(singer_count.avg * 100) out_log += "phone_accuracy: {:.4f}%, semitone_accuracy: {:.4f}% ".format( phone_count.avg * 100, semitone_count.avg * 100) logging.info("{} time: {:.2f}s".format(out_log, end_t_test - start_t_test))
def augmentation(args, target_singer_id, output_path): if not os.path.exists(output_path): os.makedirs(output_path) model, device = load_model(args) start_t_test = time.time() # Decode test_set = SVSDataset( align_root_path=args.test_align, pitch_beat_root_path=args.test_pitch, wav_root_path=args.test_wav, char_max_len=args.char_max_len, max_len=args.num_frames, sr=args.sampling_rate, preemphasis=args.preemphasis, nfft=args.nfft, frame_shift=args.frame_shift, frame_length=args.frame_length, n_mels=args.n_mels, power=args.power, max_db=args.max_db, ref_db=args.ref_db, standard=args.standard, sing_quality=args.sing_quality, db_joint=args.db_joint, Hz2semitone=args.Hz2semitone, semitone_min=args.semitone_min, semitone_max=args.semitone_max, phone_shift_size=-1, semitone_shift=False, ) collate_fn_svs = SVSCollator( args.num_frames, args.char_max_len, args.use_asr_post, args.phone_size, args.n_mels, args.db_joint, False, # random crop -1, # crop_min_length args.Hz2semitone, ) test_loader = torch.utils.data.DataLoader( dataset=test_set, batch_size=1, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn_svs, pin_memory=True, ) with torch.no_grad(): for ( step, data_step, ) in enumerate(test_loader, 1): if args.db_joint: (phone, beat, pitch, spec, real, imag, length, chars, char_len_list, mel, singer_id, semitone, filename_list, flag_filter_list) = data_step else: print( "No support for augmentation with args.db_joint == False") quit() t_singer_id = np.array(target_singer_id).reshape( np.shape(phone)[0], -1) # [batch size, 1] singer_vec = t_singer_id.repeat(np.shape(phone)[1], axis=1) # [batch size, length] singer_vec = torch.from_numpy(singer_vec).to(device) phone = phone.to(device) beat = beat.to(device) pitch = pitch.to(device).float() if semitone is not None: semitone = semitone.to(device) spec = spec.to(device).float() mel = mel.to(device).float() real = real.to(device).float() imag = imag.to(device).float() length_mask = (length > 0).int().unsqueeze(2) length_mel_mask = length_mask.repeat(1, 1, mel.shape[2]).float() length_mask = length_mask.repeat(1, 1, spec.shape[2]).float() length_mask = length_mask.to(device) length_mel_mask = length_mel_mask.to(device) length = length.to(device) char_len_list = char_len_list.to(device) if not args.use_asr_post: chars = chars.to(device) char_len_list = char_len_list.to(device) else: phone = phone.float() if args.Hz2semitone: pitch = semitone if args.model_type == "GLU_Transformer": if args.db_joint: output, att, output_mel, output_mel2 = model( chars, phone, pitch, beat, singer_vec, pos_char=char_len_list, pos_spec=length, ) elif args.model_type == "LSTM": if args.db_joint: output, hidden, output_mel, output_mel2 = model( phone, pitch, beat, singer_vec) elif args.model_type == "Comformer_full": if args.db_joint: output, att, output_mel, output_mel2 = model( chars, phone, pitch, beat, singer_vec, pos_char=char_len_list, pos_spec=length, ) if args.normalize: sepc_normalizer = GlobalMVN(args.stats_file) mel_normalizer = GlobalMVN(args.stats_mel_file) # normalize inverse stage if args.normalize and args.stats_file: output, _ = sepc_normalizer.inverse(output, length) # write wav output = output.cpu().detach().numpy()[0] length = np.max(length.cpu().detach().numpy()[0]) output = output[:length] wav = spectrogram2wav( output, args.max_db, args.ref_db, args.preemphasis, args.power, args.sampling_rate, args.frame_shift, args.frame_length, args.nfft, ) wr_fname = filename_list[0] + "-" + str( target_singer_id) # batch_size = 1 if librosa.__version__ < "0.8.0": librosa.output.write_wav( os.path.join(output_path, "{}.wav".format(wr_fname)), wav, args.sampling_rate, ) else: # librosa > 0.8 remove librosa.output.write_wav module sf.write( os.path.join(output_path, "{}.wav".format(wr_fname)), wav, args.sampling_rate, format="wav", subtype="PCM_24", ) end = time.time() out_log = os.path.join(output_path, "{}.wav".format(wr_fname)) logging.info(f"{out_log} -- sum_time: {(end - start_t_test)}s")
def infer(args): """infer.""" torch.cuda.set_device(args.gpu_id) logging.info(f"GPU {args.gpu_id} is used") torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.enabled = False device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # prepare model if args.model_type == "GLU_Transformer": model = GLU_TransformerSVS( phone_size=args.phone_size, embed_size=args.embedding_size, hidden_size=args.hidden_size, glu_num_layers=args.glu_num_layers, dropout=args.dropout, output_dim=args.feat_dim, dec_nhead=args.dec_nhead, dec_num_block=args.dec_num_block, n_mels=args.n_mels, double_mel_loss=args.double_mel_loss, local_gaussian=args.local_gaussian, device=device, ) elif args.model_type == "LSTM": model = LSTMSVS( phone_size=args.phone_size, embed_size=args.embedding_size, d_model=args.hidden_size, num_layers=args.num_rnn_layers, dropout=args.dropout, d_output=args.feat_dim, n_mels=args.n_mels, double_mel_loss=args.double_mel_loss, device=device, use_asr_post=args.use_asr_post, ) elif args.model_type == "GRU_gs": model = GRUSVS_gs( phone_size=args.phone_size, embed_size=args.embedding_size, d_model=args.hidden_size, num_layers=args.num_rnn_layers, dropout=args.dropout, d_output=args.feat_dim, n_mels=args.n_mels, double_mel_loss=args.double_mel_loss, device=device, use_asr_post=args.use_asr_post, ) elif args.model_type == "PureTransformer": model = TransformerSVS( phone_size=args.phone_size, embed_size=args.embedding_size, hidden_size=args.hidden_size, glu_num_layers=args.glu_num_layers, dropout=args.dropout, output_dim=args.feat_dim, dec_nhead=args.dec_nhead, dec_num_block=args.dec_num_block, n_mels=args.n_mels, double_mel_loss=args.double_mel_loss, local_gaussian=args.local_gaussian, device=device, ) elif args.model_type == "Conformer": model = ConformerSVS( phone_size=args.phone_size, embed_size=args.embedding_size, enc_attention_dim=args.enc_attention_dim, enc_attention_heads=args.enc_attention_heads, enc_linear_units=args.enc_linear_units, enc_num_blocks=args.enc_num_blocks, enc_dropout_rate=args.enc_dropout_rate, enc_positional_dropout_rate=args.enc_positional_dropout_rate, enc_attention_dropout_rate=args.enc_attention_dropout_rate, enc_input_layer=args.enc_input_layer, enc_normalize_before=args.enc_normalize_before, enc_concat_after=args.enc_concat_after, enc_positionwise_layer_type=args.enc_positionwise_layer_type, enc_positionwise_conv_kernel_size=(args.enc_positionwise_conv_kernel_size), enc_macaron_style=args.enc_macaron_style, enc_pos_enc_layer_type=args.enc_pos_enc_layer_type, enc_selfattention_layer_type=args.enc_selfattention_layer_type, enc_activation_type=args.enc_activation_type, enc_use_cnn_module=args.enc_use_cnn_module, enc_cnn_module_kernel=args.enc_cnn_module_kernel, enc_padding_idx=args.enc_padding_idx, output_dim=args.feat_dim, dec_nhead=args.dec_nhead, dec_num_block=args.dec_num_block, n_mels=args.n_mels, double_mel_loss=args.double_mel_loss, local_gaussian=args.local_gaussian, dec_dropout=args.dec_dropout, device=device, ) elif args.model_type == "Comformer_full": model = ConformerSVS_FULL( phone_size=args.phone_size, embed_size=args.embedding_size, output_dim=args.feat_dim, n_mels=args.n_mels, enc_attention_dim=args.enc_attention_dim, enc_attention_heads=args.enc_attention_heads, enc_linear_units=args.enc_linear_units, enc_num_blocks=args.enc_num_blocks, enc_dropout_rate=args.enc_dropout_rate, enc_positional_dropout_rate=args.enc_positional_dropout_rate, enc_attention_dropout_rate=args.enc_attention_dropout_rate, enc_input_layer=args.enc_input_layer, enc_normalize_before=args.enc_normalize_before, enc_concat_after=args.enc_concat_after, enc_positionwise_layer_type=args.enc_positionwise_layer_type, enc_positionwise_conv_kernel_size=(args.enc_positionwise_conv_kernel_size), enc_macaron_style=args.enc_macaron_style, enc_pos_enc_layer_type=args.enc_pos_enc_layer_type, enc_selfattention_layer_type=args.enc_selfattention_layer_type, enc_activation_type=args.enc_activation_type, enc_use_cnn_module=args.enc_use_cnn_module, enc_cnn_module_kernel=args.enc_cnn_module_kernel, enc_padding_idx=args.enc_padding_idx, dec_attention_dim=args.dec_attention_dim, dec_attention_heads=args.dec_attention_heads, dec_linear_units=args.dec_linear_units, dec_num_blocks=args.dec_num_blocks, dec_dropout_rate=args.dec_dropout_rate, dec_positional_dropout_rate=args.dec_positional_dropout_rate, dec_attention_dropout_rate=args.dec_attention_dropout_rate, dec_input_layer=args.dec_input_layer, dec_normalize_before=args.dec_normalize_before, dec_concat_after=args.dec_concat_after, dec_positionwise_layer_type=args.dec_positionwise_layer_type, dec_positionwise_conv_kernel_size=(args.dec_positionwise_conv_kernel_size), dec_macaron_style=args.dec_macaron_style, dec_pos_enc_layer_type=args.dec_pos_enc_layer_type, dec_selfattention_layer_type=args.dec_selfattention_layer_type, dec_activation_type=args.dec_activation_type, dec_use_cnn_module=args.dec_use_cnn_module, dec_cnn_module_kernel=args.dec_cnn_module_kernel, dec_padding_idx=args.dec_padding_idx, device=device, ) else: raise ValueError("Not Support Model Type %s" % args.model_type) logging.info(f"{model}") logging.info(f"The model has {count_parameters(model):,} trainable parameters") # Load model weights logging.info(f"Loading pretrained weights from {args.model_file}") checkpoint = torch.load(args.model_file, map_location=device) state_dict = checkpoint["state_dict"] model_dict = model.state_dict() state_dict_new = {} para_list = [] for k, v in state_dict.items(): # assert k in model_dict if ( k == "normalizer.mean" or k == "normalizer.std" or k == "mel_normalizer.mean" or k == "mel_normalizer.std" ): continue if model_dict[k].size() == state_dict[k].size(): state_dict_new[k] = v else: para_list.append(k) logging.info( f"Total {len(state_dict)} parameter sets, " f"loaded {len(state_dict_new)} parameter set" ) if len(para_list) > 0: logging.warning(f"Not loading {para_list} because of different sizes") model.load_state_dict(state_dict_new) logging.info(f"Loaded checkpoint {args.model_file}") model = model.to(device) model.eval() # Decode test_set = SVSDataset( align_root_path=args.test_align, pitch_beat_root_path=args.test_pitch, wav_root_path=args.test_wav, char_max_len=args.char_max_len, max_len=args.num_frames, sr=args.sampling_rate, preemphasis=args.preemphasis, nfft=args.nfft, frame_shift=args.frame_shift, frame_length=args.frame_length, n_mels=args.n_mels, power=args.power, max_db=args.max_db, ref_db=args.ref_db, standard=args.standard, sing_quality=args.sing_quality, ) collate_fn_svs = SVSCollator( args.num_frames, args.char_max_len, args.use_asr_post, args.phone_size, ) test_loader = torch.utils.data.DataLoader( dataset=test_set, batch_size=1, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn_svs, pin_memory=True, ) if args.loss == "l1": criterion = MaskedLoss("l1") elif args.loss == "mse": criterion = MaskedLoss("mse") else: raise ValueError("Not Support Loss Type") losses = AverageMeter() spec_losses = AverageMeter() if args.perceptual_loss > 0: pe_losses = AverageMeter() if args.n_mels > 0: mel_losses = AverageMeter() mcd_metric = AverageMeter() f0_distortion_metric, vuv_error_metric = ( AverageMeter(), AverageMeter(), ) if args.double_mel_loss: double_mel_losses = AverageMeter() model.eval() if not os.path.exists(args.prediction_path): os.makedirs(args.prediction_path) f0_ground_truth_all = np.reshape(np.array([]), (-1, 1)) f0_synthesis_all = np.reshape(np.array([]), (-1, 1)) start_t_test = time.time() # preload vocoder model if args.vocoder_category == "wavernn": voc_model = WaveRNN( rnn_dims=512, fc_dims=512, bits=9, pad=2, upsample_factors=( 5, 5, 11, ), feat_dims=80, compute_dims=128, res_out_dims=128, res_blocks=10, hop_length=275, # 12.5ms - in line with Tacotron 2 paper sample_rate=22050, mode="MOL", ).to(device) voc_model.load("./weights/wavernn/latest_weights.pyt") with torch.no_grad(): for ( step, ( phone, beat, pitch, spec, real, imag, length, chars, char_len_list, mel, ), ) in enumerate(test_loader, 1): # if step >= args.decode_sample: # break phone = phone.to(device) beat = beat.to(device) pitch = pitch.to(device).float() spec = spec.to(device).float() mel = mel.to(device).float() real = real.to(device).float() imag = imag.to(device).float() length_mask = length.unsqueeze(2) length_mel_mask = length_mask.repeat(1, 1, mel.shape[2]).float() length_mask = length_mask.repeat(1, 1, spec.shape[2]).float() length_mask = length_mask.to(device) length_mel_mask = length_mel_mask.to(device) length = length.to(device) char_len_list = char_len_list.to(device) if not args.use_asr_post: chars = chars.to(device) char_len_list = char_len_list.to(device) else: phone = phone.float() if args.model_type == "GLU_Transformer": output, att, output_mel, output_mel2 = model( chars, phone, pitch, beat, pos_char=char_len_list, pos_spec=length, ) elif args.model_type == "LSTM": output, hidden, output_mel, output_mel2 = model(phone, pitch, beat) att = None elif args.model_type == "GRU_gs": output, att, output_mel = model(spec, phone, pitch, beat, length, args) att = None elif args.model_type == "PureTransformer": output, att, output_mel, output_mel2 = model( chars, phone, pitch, beat, pos_char=char_len_list, pos_spec=length, ) elif args.model_type == "Conformer": output, att, output_mel, output_mel2 = model( chars, phone, pitch, beat, pos_char=char_len_list, pos_spec=length, ) elif args.model_type == "Comformer_full": output, att, output_mel, output_mel2 = model( chars, phone, pitch, beat, pos_char=char_len_list, pos_spec=length, ) spec_origin = spec.clone() # spec_origin = spec if args.normalize: sepc_normalizer = GlobalMVN(args.stats_file) mel_normalizer = GlobalMVN(args.stats_mel_file) spec, _ = sepc_normalizer(spec, length) mel, _ = mel_normalizer(mel, length) spec_loss = criterion(output, spec, length_mask) if args.n_mels > 0: mel_loss = criterion(output_mel, mel, length_mel_mask) else: mel_loss = 0 final_loss = mel_loss + spec_loss losses.update(final_loss.item(), phone.size(0)) spec_losses.update(spec_loss.item(), phone.size(0)) if args.n_mels > 0: mel_losses.update(mel_loss.item(), phone.size(0)) # normalize inverse stage if args.normalize and args.stats_file: output, _ = sepc_normalizer.inverse(output, length) # spec,_ = sepc_normalizer.inverse(spec,length) (mcd_value, length_sum,) = Metrics.Calculate_melcd_fromLinearSpectrum( output, spec_origin, length, args ) ( f0_distortion_value, voiced_frame_number_step, vuv_error_value, frame_number_step, f0_ground_truth_step, f0_synthesis_step, ) = Metrics.Calculate_f0RMSE_VUV_CORR_fromWav( output, spec_origin, length, args, "test" ) f0_ground_truth_all = np.concatenate( (f0_ground_truth_all, f0_ground_truth_step), axis=0 ) f0_synthesis_all = np.concatenate( (f0_synthesis_all, f0_synthesis_step), axis=0 ) mcd_metric.update(mcd_value, length_sum) f0_distortion_metric.update(f0_distortion_value, voiced_frame_number_step) vuv_error_metric.update(vuv_error_value, frame_number_step) if step % 1 == 0: if args.vocoder_category == "griffin": log_figure( step, output, spec_origin, att, length, args.prediction_path, args, ) elif args.vocoder_category == "wavernn": log_mel( step, output_mel, spec_origin, att, length, args.prediction_path, args, voc_model, ) out_log = ( "step {}:train_loss{:.4f};" "spec_loss{:.4f};mcd_value{:.4f};".format( step, losses.avg, spec_losses.avg, mcd_metric.avg, ) ) if args.perceptual_loss > 0: out_log += " pe_loss {:.4f}; ".format(pe_losses.avg) if args.n_mels > 0: out_log += " mel_loss {:.4f}; ".format(mel_losses.avg) if args.double_mel_loss: out_log += " dmel_loss {:.4f}; ".format(double_mel_losses.avg) end = time.time() logging.info(f"{out_log} -- sum_time: {(end - start_t_test)}s") end_t_test = time.time() out_log = "Test Stage: " out_log += "spec_loss: {:.4f} ".format(spec_losses.avg) if args.n_mels > 0: out_log += "mel_loss: {:.4f}, ".format(mel_losses.avg) # if args.perceptual_loss > 0: # out_log += 'pe_loss: {:.4f}, '.format(train_info['pe_loss']) f0_corr = Metrics.compute_f0_corr(f0_ground_truth_all, f0_synthesis_all) out_log += "\n\t mcd_value {:.4f} dB ".format(mcd_metric.avg) out_log += ( " f0_rmse_value {:.4f} Hz, " "vuv_error_value {:.4f} %, F0_CORR {:.4f}; ".format( np.sqrt(f0_distortion_metric.avg), vuv_error_metric.avg * 100, f0_corr, ) ) logging.info("{} time: {:.2f}s".format(out_log, end_t_test - start_t_test))
def train_joint(args): """train_joint.""" np.random.seed(args.seed) torch.manual_seed(args.seed) if torch.cuda.is_available() and args.auto_select_gpu is True: cvd = use_single_gpu() logging.info(f"GPU {cvd} is used") torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # torch.backends.cudnn.enabled = False device = torch.device("cuda" if torch.cuda.is_available() else "cpu") elif torch.cuda.is_available() and args.auto_select_gpu is False: torch.cuda.set_device(args.gpu_id) logging.info(f"GPU {args.gpu_id} is used") torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # torch.backends.cudnn.enabled = False device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: device = torch.device("cpu") logging.info("Warning: CPU is used") train_set = SVSDataset( align_root_path=args.train_align, pitch_beat_root_path=args.train_pitch, wav_root_path=args.train_wav, char_max_len=args.char_max_len, max_len=args.num_frames, sr=args.sampling_rate, preemphasis=args.preemphasis, nfft=args.nfft, frame_shift=args.frame_shift, frame_length=args.frame_length, n_mels=args.n_mels, power=args.power, max_db=args.max_db, ref_db=args.ref_db, sing_quality=args.sing_quality, standard=args.standard, db_joint=args.db_joint, Hz2semitone=args.Hz2semitone, semitone_min=args.semitone_min, semitone_max=args.semitone_max, phone_shift_size=-1, semitone_shift=False, ) dev_set = SVSDataset( align_root_path=args.val_align, pitch_beat_root_path=args.val_pitch, wav_root_path=args.val_wav, char_max_len=args.char_max_len, max_len=args.num_frames, sr=args.sampling_rate, preemphasis=args.preemphasis, nfft=args.nfft, frame_shift=args.frame_shift, frame_length=args.frame_length, n_mels=args.n_mels, power=args.power, max_db=args.max_db, ref_db=args.ref_db, sing_quality=args.sing_quality, standard=args.standard, db_joint=args.db_joint, Hz2semitone=args.Hz2semitone, semitone_min=args.semitone_min, semitone_max=args.semitone_max, phone_shift_size=-1, semitone_shift=False, ) collate_fn_svs_train = SVSCollator( args.num_frames, args.char_max_len, args.use_asr_post, args.phone_size, args.n_mels, args.db_joint, args.random_crop, args.crop_min_length, args.Hz2semitone, ) collate_fn_svs_val = SVSCollator( args.num_frames, args.char_max_len, args.use_asr_post, args.phone_size, args.n_mels, args.db_joint, False, # random crop -1, # crop_min_length args.Hz2semitone, ) train_loader = torch.utils.data.DataLoader( dataset=train_set, batch_size=args.batchsize, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn_svs_train, pin_memory=True, ) dev_loader = torch.utils.data.DataLoader( dataset=dev_set, batch_size=args.batchsize, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn_svs_val, pin_memory=True, ) assert (args.feat_dim == dev_set[0]["spec"].shape[1] or args.feat_dim == dev_set[0]["mel"].shape[1]) if args.collect_stats: collect_stats(train_loader, args) logging.info("collect_stats finished !") quit() # init model_generate if args.model_type == "GLU_Transformer": if args.db_joint: model_generate = GLU_TransformerSVS_combine( phone_size=args.phone_size, singer_size=args.singer_size, embed_size=args.embedding_size, hidden_size=args.hidden_size, glu_num_layers=args.glu_num_layers, dropout=args.dropout, output_dim=args.feat_dim, dec_nhead=args.dec_nhead, dec_num_block=args.dec_num_block, n_mels=args.n_mels, double_mel_loss=args.double_mel_loss, local_gaussian=args.local_gaussian, Hz2semitone=args.Hz2semitone, semitone_size=args.semitone_size, device=device, ) else: model_generate = GLU_TransformerSVS( phone_size=args.phone_size, embed_size=args.embedding_size, hidden_size=args.hidden_size, glu_num_layers=args.glu_num_layers, dropout=args.dropout, output_dim=args.feat_dim, dec_nhead=args.dec_nhead, dec_num_block=args.dec_num_block, n_mels=args.n_mels, double_mel_loss=args.double_mel_loss, local_gaussian=args.local_gaussian, Hz2semitone=args.Hz2semitone, semitone_size=args.semitone_size, device=device, ) elif args.model_type == "LSTM": if args.db_joint: model_generate = LSTMSVS_combine( phone_size=args.phone_size, singer_size=args.singer_size, embed_size=args.embedding_size, d_model=args.hidden_size, num_layers=args.num_rnn_layers, dropout=args.dropout, d_output=args.feat_dim, n_mels=args.n_mels, double_mel_loss=args.double_mel_loss, Hz2semitone=args.Hz2semitone, semitone_size=args.semitone_size, device=device, use_asr_post=args.use_asr_post, ) else: model_generate = LSTMSVS( phone_size=args.phone_size, embed_size=args.embedding_size, d_model=args.hidden_size, num_layers=args.num_rnn_layers, dropout=args.dropout, d_output=args.feat_dim, n_mels=args.n_mels, double_mel_loss=args.double_mel_loss, Hz2semitone=args.Hz2semitone, semitone_size=args.semitone_size, device=device, use_asr_post=args.use_asr_post, ) # init model_predict model_predict = RNN_Discriminator( embed_size=128, d_model=128, hidden_size=128, num_layers=2, n_specs=1025, singer_size=7, phone_size=43, simitone_size=59, dropout=0.1, bidirectional=True, device=device, ) logging.info(f"*********** model_generate ***********") logging.info(f"{model_generate}") logging.info( f"The model has {count_parameters(model_generate):,} trainable parameters" ) logging.info(f"*********** model_predict ***********") logging.info(f"{model_predict}") logging.info( f"The model has {count_parameters(model_predict):,} trainable parameters" ) model_generate = load_model_weights(args.initmodel_generator, model_generate, device) model_predict = load_model_weights(args.initmodel_predictor, model_predict, device) model = Joint_generator_predictor(model_generate, model_predict) # setup optimizer if args.optimizer == "noam": optimizer = ScheduledOptim( torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-09), args.hidden_size, args.noam_warmup_steps, args.noam_scale, ) elif args.optimizer == "adam": optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-09) if args.scheduler == "OneCycleLR": scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=args.lr, steps_per_epoch=len(train_loader), epochs=args.max_epochs, ) elif args.scheduler == "ReduceLROnPlateau": scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, "min", verbose=True, patience=10, factor=0.5) elif args.scheduler == "ExponentialLR": scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, verbose=True, gamma=0.9886) else: raise ValueError("Not Support Optimizer") # setup loss function loss_predict = nn.CrossEntropyLoss(reduction="sum") if args.loss == "l1": loss_generate = MaskedLoss("l1", mask_free=args.mask_free) elif args.loss == "mse": loss_generate = MaskedLoss("mse", mask_free=args.mask_free) else: raise ValueError("Not Support Loss Type") if args.perceptual_loss > 0: win_length = int(args.sampling_rate * args.frame_length) psd_dict, bark_num = cal_psd2bark_dict(fs=args.sampling_rate, win_len=win_length) sf = cal_spread_function(bark_num) loss_perceptual_entropy = PerceptualEntropy(bark_num, sf, args.sampling_rate, win_length, psd_dict) else: loss_perceptual_entropy = None # Training generator_loss_epoch_to_save = {} generator_loss_counter = 0 spec_loss_epoch_to_save = {} spec_loss_counter = 0 predictor_loss_epoch_to_save = {} predictor_loss_counter = 0 for epoch in range(0, 1 + args.max_epochs): """Train Stage""" start_t_train = time.time() train_info = train_one_epoch_joint( train_loader, model, device, optimizer, loss_generate, loss_predict, loss_perceptual_entropy, epoch, args, ) end_t_train = time.time() # Print Total info out_log = "Train epoch: {:04d} ".format(epoch) if args.optimizer == "noam": out_log += "lr: {:.6f}, \n\t".format( optimizer._optimizer.param_groups[0]["lr"]) elif args.optimizer == "adam": out_log += "lr: {:.6f}, \n\t".format( optimizer.param_groups[0]["lr"]) out_log += "total_loss: {:.4f} \n\t".format(train_info["loss"]) # Print Generator info if args.vocoder_category == "wavernn": out_log += "generator_loss: {:.4f} ".format( train_info["generator_loss"]) else: out_log += "generator_loss: {:.4f}, spec_loss: {:.4f} ".format( train_info["generator_loss"], train_info["spec_loss"]) if args.n_mels > 0: out_log += "mel_loss: {:.4f}, ".format(train_info["mel_loss"]) if args.perceptual_loss > 0: out_log += "pe_loss: {:.4f}\n\t".format(train_info["pe_loss"]) # Print Predictor info out_log += "predictor_loss: {:.4f}, singer_loss: {:.4f}, ".format( train_info["predictor_loss"], train_info["singer_loss"], ) out_log += "phone_loss: {:.4f}, semitone_loss: {:.4f} \n\t\t".format( train_info["phone_loss"], train_info["semitone_loss"], ) out_log += "singer_accuracy: {:.4f}%, ".format( train_info["singer_accuracy"] * 100, ) out_log += "phone_accuracy: {:.4f}%, semitone_accuracy: {:.4f}% ".format( train_info["phone_accuracy"] * 100, train_info["semitone_accuracy"] * 100, ) logging.info("{} time: {:.2f}s".format(out_log, end_t_train - start_t_train)) """Dev Stage""" torch.backends.cudnn.enabled = False # 莫名的bug,关掉才可以跑 # start_t_dev = time.time() dev_info = validate_one_epoch_joint( dev_loader, model, device, optimizer, loss_generate, loss_predict, loss_perceptual_entropy, epoch, args, ) end_t_dev = time.time() # Print Total info dev_log = "Dev epoch: {:04d} ".format(epoch) if args.optimizer == "noam": dev_log += "lr: {:.6f}, \n\t".format( optimizer._optimizer.param_groups[0]["lr"]) elif args.optimizer == "adam": dev_log += "lr: {:.6f}, \n\t".format( optimizer.param_groups[0]["lr"]) dev_log += "total_loss: {:.4f} \n\t".format(dev_info["loss"]) # Print Generator info if args.vocoder_category == "wavernn": dev_log += "generator_loss: {:.4f} ".format( dev_info["generator_loss"]) else: dev_log += "generator_loss: {:.4f}, spec_loss: {:.4f} ".format( dev_info["generator_loss"], dev_info["spec_loss"]) if args.n_mels > 0: dev_log += "mel_loss: {:.4f}, ".format(dev_info["mel_loss"]) if args.perceptual_loss > 0: dev_log += "pe_loss: {:.4f}\n\t".format(dev_info["pe_loss"]) # Print Predictor info dev_log += "predictor_loss: {:.4f}, singer_loss: {:.4f}, ".format( dev_info["predictor_loss"], dev_info["singer_loss"], ) dev_log += "phone_loss: {:.4f}, semitone_loss: {:.4f} \n\t\t".format( dev_info["phone_loss"], dev_info["semitone_loss"], ) dev_log += "singer_accuracy: {:.4f}%, ".format( dev_info["singer_accuracy"] * 100, ) dev_log += "phone_accuracy: {:.4f}%, semitone_accuracy: {:.4f}% ".format( dev_info["phone_accuracy"] * 100, dev_info["semitone_accuracy"] * 100, ) logging.info("{} time: {:.2f}s".format(dev_log, end_t_dev - start_t_train)) sys.stdout.flush() torch.backends.cudnn.enabled = True """Save model Stage""" if not os.path.exists(args.model_save_dir): os.makedirs(args.model_save_dir) (generator_loss_counter, generator_loss_epoch_to_save) = Auto_save_model( args, epoch, model, optimizer, train_info, dev_info, None, # logger generator_loss_counter, generator_loss_epoch_to_save, save_loss_select="generator_loss", ) (spec_loss_counter, spec_loss_epoch_to_save) = Auto_save_model( args, epoch, model, optimizer, train_info, dev_info, None, # logger spec_loss_counter, spec_loss_epoch_to_save, save_loss_select="spec_loss", ) (predictor_loss_counter, predictor_loss_epoch_to_save) = Auto_save_model( args, epoch, model, optimizer, train_info, dev_info, None, # logger predictor_loss_counter, predictor_loss_epoch_to_save, save_loss_select="predictor_loss", )
def train_predictor(args): """train_predictor.""" np.random.seed(args.seed) torch.manual_seed(args.seed) if torch.cuda.is_available() and args.auto_select_gpu is True: cvd = use_single_gpu() logging.info(f"GPU {cvd} is used") torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # torch.backends.cudnn.enabled = False device = torch.device("cuda" if torch.cuda.is_available() else "cpu") elif torch.cuda.is_available() and args.auto_select_gpu is False: torch.cuda.set_device(args.gpu_id) logging.info(f"GPU {args.gpu_id} is used") torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # torch.backends.cudnn.enabled = False device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: device = torch.device("cpu") logging.info("Warning: CPU is used") train_set = SVSDataset( align_root_path=args.train_align, pitch_beat_root_path=args.train_pitch, wav_root_path=args.train_wav, char_max_len=args.char_max_len, max_len=args.num_frames, sr=args.sampling_rate, preemphasis=args.preemphasis, nfft=args.nfft, frame_shift=args.frame_shift, frame_length=args.frame_length, n_mels=args.n_mels, power=args.power, max_db=args.max_db, ref_db=args.ref_db, sing_quality=args.sing_quality, standard=args.standard, db_joint=args.db_joint, Hz2semitone=args.Hz2semitone, semitone_min=args.semitone_min, semitone_max=args.semitone_max, phone_shift_size=-1, semitone_shift=False, ) dev_set = SVSDataset( align_root_path=args.val_align, pitch_beat_root_path=args.val_pitch, wav_root_path=args.val_wav, char_max_len=args.char_max_len, max_len=args.num_frames, sr=args.sampling_rate, preemphasis=args.preemphasis, nfft=args.nfft, frame_shift=args.frame_shift, frame_length=args.frame_length, n_mels=args.n_mels, power=args.power, max_db=args.max_db, ref_db=args.ref_db, sing_quality=args.sing_quality, standard=args.standard, db_joint=args.db_joint, Hz2semitone=args.Hz2semitone, semitone_min=args.semitone_min, semitone_max=args.semitone_max, phone_shift_size=-1, semitone_shift=False, ) collate_fn_svs_train = SVSCollator( args.num_frames, args.char_max_len, args.use_asr_post, args.phone_size, args.n_mels, args.db_joint, args.random_crop, args.crop_min_length, args.Hz2semitone, ) collate_fn_svs_val = SVSCollator( args.num_frames, args.char_max_len, args.use_asr_post, args.phone_size, args.n_mels, args.db_joint, False, # random crop -1, # crop_min_length args.Hz2semitone, ) train_loader = torch.utils.data.DataLoader( dataset=train_set, batch_size=args.batchsize, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn_svs_train, pin_memory=True, ) dev_loader = torch.utils.data.DataLoader( dataset=dev_set, batch_size=args.batchsize, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn_svs_val, pin_memory=True, ) assert ( args.feat_dim == dev_set[0]["spec"].shape[1] or args.feat_dim == dev_set[0]["mel"].shape[1] ) if args.collect_stats: collect_stats(train_loader, args) logging.info("collect_stats finished !") quit() # init model model = RNN_Discriminator( embed_size=128, d_model=128, hidden_size=128, num_layers=2, n_specs=1025, singer_size=7, phone_size=43, simitone_size=59, dropout=0.1, bidirectional=True, device=device, ) logging.info(f"{model}") model = model.to(device) logging.info(f"The model has {count_parameters(model):,} trainable parameters") # setup optimizer if args.optimizer == "adam": optimizer = torch.optim.Adam( model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-09 ) else: raise ValueError("Not Support Optimizer") loss = nn.CrossEntropyLoss(reduction="sum") # Training total_loss_epoch_to_save = {} total_loss_counter = 0 for epoch in range(0, 1 + args.max_epochs): """Train Stage""" start_t_train = time.time() train_info = train_one_epoch_discriminator( train_loader, model, device, optimizer, loss, epoch, args, ) end_t_train = time.time() out_log = "Train epoch: {:04d}, ".format(epoch) if args.optimizer == "noam": out_log += "lr: {:.6f}, ".format(optimizer._optimizer.param_groups[0]["lr"]) elif args.optimizer == "adam": out_log += "lr: {:.6f}, ".format(optimizer.param_groups[0]["lr"]) out_log += "loss: {:.4f}, singer_loss: {:.4f}, ".format( train_info["loss"], train_info["singer_loss"], ) out_log += "phone_loss: {:.4f}, semitone_loss: {:.4f} \n".format( train_info["phone_loss"], train_info["semitone_loss"], ) out_log += "singer_accuracy: {:.4f}%, ".format( train_info["singer_accuracy"] * 100, ) out_log += "phone_accuracy: {:.4f}%, semitone_accuracy: {:.4f}% ".format( train_info["phone_accuracy"] * 100, train_info["semitone_accuracy"] * 100, ) logging.info("{} time: {:.2f}s".format(out_log, end_t_train - start_t_train)) """Dev Stage""" torch.backends.cudnn.enabled = False # 莫名的bug,关掉才可以跑 # start_t_dev = time.time() dev_info = validate_one_epoch_discriminator( dev_loader, model, device, loss, epoch, args, ) end_t_dev = time.time() dev_log = "Dev epoch: {:04d} ".format(epoch) dev_log += "loss: {:.4f}, singer_loss: {:.4f}, ".format( dev_info["loss"], dev_info["singer_loss"], ) dev_log += "phone_loss: {:.4f}, semitone_loss: {:.4f} \n".format( dev_info["phone_loss"], dev_info["semitone_loss"], ) dev_log += "singer_accuracy: {:.4f}%, ".format( dev_info["singer_accuracy"] * 100, ) dev_log += "phone_accuracy: {:.4f}%, semitone_accuracy: {:.4f}% ".format( dev_info["phone_accuracy"] * 100, dev_info["semitone_accuracy"] * 100, ) logging.info("{} time: {:.2f}s".format(dev_log, end_t_dev - start_t_train)) sys.stdout.flush() torch.backends.cudnn.enabled = True """Save model Stage""" if not os.path.exists(args.model_save_dir): os.makedirs(args.model_save_dir) (total_loss_counter, total_loss_epoch_to_save) = Auto_save_model( args, epoch, model, optimizer, train_info, dev_info, None, # logger total_loss_counter, total_loss_epoch_to_save, save_loss_select="loss", )