def train_joint(args): """train_joint.""" np.random.seed(args.seed) torch.manual_seed(args.seed) if torch.cuda.is_available() and args.auto_select_gpu is True: cvd = use_single_gpu() logging.info(f"GPU {cvd} is used") torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # torch.backends.cudnn.enabled = False device = torch.device("cuda" if torch.cuda.is_available() else "cpu") elif torch.cuda.is_available() and args.auto_select_gpu is False: torch.cuda.set_device(args.gpu_id) logging.info(f"GPU {args.gpu_id} is used") torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # torch.backends.cudnn.enabled = False device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: device = torch.device("cpu") logging.info("Warning: CPU is used") train_set = SVSDataset( align_root_path=args.train_align, pitch_beat_root_path=args.train_pitch, wav_root_path=args.train_wav, char_max_len=args.char_max_len, max_len=args.num_frames, sr=args.sampling_rate, preemphasis=args.preemphasis, nfft=args.nfft, frame_shift=args.frame_shift, frame_length=args.frame_length, n_mels=args.n_mels, power=args.power, max_db=args.max_db, ref_db=args.ref_db, sing_quality=args.sing_quality, standard=args.standard, db_joint=args.db_joint, Hz2semitone=args.Hz2semitone, semitone_min=args.semitone_min, semitone_max=args.semitone_max, phone_shift_size=-1, semitone_shift=False, ) dev_set = SVSDataset( align_root_path=args.val_align, pitch_beat_root_path=args.val_pitch, wav_root_path=args.val_wav, char_max_len=args.char_max_len, max_len=args.num_frames, sr=args.sampling_rate, preemphasis=args.preemphasis, nfft=args.nfft, frame_shift=args.frame_shift, frame_length=args.frame_length, n_mels=args.n_mels, power=args.power, max_db=args.max_db, ref_db=args.ref_db, sing_quality=args.sing_quality, standard=args.standard, db_joint=args.db_joint, Hz2semitone=args.Hz2semitone, semitone_min=args.semitone_min, semitone_max=args.semitone_max, phone_shift_size=-1, semitone_shift=False, ) collate_fn_svs_train = SVSCollator( args.num_frames, args.char_max_len, args.use_asr_post, args.phone_size, args.n_mels, args.db_joint, args.random_crop, args.crop_min_length, args.Hz2semitone, ) collate_fn_svs_val = SVSCollator( args.num_frames, args.char_max_len, args.use_asr_post, args.phone_size, args.n_mels, args.db_joint, False, # random crop -1, # crop_min_length args.Hz2semitone, ) train_loader = torch.utils.data.DataLoader( dataset=train_set, batch_size=args.batchsize, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn_svs_train, pin_memory=True, ) dev_loader = torch.utils.data.DataLoader( dataset=dev_set, batch_size=args.batchsize, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn_svs_val, pin_memory=True, ) assert (args.feat_dim == dev_set[0]["spec"].shape[1] or args.feat_dim == dev_set[0]["mel"].shape[1]) if args.collect_stats: collect_stats(train_loader, args) logging.info("collect_stats finished !") quit() # init model_generate if args.model_type == "GLU_Transformer": if args.db_joint: model_generate = GLU_TransformerSVS_combine( phone_size=args.phone_size, singer_size=args.singer_size, embed_size=args.embedding_size, hidden_size=args.hidden_size, glu_num_layers=args.glu_num_layers, dropout=args.dropout, output_dim=args.feat_dim, dec_nhead=args.dec_nhead, dec_num_block=args.dec_num_block, n_mels=args.n_mels, double_mel_loss=args.double_mel_loss, local_gaussian=args.local_gaussian, Hz2semitone=args.Hz2semitone, semitone_size=args.semitone_size, device=device, ) else: model_generate = GLU_TransformerSVS( phone_size=args.phone_size, embed_size=args.embedding_size, hidden_size=args.hidden_size, glu_num_layers=args.glu_num_layers, dropout=args.dropout, output_dim=args.feat_dim, dec_nhead=args.dec_nhead, dec_num_block=args.dec_num_block, n_mels=args.n_mels, double_mel_loss=args.double_mel_loss, local_gaussian=args.local_gaussian, Hz2semitone=args.Hz2semitone, semitone_size=args.semitone_size, device=device, ) elif args.model_type == "LSTM": if args.db_joint: model_generate = LSTMSVS_combine( phone_size=args.phone_size, singer_size=args.singer_size, embed_size=args.embedding_size, d_model=args.hidden_size, num_layers=args.num_rnn_layers, dropout=args.dropout, d_output=args.feat_dim, n_mels=args.n_mels, double_mel_loss=args.double_mel_loss, Hz2semitone=args.Hz2semitone, semitone_size=args.semitone_size, device=device, use_asr_post=args.use_asr_post, ) else: model_generate = LSTMSVS( phone_size=args.phone_size, embed_size=args.embedding_size, d_model=args.hidden_size, num_layers=args.num_rnn_layers, dropout=args.dropout, d_output=args.feat_dim, n_mels=args.n_mels, double_mel_loss=args.double_mel_loss, Hz2semitone=args.Hz2semitone, semitone_size=args.semitone_size, device=device, use_asr_post=args.use_asr_post, ) # init model_predict model_predict = RNN_Discriminator( embed_size=128, d_model=128, hidden_size=128, num_layers=2, n_specs=1025, singer_size=7, phone_size=43, simitone_size=59, dropout=0.1, bidirectional=True, device=device, ) logging.info(f"*********** model_generate ***********") logging.info(f"{model_generate}") logging.info( f"The model has {count_parameters(model_generate):,} trainable parameters" ) logging.info(f"*********** model_predict ***********") logging.info(f"{model_predict}") logging.info( f"The model has {count_parameters(model_predict):,} trainable parameters" ) model_generate = load_model_weights(args.initmodel_generator, model_generate, device) model_predict = load_model_weights(args.initmodel_predictor, model_predict, device) model = Joint_generator_predictor(model_generate, model_predict) # setup optimizer if args.optimizer == "noam": optimizer = ScheduledOptim( torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-09), args.hidden_size, args.noam_warmup_steps, args.noam_scale, ) elif args.optimizer == "adam": optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-09) if args.scheduler == "OneCycleLR": scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=args.lr, steps_per_epoch=len(train_loader), epochs=args.max_epochs, ) elif args.scheduler == "ReduceLROnPlateau": scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, "min", verbose=True, patience=10, factor=0.5) elif args.scheduler == "ExponentialLR": scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, verbose=True, gamma=0.9886) else: raise ValueError("Not Support Optimizer") # setup loss function loss_predict = nn.CrossEntropyLoss(reduction="sum") if args.loss == "l1": loss_generate = MaskedLoss("l1", mask_free=args.mask_free) elif args.loss == "mse": loss_generate = MaskedLoss("mse", mask_free=args.mask_free) else: raise ValueError("Not Support Loss Type") if args.perceptual_loss > 0: win_length = int(args.sampling_rate * args.frame_length) psd_dict, bark_num = cal_psd2bark_dict(fs=args.sampling_rate, win_len=win_length) sf = cal_spread_function(bark_num) loss_perceptual_entropy = PerceptualEntropy(bark_num, sf, args.sampling_rate, win_length, psd_dict) else: loss_perceptual_entropy = None # Training generator_loss_epoch_to_save = {} generator_loss_counter = 0 spec_loss_epoch_to_save = {} spec_loss_counter = 0 predictor_loss_epoch_to_save = {} predictor_loss_counter = 0 for epoch in range(0, 1 + args.max_epochs): """Train Stage""" start_t_train = time.time() train_info = train_one_epoch_joint( train_loader, model, device, optimizer, loss_generate, loss_predict, loss_perceptual_entropy, epoch, args, ) end_t_train = time.time() # Print Total info out_log = "Train epoch: {:04d} ".format(epoch) if args.optimizer == "noam": out_log += "lr: {:.6f}, \n\t".format( optimizer._optimizer.param_groups[0]["lr"]) elif args.optimizer == "adam": out_log += "lr: {:.6f}, \n\t".format( optimizer.param_groups[0]["lr"]) out_log += "total_loss: {:.4f} \n\t".format(train_info["loss"]) # Print Generator info if args.vocoder_category == "wavernn": out_log += "generator_loss: {:.4f} ".format( train_info["generator_loss"]) else: out_log += "generator_loss: {:.4f}, spec_loss: {:.4f} ".format( train_info["generator_loss"], train_info["spec_loss"]) if args.n_mels > 0: out_log += "mel_loss: {:.4f}, ".format(train_info["mel_loss"]) if args.perceptual_loss > 0: out_log += "pe_loss: {:.4f}\n\t".format(train_info["pe_loss"]) # Print Predictor info out_log += "predictor_loss: {:.4f}, singer_loss: {:.4f}, ".format( train_info["predictor_loss"], train_info["singer_loss"], ) out_log += "phone_loss: {:.4f}, semitone_loss: {:.4f} \n\t\t".format( train_info["phone_loss"], train_info["semitone_loss"], ) out_log += "singer_accuracy: {:.4f}%, ".format( train_info["singer_accuracy"] * 100, ) out_log += "phone_accuracy: {:.4f}%, semitone_accuracy: {:.4f}% ".format( train_info["phone_accuracy"] * 100, train_info["semitone_accuracy"] * 100, ) logging.info("{} time: {:.2f}s".format(out_log, end_t_train - start_t_train)) """Dev Stage""" torch.backends.cudnn.enabled = False # 莫名的bug,关掉才可以跑 # start_t_dev = time.time() dev_info = validate_one_epoch_joint( dev_loader, model, device, optimizer, loss_generate, loss_predict, loss_perceptual_entropy, epoch, args, ) end_t_dev = time.time() # Print Total info dev_log = "Dev epoch: {:04d} ".format(epoch) if args.optimizer == "noam": dev_log += "lr: {:.6f}, \n\t".format( optimizer._optimizer.param_groups[0]["lr"]) elif args.optimizer == "adam": dev_log += "lr: {:.6f}, \n\t".format( optimizer.param_groups[0]["lr"]) dev_log += "total_loss: {:.4f} \n\t".format(dev_info["loss"]) # Print Generator info if args.vocoder_category == "wavernn": dev_log += "generator_loss: {:.4f} ".format( dev_info["generator_loss"]) else: dev_log += "generator_loss: {:.4f}, spec_loss: {:.4f} ".format( dev_info["generator_loss"], dev_info["spec_loss"]) if args.n_mels > 0: dev_log += "mel_loss: {:.4f}, ".format(dev_info["mel_loss"]) if args.perceptual_loss > 0: dev_log += "pe_loss: {:.4f}\n\t".format(dev_info["pe_loss"]) # Print Predictor info dev_log += "predictor_loss: {:.4f}, singer_loss: {:.4f}, ".format( dev_info["predictor_loss"], dev_info["singer_loss"], ) dev_log += "phone_loss: {:.4f}, semitone_loss: {:.4f} \n\t\t".format( dev_info["phone_loss"], dev_info["semitone_loss"], ) dev_log += "singer_accuracy: {:.4f}%, ".format( dev_info["singer_accuracy"] * 100, ) dev_log += "phone_accuracy: {:.4f}%, semitone_accuracy: {:.4f}% ".format( dev_info["phone_accuracy"] * 100, dev_info["semitone_accuracy"] * 100, ) logging.info("{} time: {:.2f}s".format(dev_log, end_t_dev - start_t_train)) sys.stdout.flush() torch.backends.cudnn.enabled = True """Save model Stage""" if not os.path.exists(args.model_save_dir): os.makedirs(args.model_save_dir) (generator_loss_counter, generator_loss_epoch_to_save) = Auto_save_model( args, epoch, model, optimizer, train_info, dev_info, None, # logger generator_loss_counter, generator_loss_epoch_to_save, save_loss_select="generator_loss", ) (spec_loss_counter, spec_loss_epoch_to_save) = Auto_save_model( args, epoch, model, optimizer, train_info, dev_info, None, # logger spec_loss_counter, spec_loss_epoch_to_save, save_loss_select="spec_loss", ) (predictor_loss_counter, predictor_loss_epoch_to_save) = Auto_save_model( args, epoch, model, optimizer, train_info, dev_info, None, # logger predictor_loss_counter, predictor_loss_epoch_to_save, save_loss_select="predictor_loss", )
def train(args): """train.""" np.random.seed(args.seed) torch.manual_seed(args.seed) if torch.cuda.is_available() and args.auto_select_gpu is True: cvd = use_single_gpu() logging.info(f"GPU {cvd} is used") torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # torch.backends.cudnn.enabled = False device = torch.device("cuda" if torch.cuda.is_available() else "cpu") elif torch.cuda.is_available() and args.auto_select_gpu is False: torch.cuda.set_device(args.gpu_id) logging.info(f"GPU {args.gpu_id} is used") torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # torch.backends.cudnn.enabled = False device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: device = torch.device("cpu") logging.info("Warning: CPU is used") train_set = SVSDataset( align_root_path=args.train_align, pitch_beat_root_path=args.train_pitch, wav_root_path=args.train_wav, char_max_len=args.char_max_len, max_len=args.num_frames, sr=args.sampling_rate, preemphasis=args.preemphasis, nfft=args.nfft, frame_shift=args.frame_shift, frame_length=args.frame_length, n_mels=args.n_mels, power=args.power, max_db=args.max_db, ref_db=args.ref_db, sing_quality=args.sing_quality, standard=args.standard, db_joint=args.db_joint, Hz2semitone=args.Hz2semitone, semitone_min=args.semitone_min, semitone_max=args.semitone_max, phone_shift_size=args.phone_shift_size, semitone_shift=args.semitone_shift, ) dev_set = SVSDataset( align_root_path=args.val_align, pitch_beat_root_path=args.val_pitch, wav_root_path=args.val_wav, char_max_len=args.char_max_len, max_len=args.num_frames, sr=args.sampling_rate, preemphasis=args.preemphasis, nfft=args.nfft, frame_shift=args.frame_shift, frame_length=args.frame_length, n_mels=args.n_mels, power=args.power, max_db=args.max_db, ref_db=args.ref_db, sing_quality=args.sing_quality, standard=args.standard, db_joint=args.db_joint, Hz2semitone=args.Hz2semitone, semitone_min=args.semitone_min, semitone_max=args.semitone_max, phone_shift_size=-1, semitone_shift=False, ) collate_fn_svs_train = SVSCollator( args.num_frames, args.char_max_len, args.use_asr_post, args.phone_size, args.n_mels, args.db_joint, args.random_crop, args.crop_min_length, args.Hz2semitone, ) collate_fn_svs_val = SVSCollator( args.num_frames, args.char_max_len, args.use_asr_post, args.phone_size, args.n_mels, args.db_joint, False, # random crop -1, # crop_min_length args.Hz2semitone, ) train_loader = torch.utils.data.DataLoader( dataset=train_set, batch_size=args.batchsize, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn_svs_train, pin_memory=True, ) dev_loader = torch.utils.data.DataLoader( dataset=dev_set, batch_size=args.batchsize, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn_svs_val, pin_memory=True, ) assert (args.feat_dim == dev_set[0]["spec"].shape[1] or args.feat_dim == dev_set[0]["mel"].shape[1]) if args.collect_stats: collect_stats(train_loader, args) logging.info("collect_stats finished !") quit() # prepare model if args.model_type == "GLU_Transformer": if args.db_joint: model = GLU_TransformerSVS_combine( phone_size=args.phone_size, singer_size=args.singer_size, embed_size=args.embedding_size, hidden_size=args.hidden_size, glu_num_layers=args.glu_num_layers, dropout=args.dropout, output_dim=args.feat_dim, dec_nhead=args.dec_nhead, dec_num_block=args.dec_num_block, n_mels=args.n_mels, double_mel_loss=args.double_mel_loss, local_gaussian=args.local_gaussian, Hz2semitone=args.Hz2semitone, semitone_size=args.semitone_size, device=device, ) else: model = GLU_TransformerSVS( phone_size=args.phone_size, embed_size=args.embedding_size, hidden_size=args.hidden_size, glu_num_layers=args.glu_num_layers, dropout=args.dropout, output_dim=args.feat_dim, dec_nhead=args.dec_nhead, dec_num_block=args.dec_num_block, n_mels=args.n_mels, double_mel_loss=args.double_mel_loss, local_gaussian=args.local_gaussian, Hz2semitone=args.Hz2semitone, semitone_size=args.semitone_size, device=device, ) elif args.model_type == "LSTM": if args.db_joint: model = LSTMSVS_combine( phone_size=args.phone_size, singer_size=args.singer_size, embed_size=args.embedding_size, d_model=args.hidden_size, num_layers=args.num_rnn_layers, dropout=args.dropout, d_output=args.feat_dim, n_mels=args.n_mels, double_mel_loss=args.double_mel_loss, Hz2semitone=args.Hz2semitone, semitone_size=args.semitone_size, device=device, use_asr_post=args.use_asr_post, ) else: model = LSTMSVS( phone_size=args.phone_size, embed_size=args.embedding_size, d_model=args.hidden_size, num_layers=args.num_rnn_layers, dropout=args.dropout, d_output=args.feat_dim, n_mels=args.n_mels, double_mel_loss=args.double_mel_loss, Hz2semitone=args.Hz2semitone, semitone_size=args.semitone_size, device=device, use_asr_post=args.use_asr_post, ) elif args.model_type == "GRU_gs": model = GRUSVS_gs( phone_size=args.phone_size, embed_size=args.embedding_size, d_model=args.hidden_size, num_layers=args.num_rnn_layers, dropout=args.dropout, d_output=args.feat_dim, n_mels=args.n_mels, device=device, use_asr_post=args.use_asr_post, ) elif args.model_type == "PureTransformer": model = TransformerSVS( phone_size=args.phone_size, embed_size=args.embedding_size, hidden_size=args.hidden_size, glu_num_layers=args.glu_num_layers, dropout=args.dropout, output_dim=args.feat_dim, dec_nhead=args.dec_nhead, dec_num_block=args.dec_num_block, n_mels=args.n_mels, double_mel_loss=args.double_mel_loss, local_gaussian=args.local_gaussian, device=device, ) elif args.model_type == "Conformer": model = ConformerSVS( phone_size=args.phone_size, embed_size=args.embedding_size, enc_attention_dim=args.enc_attention_dim, enc_attention_heads=args.enc_attention_heads, enc_linear_units=args.enc_linear_units, enc_num_blocks=args.enc_num_blocks, enc_dropout_rate=args.enc_dropout_rate, enc_positional_dropout_rate=args.enc_positional_dropout_rate, enc_attention_dropout_rate=args.enc_attention_dropout_rate, enc_input_layer=args.enc_input_layer, enc_normalize_before=args.enc_normalize_before, enc_concat_after=args.enc_concat_after, enc_positionwise_layer_type=args.enc_positionwise_layer_type, enc_positionwise_conv_kernel_size=( args.enc_positionwise_conv_kernel_size), enc_macaron_style=args.enc_macaron_style, enc_pos_enc_layer_type=args.enc_pos_enc_layer_type, enc_selfattention_layer_type=args.enc_selfattention_layer_type, enc_activation_type=args.enc_activation_type, enc_use_cnn_module=args.enc_use_cnn_module, enc_cnn_module_kernel=args.enc_cnn_module_kernel, enc_padding_idx=args.enc_padding_idx, output_dim=args.feat_dim, dec_nhead=args.dec_nhead, dec_num_block=args.dec_num_block, n_mels=args.n_mels, double_mel_loss=args.double_mel_loss, local_gaussian=args.local_gaussian, dec_dropout=args.dec_dropout, Hz2semitone=args.Hz2semitone, semitone_size=args.semitone_size, device=device, ) elif args.model_type == "Comformer_full": if args.db_joint: model = ConformerSVS_FULL_combine( phone_size=args.phone_size, singer_size=args.singer_size, embed_size=args.embedding_size, output_dim=args.feat_dim, n_mels=args.n_mels, enc_attention_dim=args.enc_attention_dim, enc_attention_heads=args.enc_attention_heads, enc_linear_units=args.enc_linear_units, enc_num_blocks=args.enc_num_blocks, enc_dropout_rate=args.enc_dropout_rate, enc_positional_dropout_rate=args.enc_positional_dropout_rate, enc_attention_dropout_rate=args.enc_attention_dropout_rate, enc_input_layer=args.enc_input_layer, enc_normalize_before=args.enc_normalize_before, enc_concat_after=args.enc_concat_after, enc_positionwise_layer_type=args.enc_positionwise_layer_type, enc_positionwise_conv_kernel_size=( args.enc_positionwise_conv_kernel_size), enc_macaron_style=args.enc_macaron_style, enc_pos_enc_layer_type=args.enc_pos_enc_layer_type, enc_selfattention_layer_type=args.enc_selfattention_layer_type, enc_activation_type=args.enc_activation_type, enc_use_cnn_module=args.enc_use_cnn_module, enc_cnn_module_kernel=args.enc_cnn_module_kernel, enc_padding_idx=args.enc_padding_idx, dec_attention_dim=args.dec_attention_dim, dec_attention_heads=args.dec_attention_heads, dec_linear_units=args.dec_linear_units, dec_num_blocks=args.dec_num_blocks, dec_dropout_rate=args.dec_dropout_rate, dec_positional_dropout_rate=args.dec_positional_dropout_rate, dec_attention_dropout_rate=args.dec_attention_dropout_rate, dec_input_layer=args.dec_input_layer, dec_normalize_before=args.dec_normalize_before, dec_concat_after=args.dec_concat_after, dec_positionwise_layer_type=args.dec_positionwise_layer_type, dec_positionwise_conv_kernel_size=( args.dec_positionwise_conv_kernel_size), dec_macaron_style=args.dec_macaron_style, dec_pos_enc_layer_type=args.dec_pos_enc_layer_type, dec_selfattention_layer_type=args.dec_selfattention_layer_type, dec_activation_type=args.dec_activation_type, dec_use_cnn_module=args.dec_use_cnn_module, dec_cnn_module_kernel=args.dec_cnn_module_kernel, dec_padding_idx=args.dec_padding_idx, Hz2semitone=args.Hz2semitone, semitone_size=args.semitone_size, device=device, ) else: model = ConformerSVS_FULL( phone_size=args.phone_size, embed_size=args.embedding_size, output_dim=args.feat_dim, n_mels=args.n_mels, enc_attention_dim=args.enc_attention_dim, enc_attention_heads=args.enc_attention_heads, enc_linear_units=args.enc_linear_units, enc_num_blocks=args.enc_num_blocks, enc_dropout_rate=args.enc_dropout_rate, enc_positional_dropout_rate=args.enc_positional_dropout_rate, enc_attention_dropout_rate=args.enc_attention_dropout_rate, enc_input_layer=args.enc_input_layer, enc_normalize_before=args.enc_normalize_before, enc_concat_after=args.enc_concat_after, enc_positionwise_layer_type=args.enc_positionwise_layer_type, enc_positionwise_conv_kernel_size=( args.enc_positionwise_conv_kernel_size), enc_macaron_style=args.enc_macaron_style, enc_pos_enc_layer_type=args.enc_pos_enc_layer_type, enc_selfattention_layer_type=args.enc_selfattention_layer_type, enc_activation_type=args.enc_activation_type, enc_use_cnn_module=args.enc_use_cnn_module, enc_cnn_module_kernel=args.enc_cnn_module_kernel, enc_padding_idx=args.enc_padding_idx, dec_attention_dim=args.dec_attention_dim, dec_attention_heads=args.dec_attention_heads, dec_linear_units=args.dec_linear_units, dec_num_blocks=args.dec_num_blocks, dec_dropout_rate=args.dec_dropout_rate, dec_positional_dropout_rate=args.dec_positional_dropout_rate, dec_attention_dropout_rate=args.dec_attention_dropout_rate, dec_input_layer=args.dec_input_layer, dec_normalize_before=args.dec_normalize_before, dec_concat_after=args.dec_concat_after, dec_positionwise_layer_type=args.dec_positionwise_layer_type, dec_positionwise_conv_kernel_size=( args.dec_positionwise_conv_kernel_size), dec_macaron_style=args.dec_macaron_style, dec_pos_enc_layer_type=args.dec_pos_enc_layer_type, dec_selfattention_layer_type=args.dec_selfattention_layer_type, dec_activation_type=args.dec_activation_type, dec_use_cnn_module=args.dec_use_cnn_module, dec_cnn_module_kernel=args.dec_cnn_module_kernel, dec_padding_idx=args.dec_padding_idx, Hz2semitone=args.Hz2semitone, semitone_size=args.semitone_size, device=device, ) elif args.model_type == "USTC_DAR": model = USTC_SVS( phone_size=args.phone_size, embed_size=args.embedding_size, middle_dim_fc=args.middle_dim_fc, output_dim=args.feat_dim, multi_history_num=args.multi_history_num, middle_dim_prenet=args.middle_dim_prenet, n_blocks_prenet=args.n_blocks_prenet, n_heads_prenet=args.n_heads_prenet, kernel_size_prenet=args.kernel_size_prenet, bi_d_model=args.bi_d_model, bi_num_layers=args.bi_num_layers, uni_d_model=args.uni_d_model, uni_num_layers=args.uni_num_layers, dropout=args.dropout, feedbackLink_drop_rate=args.feedbackLink_drop_rate, device=device, ) else: raise ValueError("Not Support Model Type %s" % args.model_type) logging.info(f"{model}") model = model.to(device) logging.info( f"The model has {count_parameters(model):,} trainable parameters") model_load_dir = "" pretrain_encoder_dir = "" start_epoch = 1 # FIX ME if args.pretrain_encoder != "": pretrain_encoder_dir = args.pretrain_encoder if args.initmodel != "": model_load_dir = args.initmodel if args.resume and os.path.exists(args.model_save_dir): checks = os.listdir(args.model_save_dir) start_epoch = max( list( map( lambda x: int(x.split(".")[0].split("_")[-1]) if x.endswith("pth.tar") else -1, checks, ))) model_temp_load_dir = "{}/epoch_loss_{}.pth.tar".format( args.model_save_dir, start_epoch) if start_epoch < 0: model_load_dir = "" elif os.path.isfile(model_temp_load_dir): model_load_dir = model_temp_load_dir else: model_load_dir = "{}/epoch_spec_loss_{}.pth.tar".format( args.model_save_dir, start_epoch) # load encoder parm from Transformer-TTS if pretrain_encoder_dir != "": pretrain = torch.load(pretrain_encoder_dir, map_location=device) pretrain_dict = pretrain["model"] model_dict = model.state_dict() state_dict_new = {} para_list = [] i = 0 for k, v in pretrain_dict.items(): k_new = k[7:] if (k_new in model_dict and model_dict[k_new].size() == pretrain_dict[k].size()): i += 1 state_dict_new[k_new] = v model_dict.update(state_dict_new) model.load_state_dict(model_dict) logging.info(f"Load {i} layers total. Load pretrain encoder success !") # load weights for pre-trained model if model_load_dir != "": logging.info(f"Model Start to Load, dir: {model_load_dir}") model_load = torch.load(model_load_dir, map_location=device) loading_dict = model_load["state_dict"] model_dict = model.state_dict() state_dict_new = {} para_list = [] for k, v in loading_dict.items(): # assert k in model_dict if (k == "normalizer.mean" or k == "normalizer.std" or k == "mel_normalizer.mean" or k == "mel_normalizer.std"): continue if model_dict[k].size() == loading_dict[k].size(): state_dict_new[k] = v else: para_list.append(k) logging.info(f"Total {len(loading_dict)} parameter sets, " f"Loaded {len(state_dict_new)} parameter sets") if len(para_list) > 0: logging.warning("Not loading {} because of different sizes".format( ", ".join(para_list))) model_dict.update(state_dict_new) model.load_state_dict(model_dict) logging.info(f"Loaded checkpoint {args.initmodel}") # setup optimizer if args.optimizer == "noam": optimizer = ScheduledOptim( torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-09), args.hidden_size, args.noam_warmup_steps, args.noam_scale, ) elif args.optimizer == "adam": optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-09) if args.scheduler == "OneCycleLR": scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=args.lr, steps_per_epoch=len(train_loader), epochs=args.max_epochs, ) elif args.scheduler == "ReduceLROnPlateau": scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, "min", verbose=True, patience=10, factor=0.5) elif args.scheduler == "ExponentialLR": scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, verbose=True, gamma=0.9886) else: raise ValueError("Not Support Optimizer") # Setup tensorborad logger if args.use_tfboard: from tensorboardX import SummaryWriter logger = SummaryWriter("{}/log".format(args.model_save_dir)) else: logger = None if args.loss == "l1": loss = MaskedLoss("l1", mask_free=args.mask_free) elif args.loss == "mse": loss = MaskedLoss("mse", mask_free=args.mask_free) else: raise ValueError("Not Support Loss Type") if args.perceptual_loss > 0: win_length = int(args.sampling_rate * args.frame_length) psd_dict, bark_num = cal_psd2bark_dict(fs=args.sampling_rate, win_len=win_length) sf = cal_spread_function(bark_num) loss_perceptual_entropy = PerceptualEntropy(bark_num, sf, args.sampling_rate, win_length, psd_dict) else: loss_perceptual_entropy = None # Training total_loss_epoch_to_save = {} total_loss_counter = 0 spec_loss_epoch_to_save = {} spec_loss_counter = 0 total_learning_step = 0 # args.num_saved_model = 5 # preload vocoder model voc_model = [] if args.vocoder_category == "wavernn": logging.info("load voc_model from {}".format(args.wavernn_voc_model)) voc_model = WaveRNN( rnn_dims=args.voc_rnn_dims, fc_dims=args.voc_fc_dims, bits=args.voc_bits, pad=args.voc_pad, upsample_factors=( args.voc_upsample_factors_0, args.voc_upsample_factors_1, args.voc_upsample_factors_2, ), feat_dims=args.n_mels, compute_dims=args.voc_compute_dims, res_out_dims=args.voc_res_out_dims, res_blocks=args.voc_res_blocks, hop_length=args.hop_length, sample_rate=args.sampling_rate, mode=args.voc_mode, ).to(device) voc_model.load(args.wavernn_voc_model) for epoch in range(start_epoch + 1, 1 + args.max_epochs): """Train Stage""" start_t_train = time.time() train_info = train_one_epoch( train_loader, model, device, optimizer, loss, loss_perceptual_entropy, epoch, args, voc_model, ) end_t_train = time.time() out_log = "Train epoch: {:04d}, ".format(epoch) if args.optimizer == "noam": out_log += "lr: {:.6f}, ".format( optimizer._optimizer.param_groups[0]["lr"]) elif args.optimizer == "adam": out_log += "lr: {:.6f}, ".format(optimizer.param_groups[0]["lr"]) if args.vocoder_category == "wavernn": out_log += "loss: {:.4f} ".format(train_info["loss"]) else: out_log += "loss: {:.4f}, spec_loss: {:.4f} ".format( train_info["loss"], train_info["spec_loss"]) if args.n_mels > 0: out_log += "mel_loss: {:.4f}, ".format(train_info["mel_loss"]) if args.perceptual_loss > 0: out_log += "pe_loss: {:.4f}, ".format(train_info["pe_loss"]) logging.info("{} time: {:.2f}s".format(out_log, end_t_train - start_t_train)) """Dev Stage""" torch.backends.cudnn.enabled = False # 莫名的bug,关掉才可以跑 # start_t_dev = time.time() dev_info = validate( dev_loader, model, device, loss, loss_perceptual_entropy, epoch, args, voc_model, ) end_t_dev = time.time() dev_log = "Dev epoch: {:04d}, loss: {:.4f}, spec_loss: {:.4f}, ".format( epoch, dev_info["loss"], dev_info["spec_loss"]) dev_log += "mcd_value: {:.4f}, ".format(dev_info["mcd_value"]) if args.n_mels > 0: dev_log += "mel_loss: {:.4f}, ".format(dev_info["mel_loss"]) if args.perceptual_loss > 0: dev_log += "pe_loss: {:.4f}, ".format(dev_info["pe_loss"]) logging.info("{} time: {:.2f}s".format(dev_log, end_t_dev - start_t_train)) sys.stdout.flush() torch.backends.cudnn.enabled = True if args.scheduler == "OneCycleLR": scheduler.step() elif args.scheduler == "ReduceLROnPlateau": scheduler.step(dev_info["loss"]) elif args.scheduler == "ExponentialLR": before = total_learning_step // args.lr_decay_learning_steps total_learning_step += len(train_loader) after = total_learning_step // args.lr_decay_learning_steps if after > before: # decay per 250 learning steps scheduler.step() """Save model Stage""" if not os.path.exists(args.model_save_dir): os.makedirs(args.model_save_dir) (total_loss_counter, total_loss_epoch_to_save) = Auto_save_model( args, epoch, model, optimizer, train_info, dev_info, logger, total_loss_counter, total_loss_epoch_to_save, save_loss_select="loss", ) if (dev_info["spec_loss"] != 0): # spec_loss 有意义时再存模型,比如 USTC DAR model 不需要计算线性谱spec loss (spec_loss_counter, spec_loss_epoch_to_save) = Auto_save_model( args, epoch, model, optimizer, train_info, dev_info, logger, spec_loss_counter, spec_loss_epoch_to_save, save_loss_select="spec_loss", ) if args.use_tfboard: logger.close()
def train_predictor(args): """train_predictor.""" np.random.seed(args.seed) torch.manual_seed(args.seed) if torch.cuda.is_available() and args.auto_select_gpu is True: cvd = use_single_gpu() logging.info(f"GPU {cvd} is used") torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # torch.backends.cudnn.enabled = False device = torch.device("cuda" if torch.cuda.is_available() else "cpu") elif torch.cuda.is_available() and args.auto_select_gpu is False: torch.cuda.set_device(args.gpu_id) logging.info(f"GPU {args.gpu_id} is used") torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # torch.backends.cudnn.enabled = False device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: device = torch.device("cpu") logging.info("Warning: CPU is used") train_set = SVSDataset( align_root_path=args.train_align, pitch_beat_root_path=args.train_pitch, wav_root_path=args.train_wav, char_max_len=args.char_max_len, max_len=args.num_frames, sr=args.sampling_rate, preemphasis=args.preemphasis, nfft=args.nfft, frame_shift=args.frame_shift, frame_length=args.frame_length, n_mels=args.n_mels, power=args.power, max_db=args.max_db, ref_db=args.ref_db, sing_quality=args.sing_quality, standard=args.standard, db_joint=args.db_joint, Hz2semitone=args.Hz2semitone, semitone_min=args.semitone_min, semitone_max=args.semitone_max, phone_shift_size=-1, semitone_shift=False, ) dev_set = SVSDataset( align_root_path=args.val_align, pitch_beat_root_path=args.val_pitch, wav_root_path=args.val_wav, char_max_len=args.char_max_len, max_len=args.num_frames, sr=args.sampling_rate, preemphasis=args.preemphasis, nfft=args.nfft, frame_shift=args.frame_shift, frame_length=args.frame_length, n_mels=args.n_mels, power=args.power, max_db=args.max_db, ref_db=args.ref_db, sing_quality=args.sing_quality, standard=args.standard, db_joint=args.db_joint, Hz2semitone=args.Hz2semitone, semitone_min=args.semitone_min, semitone_max=args.semitone_max, phone_shift_size=-1, semitone_shift=False, ) collate_fn_svs_train = SVSCollator( args.num_frames, args.char_max_len, args.use_asr_post, args.phone_size, args.n_mels, args.db_joint, args.random_crop, args.crop_min_length, args.Hz2semitone, ) collate_fn_svs_val = SVSCollator( args.num_frames, args.char_max_len, args.use_asr_post, args.phone_size, args.n_mels, args.db_joint, False, # random crop -1, # crop_min_length args.Hz2semitone, ) train_loader = torch.utils.data.DataLoader( dataset=train_set, batch_size=args.batchsize, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn_svs_train, pin_memory=True, ) dev_loader = torch.utils.data.DataLoader( dataset=dev_set, batch_size=args.batchsize, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn_svs_val, pin_memory=True, ) assert ( args.feat_dim == dev_set[0]["spec"].shape[1] or args.feat_dim == dev_set[0]["mel"].shape[1] ) if args.collect_stats: collect_stats(train_loader, args) logging.info("collect_stats finished !") quit() # init model model = RNN_Discriminator( embed_size=128, d_model=128, hidden_size=128, num_layers=2, n_specs=1025, singer_size=7, phone_size=43, simitone_size=59, dropout=0.1, bidirectional=True, device=device, ) logging.info(f"{model}") model = model.to(device) logging.info(f"The model has {count_parameters(model):,} trainable parameters") # setup optimizer if args.optimizer == "adam": optimizer = torch.optim.Adam( model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-09 ) else: raise ValueError("Not Support Optimizer") loss = nn.CrossEntropyLoss(reduction="sum") # Training total_loss_epoch_to_save = {} total_loss_counter = 0 for epoch in range(0, 1 + args.max_epochs): """Train Stage""" start_t_train = time.time() train_info = train_one_epoch_discriminator( train_loader, model, device, optimizer, loss, epoch, args, ) end_t_train = time.time() out_log = "Train epoch: {:04d}, ".format(epoch) if args.optimizer == "noam": out_log += "lr: {:.6f}, ".format(optimizer._optimizer.param_groups[0]["lr"]) elif args.optimizer == "adam": out_log += "lr: {:.6f}, ".format(optimizer.param_groups[0]["lr"]) out_log += "loss: {:.4f}, singer_loss: {:.4f}, ".format( train_info["loss"], train_info["singer_loss"], ) out_log += "phone_loss: {:.4f}, semitone_loss: {:.4f} \n".format( train_info["phone_loss"], train_info["semitone_loss"], ) out_log += "singer_accuracy: {:.4f}%, ".format( train_info["singer_accuracy"] * 100, ) out_log += "phone_accuracy: {:.4f}%, semitone_accuracy: {:.4f}% ".format( train_info["phone_accuracy"] * 100, train_info["semitone_accuracy"] * 100, ) logging.info("{} time: {:.2f}s".format(out_log, end_t_train - start_t_train)) """Dev Stage""" torch.backends.cudnn.enabled = False # 莫名的bug,关掉才可以跑 # start_t_dev = time.time() dev_info = validate_one_epoch_discriminator( dev_loader, model, device, loss, epoch, args, ) end_t_dev = time.time() dev_log = "Dev epoch: {:04d} ".format(epoch) dev_log += "loss: {:.4f}, singer_loss: {:.4f}, ".format( dev_info["loss"], dev_info["singer_loss"], ) dev_log += "phone_loss: {:.4f}, semitone_loss: {:.4f} \n".format( dev_info["phone_loss"], dev_info["semitone_loss"], ) dev_log += "singer_accuracy: {:.4f}%, ".format( dev_info["singer_accuracy"] * 100, ) dev_log += "phone_accuracy: {:.4f}%, semitone_accuracy: {:.4f}% ".format( dev_info["phone_accuracy"] * 100, dev_info["semitone_accuracy"] * 100, ) logging.info("{} time: {:.2f}s".format(dev_log, end_t_dev - start_t_train)) sys.stdout.flush() torch.backends.cudnn.enabled = True """Save model Stage""" if not os.path.exists(args.model_save_dir): os.makedirs(args.model_save_dir) (total_loss_counter, total_loss_epoch_to_save) = Auto_save_model( args, epoch, model, optimizer, train_info, dev_info, None, # logger total_loss_counter, total_loss_epoch_to_save, save_loss_select="loss", )