from evaluate.draw_loss_figure import draw_loss_figure from models.cycle_gan_model import CycleGANModel from utilSet.visualizer import Visualizer, save_opt from data.dataset import DataLoader import numpy as np from visdom import Visdom viz = Visdom() assert viz.check_connection() viz.close() opt = TrainOptions().parse() save_opt(opt) data_loader = DataLoader(opt) dataset = data_loader.load_data() dataset_size = len(data_loader) model = CycleGANModel() model.initialize(opt) visualizer = Visualizer(opt) if __name__ == '__main__': total_steps = 0 sparse_c_loss_points, sparse_c_loss_avr_points = [], [] win_sparse_C = viz.line(X=torch.zeros((1, )), Y=torch.zeros((1, )), name="win_sparse_C")
def run(args): # Get device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Define model logger.info(f"Loading Model of {args.model_name}...") with open(args.config) as f: config = yaml.load(f, Loader=yaml.Loader) hp.lambda_stft = config["lamda_stft"] hp.use_feature_map_loss = config["use_feature_map_loss"] if args.model_name == "melgan": model = MelGANGenerator( in_channels=config["in_channels"], out_channels=config["out_channels"], kernel_size=config["kernel_size"], channels=config["channels"], upsample_scales=config["upsample_scales"], stack_kernel_size=config["stack_kernel_size"], stacks=config["stacks"], use_weight_norm=config["use_weight_norm"], use_causal_conv=config["use_causal_conv"]).to(device) elif args.model_name == "hifigan": model = HiFiGANGenerator( resblock_kernel_sizes=config["resblock_kernel_sizes"], upsample_rates=config["upsample_rates"], upsample_initial_channel=config["upsample_initial_channel"], resblock_type=config["resblock_type"], upsample_kernel_sizes=config["upsample_kernel_sizes"], resblock_dilation_sizes=config["resblock_dilation_sizes"], transposedconv=config["transposedconv"], bias=config["bias"]).to(device) elif args.model_name == "multiband-hifigan": model = MultiBandHiFiGANGenerator( resblock_kernel_sizes=config["resblock_kernel_sizes"], upsample_rates=config["upsample_rates"], upsample_initial_channel=config["upsample_initial_channel"], resblock_type=config["resblock_type"], upsample_kernel_sizes=config["upsample_kernel_sizes"], resblock_dilation_sizes=config["resblock_dilation_sizes"], transposedconv=config["transposedconv"], bias=config["bias"]).to(device) elif args.model_name == "basis-melgan": basis_signal_weight = np.load( os.path.join("Basis-MelGAN-dataset", "basis_signal_weight.npy")) basis_signal_weight = torch.from_numpy(basis_signal_weight) model = BasisMelGANGenerator( basis_signal_weight=basis_signal_weight, L=config["L"], in_channels=config["in_channels"], out_channels=config["out_channels"], kernel_size=config["kernel_size"], channels=config["channels"], upsample_scales=config["upsample_scales"], stack_kernel_size=config["stack_kernel_size"], stacks=config["stacks"], use_weight_norm=config["use_weight_norm"], use_causal_conv=config["use_causal_conv"], transposedconv=config["transposedconv"]).to(device) else: raise Exception("no model find!") pqmf = None if config["multiband"] == True: logger.info("Define PQMF") pqmf = PQMF().to(device) logger.info(f"model is {str(model)}") discriminator = Discriminator().to(device) logger.info("Model Has Been Defined") num_param = get_param_num(model) logger.info(f'Number of TTS Parameters: {num_param}') # Optimizer and loss basis_signal_optimizer = None if not args.mixprecision: if args.model_name == "basis-melgan": optimizer = Adam(model.melgan.parameters(), lr=args.learning_rate, eps=1.0e-6, weight_decay=0.0) # freeze basis signal layer basis_signal_optimizer = Adam(model.basis_signal.parameters()) else: optimizer = Adam(model.parameters(), lr=args.learning_rate, eps=1.0e-6, weight_decay=0.0) discriminator_optimizer = Adam(discriminator.parameters(), lr=args.learning_rate_discriminator, eps=1.0e-6, weight_decay=0.0) else: if args.model_name == "basis-melgan": raise Exception("basis melgan don't support amp!") optimizer = apex.optimizers.FusedAdam(model.parameters(), lr=args.learning_rate) discriminator_optimizer = apex.optimizers.FusedAdam( discriminator.parameters(), lr=args.learning_rate_discriminator) model, optimizer = amp.initialize(model, optimizer, opt_level="O1", keep_batchnorm_fp32=None) discriminator, discriminator_optimizer = amp.initialize( discriminator, discriminator_optimizer, opt_level="O1") logger.info("Start mix precision training...") if args.use_scheduler: scheduler = CosineAnnealingLR(optimizer, T_max=2500, eta_min=args.learning_rate / 10.) discriminator_scheduler = CosineAnnealingLR( discriminator_optimizer, T_max=2500, eta_min=args.learning_rate_discriminator / 10.) else: scheduler = None discriminator_scheduler = None vocoder_loss = Loss().to(device) logger.info("Defined Optimizer and Loss Function.") # Load checkpoint if exists os.makedirs(hp.checkpoint_path, exist_ok=True) current_checkpoint_path = str(datetime.now()).replace(" ", "-").replace( ":", "-").replace(".", "-") current_checkpoint_path = os.path.join(hp.checkpoint_path, current_checkpoint_path) try: checkpoint = torch.load(os.path.join(args.checkpoint_path), map_location=torch.device(device)) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) if 'discriminator' in checkpoint: logger.info("loading discriminator") discriminator.load_state_dict(checkpoint['discriminator']) discriminator_optimizer.load_state_dict( checkpoint['discriminator_optimizer']) os.makedirs(current_checkpoint_path, exist_ok=True) if args.mixprecision: amp.load_state_dict(checkpoint['amp']) logger.info("\n---Model Restored at Step %d---\n" % args.restore_step) except: logger.info("\n---Start New Training---\n") os.makedirs(current_checkpoint_path, exist_ok=True) # Init logger os.makedirs(hp.logger_path, exist_ok=True) current_logger_path = str(datetime.now()).replace(" ", "-").replace( ":", "-").replace(".", "-") writer = SummaryWriter( os.path.join(hp.tensorboard_path, current_logger_path)) current_logger_path = os.path.join(hp.logger_path, current_logger_path) os.makedirs(current_logger_path, exist_ok=True) # Get buffer if args.model_name != "basis-melgan": logger.info("Load data to buffer") buffer = load_data_to_buffer(args.audio_index_path, args.mel_index_path, logger, feature_savepath="features_train.bin") logger.info("Load valid data to buffer") valid_buffer = load_data_to_buffer( args.audio_index_valid_path, args.mel_index_valid_path, logger, feature_savepath="features_valid.bin") # Get dataset if args.model_name == "basis-melgan": dataset = WeightDataset(args.audio_index_path, args.mel_index_path, config["L"]) valid_dataset = WeightDataset(args.audio_index_valid_path, args.mel_index_valid_path, config["L"]) else: dataset = BufferDataset(buffer) valid_dataset = BufferDataset(valid_buffer) # Get Training Loader training_loader = DataLoader(dataset, batch_size=hp.batch_expand_size * hp.batch_size, shuffle=True, collate_fn=collate_fn_tensor, drop_last=True, num_workers=4, prefetch_factor=2, pin_memory=True) logger.info(f"Length of training loader is {len(training_loader)}") total_step = hp.epochs * len(training_loader) * hp.batch_expand_size # Define Some Information time_list = np.array([]) Start = time.perf_counter() # Training model = model.train() for epoch in range(hp.epochs): for i, batchs in enumerate(training_loader): # real batch start here for j, db in enumerate(batchs): current_step = i * hp.batch_expand_size + j + args.restore_step + epoch * len( training_loader) * hp.batch_expand_size + 1 # Get Data clock_1_s = time.perf_counter() mel = db["mel"].float().to(device) wav = db["wav"].float().to(device) mel = mel.contiguous().transpose(1, 2) weight = None if "weight" in db: weight = db["weight"].float().to(device) clock_1_e = time.perf_counter() time_used_1 = round(clock_1_e - clock_1_s, 5) # Training clock_2_s = time.perf_counter() time_list = trainer( model, discriminator, optimizer, discriminator_optimizer, scheduler, discriminator_scheduler, vocoder_loss, mel, wav, epoch, current_step, total_step, time_list, Start, current_checkpoint_path, current_logger_path, writer, weight=weight, basis_signal_optimizer=basis_signal_optimizer, pqmf=pqmf, mixprecision=args.mixprecision) clock_2_e = time.perf_counter() time_used_2 = round(clock_2_e - clock_2_s, 5) if current_step % hp.valid_step == 0: logger.info("Start valid...") valid_loader = DataLoader( valid_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn_tensor_valid, num_workers=0) valid_loss_all = 0. for ii, valid_batch in enumerate(valid_loader): valid_mel = valid_batch["mel"].float().to(device) valid_mel = valid_mel.contiguous().transpose(1, 2) valid_wav = valid_batch["wav"].float().to(device) with torch.no_grad(): if args.model_name == "basis-melgan": valid_est_source, _ = model(valid_mel) else: valid_est_source = model(valid_mel) valid_stft_loss, _ = vocoder_loss(valid_est_source, valid_wav, pqmf=pqmf) valid_loss_all += valid_stft_loss.item() if ii == hp.valid_num: break writer.add_scalar('valid_stft_loss', valid_loss_all / float(hp.valid_num), global_step=current_step) writer.export_scalars_to_json(os.path.join("all_scalars.json")) writer.close() return