def test_melgan_trainable_with_melgan_discriminator(dict_g, dict_d, dict_loss): # setup batch_size = 4 batch_length = 4096 args_g = make_melgan_generator_args(**dict_g) args_d = make_melgan_discriminator_args(**dict_d) args_loss = make_mutli_reso_stft_loss_args(**dict_loss) y = torch.randn(batch_size, 1, batch_length) c = torch.randn(batch_size, args_g["in_channels"], batch_length // np.prod( args_g["upsample_scales"])) model_g = MelGANGenerator(**args_g) model_d = MelGANMultiScaleDiscriminator(**args_d) aux_criterion = MultiResolutionSTFTLoss(**args_loss) optimizer_g = RAdam(model_g.parameters()) optimizer_d = RAdam(model_d.parameters()) # check generator trainable y_hat = model_g(c) p_hat = model_d(y_hat) y, y_hat = y.squeeze(1), y_hat.squeeze(1) sc_loss, mag_loss = aux_criterion(y_hat, y) aux_loss = sc_loss + mag_loss adv_loss = 0.0 for i in range(len(p_hat)): adv_loss += F.mse_loss( p_hat[i][-1], p_hat[i][-1].new_ones(p_hat[i][-1].size())) adv_loss /= (i + 1) with torch.no_grad(): p = model_d(y.unsqueeze(1)) fm_loss = 0.0 for i in range(len(p_hat)): for j in range(len(p_hat[i]) - 1): fm_loss += F.l1_loss(p_hat[i][j], p[i][j].detach()) fm_loss /= (i + 1) * j loss_g = adv_loss + aux_loss + fm_loss optimizer_g.zero_grad() loss_g.backward() optimizer_g.step() # check discriminator trainable y, y_hat = y.unsqueeze(1), y_hat.unsqueeze(1).detach() p = model_d(y) p_hat = model_d(y_hat) real_loss = 0.0 fake_loss = 0.0 for i in range(len(p)): real_loss += F.mse_loss( p[i][-1], p[i][-1].new_ones(p[i][-1].size())) fake_loss += F.mse_loss( p_hat[i][-1], p_hat[i][-1].new_zeros(p_hat[i][-1].size())) real_loss /= (i + 1) fake_loss /= (i + 1) loss_d = real_loss + fake_loss optimizer_d.zero_grad() loss_d.backward() optimizer_d.step()
def test_hifigan_trainable(dict_g, dict_d, dict_loss): # setup batch_size = 4 batch_length = 2**13 args_g = make_hifigan_generator_args(**dict_g) args_d = make_hifigan_multi_scale_multi_period_discriminator_args(**dict_d) args_loss = make_mutli_reso_stft_loss_args(**dict_loss) y = torch.randn(batch_size, 1, batch_length) c = torch.randn( batch_size, args_g["in_channels"], batch_length // np.prod(args_g["upsample_scales"]), ) model_g = HiFiGANGenerator(**args_g) model_d = HiFiGANMultiScaleMultiPeriodDiscriminator(**args_d) aux_criterion = MultiResolutionSTFTLoss(**args_loss) feat_match_criterion = FeatureMatchLoss( average_by_layers=False, average_by_discriminators=False, include_final_outputs=True, ) gen_adv_criterion = GeneratorAdversarialLoss( average_by_discriminators=False, ) dis_adv_criterion = DiscriminatorAdversarialLoss( average_by_discriminators=False, ) optimizer_g = torch.optim.AdamW(model_g.parameters()) optimizer_d = torch.optim.AdamW(model_d.parameters()) # check generator trainable y_hat = model_g(c) p_hat = model_d(y_hat) sc_loss, mag_loss = aux_criterion(y_hat, y) aux_loss = sc_loss + mag_loss adv_loss = gen_adv_criterion(p_hat) with torch.no_grad(): p = model_d(y) fm_loss = feat_match_criterion(p_hat, p) loss_g = adv_loss + aux_loss + fm_loss optimizer_g.zero_grad() loss_g.backward() optimizer_g.step() # check discriminator trainable p = model_d(y) p_hat = model_d(y_hat.detach()) real_loss, fake_loss = dis_adv_criterion(p_hat, p) loss_d = real_loss + fake_loss optimizer_d.zero_grad() loss_d.backward() optimizer_d.step() print(model_d) print(model_g)
def test_melgan_trainable_with_melgan_discriminator(dict_g, dict_d, dict_loss): # setup batch_size = 4 batch_length = 4096 args_g = make_melgan_generator_args(**dict_g) args_d = make_melgan_discriminator_args(**dict_d) args_loss = make_mutli_reso_stft_loss_args(**dict_loss) y = torch.randn(batch_size, 1, batch_length) c = torch.randn( batch_size, args_g["in_channels"], batch_length // np.prod(args_g["upsample_scales"]), ) model_g = MelGANGenerator(**args_g) model_d = MelGANMultiScaleDiscriminator(**args_d) aux_criterion = MultiResolutionSTFTLoss(**args_loss) feat_match_criterion = FeatureMatchLoss() gen_adv_criterion = GeneratorAdversarialLoss() dis_adv_criterion = DiscriminatorAdversarialLoss() optimizer_g = RAdam(model_g.parameters()) optimizer_d = RAdam(model_d.parameters()) # check generator trainable y_hat = model_g(c) p_hat = model_d(y_hat) sc_loss, mag_loss = aux_criterion(y_hat, y) aux_loss = sc_loss + mag_loss adv_loss = gen_adv_criterion(p_hat) with torch.no_grad(): p = model_d(y) fm_loss = feat_match_criterion(p_hat, p) loss_g = adv_loss + aux_loss + fm_loss optimizer_g.zero_grad() loss_g.backward() optimizer_g.step() # check discriminator trainable p = model_d(y) p_hat = model_d(y_hat.detach()) real_loss, fake_loss = dis_adv_criterion(p_hat, p) loss_d = real_loss + fake_loss optimizer_d.zero_grad() loss_d.backward() optimizer_d.step()
def test_parallel_wavegan_with_residual_discriminator_trainable( dict_g, dict_d, dict_loss): # setup batch_size = 4 batch_length = 4096 args_g = make_generator_args(**dict_g) args_d = make_residual_discriminator_args(**dict_d) args_loss = make_mutli_reso_stft_loss_args(**dict_loss) z = torch.randn(batch_size, 1, batch_length) y = torch.randn(batch_size, 1, batch_length) c = torch.randn( batch_size, args_g["aux_channels"], batch_length // np.prod(args_g["upsample_params"]["upsample_scales"]) + 2 * args_g["aux_context_window"], ) model_g = ParallelWaveGANGenerator(**args_g) model_d = ResidualParallelWaveGANDiscriminator(**args_d) aux_criterion = MultiResolutionSTFTLoss(**args_loss) gen_adv_criterion = GeneratorAdversarialLoss() dis_adv_criterion = DiscriminatorAdversarialLoss() optimizer_g = RAdam(model_g.parameters()) optimizer_d = RAdam(model_d.parameters()) # check generator trainable y_hat = model_g(z, c) p_hat = model_d(y_hat) adv_loss = gen_adv_criterion(p_hat) sc_loss, mag_loss = aux_criterion(y_hat, y) aux_loss = sc_loss + mag_loss loss_g = adv_loss + aux_loss optimizer_g.zero_grad() loss_g.backward() optimizer_g.step() # check discriminator trainable p = model_d(y) p_hat = model_d(y_hat.detach()) real_loss, fake_loss = dis_adv_criterion(p_hat, p) loss_d = real_loss + fake_loss optimizer_d.zero_grad() loss_d.backward() optimizer_d.step()
def test_style_melgan_trainable(dict_g, dict_d, dict_loss, loss_type): # setup args_g = make_style_melgan_generator_args(**dict_g) args_d = make_style_melgan_discriminator_args(**dict_d) args_loss = make_mutli_reso_stft_loss_args(**dict_loss) batch_size = 4 batch_length = np.prod(args_g["noise_upsample_scales"]) * np.prod( args_g["upsample_scales"] ) y = torch.randn(batch_size, 1, batch_length) c = torch.randn( batch_size, args_g["aux_channels"], batch_length // np.prod(args_g["upsample_scales"]), ) model_g = StyleMelGANGenerator(**args_g) model_d = StyleMelGANDiscriminator(**args_d) aux_criterion = MultiResolutionSTFTLoss(**args_loss) gen_adv_criterion = GeneratorAdversarialLoss(loss_type=loss_type) dis_adv_criterion = DiscriminatorAdversarialLoss(loss_type=loss_type) optimizer_g = torch.optim.Adam(model_g.parameters()) optimizer_d = torch.optim.Adam(model_d.parameters()) # check generator trainable y_hat = model_g(c) p_hat = model_d(y_hat) adv_loss = gen_adv_criterion(p_hat) sc_loss, mag_loss = aux_criterion(y_hat, y) aux_loss = sc_loss + mag_loss loss_g = adv_loss + aux_loss optimizer_g.zero_grad() loss_g.backward() optimizer_g.step() # check discriminator trainable p = model_d(y) p_hat = model_d(y_hat.detach()) real_loss, fake_loss = dis_adv_criterion(p_hat, p) loss_d = real_loss + fake_loss optimizer_d.zero_grad() loss_d.backward() optimizer_d.step()
def test_parallel_wavegan_trainable(dict_g, dict_d, dict_loss): # setup batch_size = 4 batch_length = 4096 args_g = make_generator_args(**dict_g) args_d = make_discriminator_args(**dict_d) args_loss = make_mutli_reso_stft_loss_args(**dict_loss) z = torch.randn(batch_size, 1, batch_length) y = torch.randn(batch_size, 1, batch_length) c = torch.randn( batch_size, args_g["aux_channels"], batch_length // np.prod(args_g["upsample_params"]["upsample_scales"]) + 2 * args_g["aux_context_window"]) model_g = ParallelWaveGANGenerator(**args_g) model_d = ParallelWaveGANDiscriminator(**args_d) aux_criterion = MultiResolutionSTFTLoss(**args_loss) optimizer_g = RAdam(model_g.parameters()) optimizer_d = RAdam(model_d.parameters()) # check generator trainable y_hat = model_g(z, c) p_hat = model_d(y_hat) y, y_hat, p_hat = y.squeeze(1), y_hat.squeeze(1), p_hat.squeeze(1) adv_loss = F.mse_loss(p_hat, p_hat.new_ones(p_hat.size())) sc_loss, mag_loss = aux_criterion(y_hat, y) aux_loss = sc_loss + mag_loss loss_g = adv_loss + aux_loss optimizer_g.zero_grad() loss_g.backward() optimizer_g.step() # check discriminator trainable y, y_hat = y.unsqueeze(1), y_hat.unsqueeze(1).detach() p = model_d(y) p_hat = model_d(y_hat) p, p_hat = p.squeeze(1), p_hat.squeeze(1) loss_d = F.mse_loss(p, p.new_ones(p.size())) + F.mse_loss( p_hat, p_hat.new_zeros(p_hat.size())) optimizer_d.zero_grad() loss_d.backward() optimizer_d.step()
def main(): """Run training process.""" parser = argparse.ArgumentParser( description= "Train Parallel WaveGAN (See detail in parallel_wavegan/bin/train.py)." ) parser.add_argument("--train-dumpdir", type=str, required=True, help="directory including training data.") parser.add_argument("--dev-dumpdir", type=str, required=True, help="directory including development data.") parser.add_argument("--outdir", type=str, required=True, help="directory to save checkpoints.") parser.add_argument("--config", type=str, required=True, help="yaml format configuration file.") parser.add_argument( "--resume", default="", type=str, nargs="?", help="checkpoint file path to resume training. (default=\"\")") parser.add_argument( "--verbose", type=int, default=1, help="logging level. higher is more logging. (default=1)") parser.add_argument( "--rank", "--local_rank", default=0, type=int, help="rank for distributed training. no need to explictly specify.") args = parser.parse_args() args.distributed = False if not torch.cuda.is_available(): device = torch.device("cpu") else: device = torch.device("cuda") # effective when using fixed size inputs # see https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936 torch.backends.cudnn.benchmark = True torch.cuda.set_device(args.rank) # setup for distributed training # see example: https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed if "WORLD_SIZE" in os.environ: args.world_size = int(os.environ["WORLD_SIZE"]) args.distributed = args.world_size > 1 if args.distributed: torch.distributed.init_process_group(backend="nccl", init_method="env://") # suppress logging for distributed training if args.rank != 0: sys.stdout = open(os.devnull, "w") # set logger if args.verbose > 1: logging.basicConfig( level=logging.DEBUG, stream=sys.stdout, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") elif args.verbose > 0: logging.basicConfig( level=logging.INFO, stream=sys.stdout, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") else: logging.basicConfig( level=logging.WARN, stream=sys.stdout, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") logging.warning("Skip DEBUG/INFO messages") # check directory existence if not os.path.exists(args.outdir): os.makedirs(args.outdir) # load and save config with open(args.config) as f: config = yaml.load(f, Loader=yaml.Loader) config.update(vars(args)) config["version"] = parallel_wavegan.__version__ # add version info with open(os.path.join(args.outdir, "config.yml"), "w") as f: yaml.dump(config, f, Dumper=yaml.Dumper) for key, value in config.items(): logging.info(f"{key} = {value}") # get dataset if config["remove_short_samples"]: mel_length_threshold = config["batch_max_steps"] // config["hop_size"] + \ 2 * config["generator_params"].get("aux_context_window", 0) else: mel_length_threshold = None if config["format"] == "hdf5": audio_query, mel_query = "*.h5", "*.h5" audio_load_fn = lambda x: read_hdf5(x, "wave") # NOQA mel_load_fn = lambda x: read_hdf5(x, "feats") # NOQA elif config["format"] == "npy": audio_query, mel_query = "*-wave.npy", "*-feats.npy" audio_load_fn = np.load mel_load_fn = np.load else: raise ValueError("support only hdf5 or npy format.") dataset = { "train": AudioMelDataset( root_dir=args.train_dumpdir, audio_query=audio_query, mel_query=mel_query, audio_load_fn=audio_load_fn, mel_load_fn=mel_load_fn, mel_length_threshold=mel_length_threshold, allow_cache=config.get("allow_cache", False), # keep compatibility ), "dev": AudioMelDataset( root_dir=args.dev_dumpdir, audio_query=audio_query, mel_query=mel_query, audio_load_fn=audio_load_fn, mel_load_fn=mel_load_fn, mel_length_threshold=mel_length_threshold, allow_cache=config.get("allow_cache", False), # keep compatibility ), } # get data loader collater = Collater( batch_max_steps=config["batch_max_steps"], hop_size=config["hop_size"], # keep compatibility aux_context_window=config["generator_params"].get( "aux_context_window", 0), # keep compatibility use_noise_input=config.get( "generator_type", "ParallelWaveGANGenerator") != "MelGANGenerator", ) train_sampler, dev_sampler = None, None if args.distributed: # setup sampler for distributed training from torch.utils.data.distributed import DistributedSampler train_sampler = DistributedSampler( dataset=dataset["train"], num_replicas=args.world_size, rank=args.rank, shuffle=True, ) dev_sampler = DistributedSampler( dataset=dataset["dev"], num_replicas=args.world_size, rank=args.rank, shuffle=False, ) data_loader = { "train": DataLoader( dataset=dataset["train"], shuffle=False if args.distributed else True, collate_fn=collater, batch_size=config["batch_size"], num_workers=config["num_workers"], sampler=train_sampler, pin_memory=config["pin_memory"], ), "dev": DataLoader( dataset=dataset["dev"], shuffle=False if args.distributed else True, collate_fn=collater, batch_size=config["batch_size"], num_workers=config["num_workers"], sampler=dev_sampler, pin_memory=config["pin_memory"], ), } # define models and optimizers generator_class = getattr( parallel_wavegan.models, # keep compatibility config.get("generator_type", "ParallelWaveGANGenerator"), ) discriminator_class = getattr( parallel_wavegan.models, # keep compatibility config.get("discriminator_type", "ParallelWaveGANDiscriminator"), ) model = { "generator": generator_class(**config["generator_params"]).to(device), "discriminator": discriminator_class(**config["discriminator_params"]).to(device), } criterion = { "stft": MultiResolutionSTFTLoss(**config["stft_loss_params"]).to(device), "mse": torch.nn.MSELoss().to(device), } if config.get("use_feat_match_loss", False): # keep compatibility criterion["l1"] = torch.nn.L1Loss().to(device) optimizer = { "generator": RAdam(model["generator"].parameters(), **config["generator_optimizer_params"]), "discriminator": RAdam(model["discriminator"].parameters(), **config["discriminator_optimizer_params"]), } scheduler = { "generator": torch.optim.lr_scheduler.StepLR( optimizer=optimizer["generator"], **config["generator_scheduler_params"]), "discriminator": torch.optim.lr_scheduler.StepLR( optimizer=optimizer["discriminator"], **config["discriminator_scheduler_params"]), } if args.distributed: # wrap model for distributed training try: from apex.parallel import DistributedDataParallel except ImportError: raise ImportError( "apex is not installed. please check https://github.com/NVIDIA/apex." ) model["generator"] = DistributedDataParallel(model["generator"]) model["discriminator"] = DistributedDataParallel( model["discriminator"]) logging.info(model["generator"]) logging.info(model["discriminator"]) # define trainer trainer = Trainer( steps=0, epochs=0, data_loader=data_loader, model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, config=config, device=device, ) # resume from checkpoint if len(args.resume) != 0: trainer.load_checkpoint(args.resume) logging.info(f"Successfully resumed from {args.resume}.") # run training loop try: trainer.run() except KeyboardInterrupt: trainer.save_checkpoint( os.path.join(config["outdir"], f"checkpoint-{trainer.steps}steps.pkl")) logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
def main(): """Run training process.""" parser = argparse.ArgumentParser( description= "Train Parallel WaveGAN (See detail in parallel_wavegan/bin/train.py)." ) parser.add_argument("--train-dumpdir", type=str, required=True, help="directory including trainning data.") parser.add_argument("--dev-dumpdir", type=str, required=True, help="directory including development data.") parser.add_argument("--outdir", type=str, required=True, help="directory to save checkpoints.") parser.add_argument("--config", type=str, required=True, help="yaml format configuration file.") parser.add_argument( "--resume", default="", type=str, nargs="?", help="checkpoint file path to resume training. (default=\"\")") parser.add_argument( "--verbose", type=int, default=1, help="logging level. higher is more logging. (default=1)") args = parser.parse_args() # set logger if args.verbose > 1: logging.basicConfig( level=logging.DEBUG, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") elif args.verbose > 0: logging.basicConfig( level=logging.INFO, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") else: logging.basicConfig( level=logging.WARN, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") logging.warning('skip DEBUG/INFO messages') # check directory existence if not os.path.exists(args.outdir): os.makedirs(args.outdir) # load and save config with open(args.config) as f: config = yaml.load(f, Loader=yaml.Loader) config.update(vars(args)) with open(os.path.join(args.outdir, "config.yml"), "w") as f: yaml.dump(config, f, Dumper=yaml.Dumper) for key, value in config.items(): logging.info(f"{key} = {value}") # get dataset if config["remove_short_samples"]: mel_length_threshold = config["batch_max_steps"] // config["hop_size"] + \ 2 * config["generator_params"]["aux_context_window"] else: mel_length_threshold = None if config["format"] == "hdf5": audio_query, mel_query = "*.h5", "*.h5" audio_load_fn = lambda x: read_hdf5(x, "wave") # NOQA mel_load_fn = lambda x: read_hdf5(x, "feats") # NOQA elif config["format"] == "npy": audio_query, mel_query = "*-wave.npy", "*-feats.npy" audio_load_fn = np.load mel_load_fn = np.load else: raise ValueError("support only hdf5 or npy format.") dataset = { "train": AudioMelDataset( root_dir=args.train_dumpdir, audio_query=audio_query, mel_query=mel_query, audio_load_fn=audio_load_fn, mel_load_fn=mel_load_fn, mel_length_threshold=mel_length_threshold, allow_cache=config.get("allow_cache", False), # keep compatibilty ), "dev": AudioMelDataset( root_dir=args.dev_dumpdir, audio_query=audio_query, mel_query=mel_query, audio_load_fn=audio_load_fn, mel_load_fn=mel_load_fn, mel_length_threshold=mel_length_threshold, allow_cache=config.get("allow_cache", False), # keep compatibilty ), } # get data loader if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") collater = Collater( batch_max_steps=config["batch_max_steps"], hop_size=config["hop_size"], aux_context_window=config["generator_params"]["aux_context_window"], ) data_loader = { "train": DataLoader(dataset=dataset["train"], shuffle=True, collate_fn=collater, batch_size=config["batch_size"], num_workers=config["num_workers"], pin_memory=config["pin_memory"]), "dev": DataLoader(dataset=dataset["dev"], shuffle=True, collate_fn=collater, batch_size=config["batch_size"], num_workers=config["num_workers"], pin_memory=config["pin_memory"]), } # define models and optimizers model = { "generator": ParallelWaveGANGenerator(**config["generator_params"]).to(device), "discriminator": ParallelWaveGANDiscriminator( **config["discriminator_params"]).to(device), } criterion = { "stft": MultiResolutionSTFTLoss(**config["stft_loss_params"]).to(device), "mse": torch.nn.MSELoss().to(device), } optimizer = { "generator": RAdam(model["generator"].parameters(), **config["generator_optimizer_params"]), "discriminator": RAdam(model["discriminator"].parameters(), **config["discriminator_optimizer_params"]), } scheduler = { "generator": torch.optim.lr_scheduler.StepLR( optimizer=optimizer["generator"], **config["generator_scheduler_params"]), "discriminator": torch.optim.lr_scheduler.StepLR( optimizer=optimizer["discriminator"], **config["discriminator_scheduler_params"]), } logging.info(model["generator"]) logging.info(model["discriminator"]) # define trainer trainer = Trainer( steps=0, epochs=0, data_loader=data_loader, model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, config=config, device=device, ) # resume from checkpoint if len(args.resume) != 0: trainer.load_checkpoint(args.resume) logging.info(f"resumed from {args.resume}.") # run training loop try: trainer.run() finally: trainer.save_checkpoint( os.path.join(config["outdir"], f"checkpoint-{trainer.steps}steps.pkl")) logging.info(f"successfully saved checkpoint @ {trainer.steps}steps.")
def main(): """The main function that runs training process.""" # initialize the argument parser parser = argparse.ArgumentParser(description="Train Parallel WaveGAN.") # just a description of the job that the parser is used to support. # Add arguments to the parser #first is name of the argument #default: The value produced if the argument is absent from the command line #type: The type to which the command-line argument should be converted #help: hint that appears when the user doesnot know what is this argument [-h] #required: Whether or not the command-line option may be omitted (optionals only) #nargs:The number of command-line arguments that should be consumed # "?" One argument will be consumed from the command line if possible, and produced as a single item. If no command-line argument is present, # the value from default will be produced. # Note that for optional arguments, there is an additional case - the option string is present but not followed by a command-line argument. In this case the value from const will be produced. parser.add_argument("--train-wav-scp", default=None, type=str, help="kaldi-style wav.scp file for training. " "you need to specify either train-*-scp or train-dumpdir.") parser.add_argument("--train-feats-scp", default=None, type=str, help="kaldi-style feats.scp file for training. " "you need to specify either train-*-scp or train-dumpdir.") parser.add_argument("--train-segments", default=None, type=str, help="kaldi-style segments file for training.") parser.add_argument("--train-dumpdir", default=None, type=str, help="directory including training data. " "you need to specify either train-*-scp or train-dumpdir.") parser.add_argument("--dev-wav-scp", default=None, type=str, help="kaldi-style wav.scp file for validation. " "you need to specify either dev-*-scp or dev-dumpdir.") parser.add_argument("--dev-feats-scp", default=None, type=str, help="kaldi-style feats.scp file for vaidation. " "you need to specify either dev-*-scp or dev-dumpdir.") parser.add_argument("--dev-segments", default=None, type=str, help="kaldi-style segments file for validation.") parser.add_argument("--dev-dumpdir", default=None, type=str, help="directory including development data. " "you need to specify either dev-*-scp or dev-dumpdir.") parser.add_argument("--outdir", type=str, required=True, help="directory to save checkpoints.") parser.add_argument("--config", type=str, required=True, help="yaml format configuration file.") parser.add_argument("--pretrain", default="", type=str, nargs="?", help="checkpoint file path to load pretrained params. (default=\"\")") parser.add_argument("--resume", default="", type=str, nargs="?", help="checkpoint file path to resume training. (default=\"\")") parser.add_argument("--verbose", type=int, default=1, help="logging level. higher is more logging. (default=1)") parser.add_argument("--rank", "--local_rank", default=0, type=int, help="rank for distributed training. no need to explictly specify.") # parse all the input arguments args = parser.parse_args() args.distributed = False if not torch.cuda.is_available(): #if gpu is not available device = torch.device("cpu") #train on cpu else: #GPU device = torch.device("cuda")#train on gpu torch.backends.cudnn.benchmark = True # effective when using fixed size inputs (no conditional layers or layers inside loops),benchmark mode in cudnn,faster runtime torch.cuda.set_device(args.rank) # sets the default GPU for distributed training if "WORLD_SIZE" in os.environ:#determine max number of parallel processes (distributed) args.world_size = int(os.environ["WORLD_SIZE"]) #get the world size from the os args.distributed = args.world_size > 1 #set distributed if woldsize > 1 if args.distributed: torch.distributed.init_process_group(backend="nccl", init_method="env://") #Use the NCCL backend for distributed GPU training (Rule of thumb) #NCCL:since it currently provides the best distributed GPU training performance, especially for multiprocess single-node or multi-node distributed training # suppress logging for distributed training if args.rank != 0: #if process is not p0 sys.stdout = open(os.devnull, "w")#DEVNULL is Special value that can be used as the stdin, stdout or stderr argument to # set logger if args.verbose > 1: #if level of logging is heigher then 1 logging.basicConfig( #configure the logging level=logging.DEBUG, stream=sys.stdout, #heigh logging level,detailed information, typically of interest only when diagnosing problems. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") #format includes Time,module,line#,level,and message. elif args.verbose > 0:#if level of logging is between 0,1 logging.basicConfig(#configure the logging level=logging.INFO, stream=sys.stdout,#moderate logging level,Confirmation that things are working as expected. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")#format includes Time,module,line#,level,and message. else:#if level of logging is 0 logging.basicConfig(#configure the logging level=logging.WARN, stream=sys.stdout,#low logging level,An indication that something unexpected happened, or indicative of some problem in the near future (e.g. ‘disk space low’). format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")#format includes Time,module,line#,level,and message. logging.warning("Skip DEBUG/INFO messages")#tell the user that he will skip logging DEBUG/INFO messages by choosing this level. # check directory existence if not os.path.exists(args.outdir): #directory to save checkpoints os.makedirs(args.outdir) # check arguments if (args.train_feats_scp is not None and args.train_dumpdir is not None) or \ (args.train_feats_scp is None and args.train_dumpdir is None): # if the user chooses both training data files (examples) or # the user doesnot choose any training data file raise ValueError("Please specify either --train-dumpdir or --train-*-scp.") #raise an error to tell the user to choose one training file if (args.dev_feats_scp is not None and args.dev_dumpdir is not None) or \ (args.dev_feats_scp is None and args.dev_dumpdir is None): # if the user chooses both validatation data files (examples) or # the user doesnot choose any validatation data file raise ValueError("Please specify either --dev-dumpdir or --dev-*-scp.") #raise an error to tell the user to choose one validation data file # load config with open(args.config) as f:#open configuration file (yaml format) config = yaml.load(f, Loader=yaml.Loader) #load configuration file (yaml format to python object) # update config config.update(vars(args))#update arguments in configuration file config["version"] = parallel_wavegan.__version__ # add parallel wavegan version info # save config with open(os.path.join(args.outdir, "config.yml"), "w") as f:#open outdir/config.yml yaml.dump(config, f, Dumper=yaml.Dumper) #dump function accepts a Python object and produces a YAML document. # add config info to the high level logger. for key, value in config.items(): logging.info(f"{key} = {value}") # get dataset if config["remove_short_samples"]:#if configuration tells to remove short samples from training. mel_length_threshold = config["batch_max_steps"] // config["hop_size"] + \ 2 * config["generator_params"].get("aux_context_window", 0)#th of length = floor(batch_max_steps/hop_size) + 2 * (generator_params.aux_context_window) else: mel_length_threshold = None # No th. if args.train_wav_scp is None or args.dev_wav_scp is None: #if at least one of training or evaluating datasets = None if config["format"] == "hdf5":# format of data = hdf5 audio_query, mel_query = "*.h5", "*.h5" # audio and text queries = "...".h5 #lambda example: #x = lambda a, b: a * b #x(5, 6)-->x(a=5,b=6)=a*b=5*6=30 audio_load_fn = lambda x: read_hdf5(x, "wave") # The function to load data,NOQA mel_load_fn = lambda x: read_hdf5(x, "feats") # The function to load data,NOQA elif config["format"] == "npy":# format of data = npy audio_query, mel_query = "*-wave.npy", "*-feats.npy" #audio query = "..."-wave.npy and text query = "..."-feats.h5 audio_load_fn = np.load#The function to load data. mel_load_fn = np.load#The function to load data. else:#if any other data format raise ValueError("support only hdf5 or npy format.") #raise error to tell the user the data format is not supported. if args.train_dumpdir is not None: # if training ds is not None train_dataset = AudioMelDataset( # define the training dataset root_dir=args.train_dumpdir,#the directory of ds. audio_query=audio_query,#audio query according to format above. mel_query=mel_query,#mel query according to format above. audio_load_fn=audio_load_fn,#load the function that loads the audio data according to format above. mel_load_fn=mel_load_fn,#load the function that loads the mel data according to format above. mel_length_threshold=mel_length_threshold,#th to remove short samples -calculated above-. allow_cache=config.get("allow_cache", False), # keep compatibility. ) else:# if training ds is None train_dataset = AudioMelSCPDataset(# define the training dataset wav_scp=args.train_wav_scp, feats_scp=args.train_feats_scp, segments=args.train_segments, #segments of dataset mel_length_threshold=mel_length_threshold,#th to remove short samples -calculated above-. allow_cache=config.get("allow_cache", False), # keep compatibility ) logging.info(f"The number of training files = {len(train_dataset)}.") # add length of trainning data set to the logger. if args.dev_dumpdir is not None: #if evaluating ds is not None dev_dataset = AudioMelDataset( # define the evaluating dataset root_dir=args.dev_dumpdir,#the directory of ds. audio_query=audio_query,#audio query according to format above. mel_query=mel_query,#mel query according to format above. audio_load_fn=audio_load_fn,#load the function that loads the audio data according to format above. mel_load_fn=mel_load_fn,#load the function that loads the mel data according to format above. mel_length_threshold=mel_length_threshold,#th to remove short samples -calculated above-. allow_cache=config.get("allow_cache", False), # keep compatibility ) else:# if evaluating ds is None dev_dataset = AudioMelSCPDataset( wav_scp=args.dev_wav_scp, feats_scp=args.dev_feats_scp, segments=args.dev_segments,#segments of dataset mel_length_threshold=mel_length_threshold,#th to remove short samples -calculated above-. allow_cache=config.get("allow_cache", False), # keep compatibility ) logging.info(f"The number of development files = {len(dev_dataset)}.") # add length of evaluating data set to the logger. dataset = { "train": train_dataset, "dev": dev_dataset, } #define the whole dataset used which is divided into training and evaluating datasets # get data loader collater = Collater( batch_max_steps=config["batch_max_steps"], hop_size=config["hop_size"], # keep compatibility aux_context_window=config["generator_params"].get("aux_context_window", 0), # keep compatibility use_noise_input=config.get( "generator_type", "ParallelWaveGANGenerator") != "MelGANGenerator", ) train_sampler, dev_sampler = None, None if args.distributed: # setup sampler for distributed training from torch.utils.data.distributed import DistributedSampler train_sampler = DistributedSampler( dataset=dataset["train"], num_replicas=args.world_size, rank=args.rank, shuffle=True, ) dev_sampler = DistributedSampler( dataset=dataset["dev"], num_replicas=args.world_size, rank=args.rank, shuffle=False, ) data_loader = { "train": DataLoader( dataset=dataset["train"], shuffle=False if args.distributed else True, collate_fn=collater, batch_size=config["batch_size"], num_workers=config["num_workers"], sampler=train_sampler, pin_memory=config["pin_memory"], ), "dev": DataLoader( dataset=dataset["dev"], shuffle=False if args.distributed else True, collate_fn=collater, batch_size=config["batch_size"], num_workers=config["num_workers"], sampler=dev_sampler, pin_memory=config["pin_memory"], ), } # define models and optimizers generator_class = getattr( parallel_wavegan.models, # keep compatibility config.get("generator_type", "ParallelWaveGANGenerator"), ) discriminator_class = getattr( parallel_wavegan.models, # keep compatibility config.get("discriminator_type", "ParallelWaveGANDiscriminator"), ) model = { "generator": generator_class( **config["generator_params"]).to(device), "discriminator": discriminator_class( **config["discriminator_params"]).to(device), } criterion = { "stft": MultiResolutionSTFTLoss( **config["stft_loss_params"]).to(device), "mse": torch.nn.MSELoss().to(device), } if config.get("use_feat_match_loss", False): # keep compatibility criterion["l1"] = torch.nn.L1Loss().to(device) optimizer = { "generator": RAdam( model["generator"].parameters(), **config["generator_optimizer_params"]), "discriminator": RAdam( model["discriminator"].parameters(), **config["discriminator_optimizer_params"]), } scheduler = { "generator": torch.optim.lr_scheduler.StepLR( optimizer=optimizer["generator"], **config["generator_scheduler_params"]), "discriminator": torch.optim.lr_scheduler.StepLR( optimizer=optimizer["discriminator"], **config["discriminator_scheduler_params"]), } if args.distributed: # wrap model for distributed training try: from apex.parallel import DistributedDataParallel except ImportError: raise ImportError("apex is not installed. please check https://github.com/NVIDIA/apex.") model["generator"] = DistributedDataParallel(model["generator"]) model["discriminator"] = DistributedDataParallel(model["discriminator"]) logging.info(model["generator"]) logging.info(model["discriminator"]) # define trainer trainer = Trainer( steps=0, epochs=0, data_loader=data_loader, model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, config=config, device=device, ) # load pretrained parameters from checkpoint if len(args.pretrain) != 0: trainer.load_checkpoint(args.pretrain, load_only_params=True) logging.info(f"Successfully load parameters from {args.pretrain}.") # resume from checkpoint if len(args.resume) != 0: trainer.load_checkpoint(args.resume) logging.info(f"Successfully resumed from {args.resume}.") # run training loop try: trainer.run() except KeyboardInterrupt: trainer.save_checkpoint( os.path.join(config["outdir"], f"checkpoint-{trainer.steps}steps.pkl")) logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
def main(): """Run training process.""" parser = argparse.ArgumentParser(description=( "Train Parallel WaveGAN (See detail in parallel_wavegan/bin/train.py)." )) parser.add_argument( "--train-wav-scp", default=None, type=str, help=("kaldi-style wav.scp file for training. " "you need to specify either train-*-scp or train-dumpdir."), ) parser.add_argument( "--train-feats-scp", default=None, type=str, help=("kaldi-style feats.scp file for training. " "you need to specify either train-*-scp or train-dumpdir."), ) parser.add_argument( "--train-segments", default=None, type=str, help="kaldi-style segments file for training.", ) parser.add_argument( "--train-dumpdir", default=None, type=str, help=("directory including training data. " "you need to specify either train-*-scp or train-dumpdir."), ) parser.add_argument( "--dev-wav-scp", default=None, type=str, help=("kaldi-style wav.scp file for validation. " "you need to specify either dev-*-scp or dev-dumpdir."), ) parser.add_argument( "--dev-feats-scp", default=None, type=str, help=("kaldi-style feats.scp file for vaidation. " "you need to specify either dev-*-scp or dev-dumpdir."), ) parser.add_argument( "--dev-segments", default=None, type=str, help="kaldi-style segments file for validation.", ) parser.add_argument( "--dev-dumpdir", default=None, type=str, help=("directory including development data. " "you need to specify either dev-*-scp or dev-dumpdir."), ) parser.add_argument( "--outdir", type=str, required=True, help="directory to save checkpoints.", ) parser.add_argument( "--config", type=str, required=True, help="yaml format configuration file.", ) parser.add_argument( "--pretrain", default="", type=str, nargs="?", help='checkpoint file path to load pretrained params. (default="")', ) parser.add_argument( "--resume", default="", type=str, nargs="?", help='checkpoint file path to resume training. (default="")', ) parser.add_argument( "--verbose", type=int, default=1, help="logging level. higher is more logging. (default=1)", ) parser.add_argument( "--rank", "--local_rank", default=0, type=int, help="rank for distributed training. no need to explictly specify.", ) args = parser.parse_args() args.distributed = False if not torch.cuda.is_available(): device = torch.device("cpu") else: device = torch.device("cuda") # effective when using fixed size inputs # see https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936 torch.backends.cudnn.benchmark = True torch.cuda.set_device(args.rank) # setup for distributed training # see example: https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed if "WORLD_SIZE" in os.environ: args.world_size = int(os.environ["WORLD_SIZE"]) args.distributed = args.world_size > 1 if args.distributed: torch.distributed.init_process_group(backend="nccl", init_method="env://") # suppress logging for distributed training if args.rank != 0: sys.stdout = open(os.devnull, "w") # set logger if args.verbose > 1: logging.basicConfig( level=logging.DEBUG, stream=sys.stdout, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) elif args.verbose > 0: logging.basicConfig( level=logging.INFO, stream=sys.stdout, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) else: logging.basicConfig( level=logging.WARN, stream=sys.stdout, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) logging.warning("Skip DEBUG/INFO messages") # check directory existence if not os.path.exists(args.outdir): os.makedirs(args.outdir) # check arguments if (args.train_feats_scp is not None and args.train_dumpdir is not None) or (args.train_feats_scp is None and args.train_dumpdir is None): raise ValueError( "Please specify either --train-dumpdir or --train-*-scp.") if (args.dev_feats_scp is not None and args.dev_dumpdir is not None) or ( args.dev_feats_scp is None and args.dev_dumpdir is None): raise ValueError("Please specify either --dev-dumpdir or --dev-*-scp.") # load and save config with open(args.config) as f: config = yaml.load(f, Loader=yaml.Loader) config.update(vars(args)) config["version"] = parallel_wavegan.__version__ # add version info with open(os.path.join(args.outdir, "config.yml"), "w") as f: yaml.dump(config, f, Dumper=yaml.Dumper) for key, value in config.items(): logging.info(f"{key} = {value}") # get dataset if config["remove_short_samples"]: mel_length_threshold = config["batch_max_steps"] // config[ "hop_size"] + 2 * config["generator_params"].get( "aux_context_window", 0) else: mel_length_threshold = None if args.train_wav_scp is None or args.dev_wav_scp is None: if config["format"] == "hdf5": audio_query, mel_query = "*.h5", "*.h5" audio_load_fn = lambda x: read_hdf5(x, "wave") # NOQA mel_load_fn = lambda x: read_hdf5(x, "feats") # NOQA elif config["format"] == "npy": audio_query, mel_query = "*-wave.npy", "*-feats.npy" audio_load_fn = np.load mel_load_fn = np.load else: raise ValueError("support only hdf5 or npy format.") if args.train_dumpdir is not None: train_dataset = AudioMelDataset( root_dir=args.train_dumpdir, audio_query=audio_query, mel_query=mel_query, audio_load_fn=audio_load_fn, mel_load_fn=mel_load_fn, mel_length_threshold=mel_length_threshold, allow_cache=config.get("allow_cache", False), # keep compatibility ) else: train_dataset = AudioMelSCPDataset( wav_scp=args.train_wav_scp, feats_scp=args.train_feats_scp, segments=args.train_segments, mel_length_threshold=mel_length_threshold, allow_cache=config.get("allow_cache", False), # keep compatibility ) logging.info(f"The number of training files = {len(train_dataset)}.") if args.dev_dumpdir is not None: dev_dataset = AudioMelDataset( root_dir=args.dev_dumpdir, audio_query=audio_query, mel_query=mel_query, audio_load_fn=audio_load_fn, mel_load_fn=mel_load_fn, mel_length_threshold=mel_length_threshold, allow_cache=config.get("allow_cache", False), # keep compatibility ) else: dev_dataset = AudioMelSCPDataset( wav_scp=args.dev_wav_scp, feats_scp=args.dev_feats_scp, segments=args.dev_segments, mel_length_threshold=mel_length_threshold, allow_cache=config.get("allow_cache", False), # keep compatibility ) logging.info(f"The number of development files = {len(dev_dataset)}.") dataset = { "train": train_dataset, "dev": dev_dataset, } # get data loader collater = Collater( batch_max_steps=config["batch_max_steps"], hop_size=config["hop_size"], # keep compatibility aux_context_window=config["generator_params"].get( "aux_context_window", 0), # keep compatibility use_noise_input=config.get("generator_type", "ParallelWaveGANGenerator") in ["ParallelWaveGANGenerator"], ) sampler = {"train": None, "dev": None} if args.distributed: # setup sampler for distributed training from torch.utils.data.distributed import DistributedSampler sampler["train"] = DistributedSampler( dataset=dataset["train"], num_replicas=args.world_size, rank=args.rank, shuffle=True, ) sampler["dev"] = DistributedSampler( dataset=dataset["dev"], num_replicas=args.world_size, rank=args.rank, shuffle=False, ) data_loader = { "train": DataLoader( dataset=dataset["train"], shuffle=False if args.distributed else True, collate_fn=collater, batch_size=config["batch_size"], num_workers=config["num_workers"], sampler=sampler["train"], pin_memory=config["pin_memory"], ), "dev": DataLoader( dataset=dataset["dev"], shuffle=False if args.distributed else True, collate_fn=collater, batch_size=config["batch_size"], num_workers=config["num_workers"], sampler=sampler["dev"], pin_memory=config["pin_memory"], ), } # define models generator_class = getattr( parallel_wavegan.models, # keep compatibility config.get("generator_type", "ParallelWaveGANGenerator"), ) discriminator_class = getattr( parallel_wavegan.models, # keep compatibility config.get("discriminator_type", "ParallelWaveGANDiscriminator"), ) model = { "generator": generator_class(**config["generator_params"], ).to(device), "discriminator": discriminator_class(**config["discriminator_params"], ).to(device), } # define criterions criterion = { "gen_adv": GeneratorAdversarialLoss( # keep compatibility **config.get("generator_adv_loss_params", {})).to(device), "dis_adv": DiscriminatorAdversarialLoss( # keep compatibility **config.get("discriminator_adv_loss_params", {})).to(device), } if config.get("use_stft_loss", True): # keep compatibility config["use_stft_loss"] = True criterion["stft"] = MultiResolutionSTFTLoss( **config["stft_loss_params"], ).to(device) if config.get("use_subband_stft_loss", False): # keep compatibility assert config["generator_params"]["out_channels"] > 1 criterion["sub_stft"] = MultiResolutionSTFTLoss( **config["subband_stft_loss_params"], ).to(device) else: config["use_subband_stft_loss"] = False if config.get("use_feat_match_loss", False): # keep compatibility criterion["feat_match"] = FeatureMatchLoss( # keep compatibility **config.get("feat_match_loss_params", {}), ).to(device) else: config["use_feat_match_loss"] = False if config.get("use_mel_loss", False): # keep compatibility if config.get("mel_loss_params", None) is None: criterion["mel"] = MelSpectrogramLoss( fs=config["sampling_rate"], fft_size=config["fft_size"], hop_size=config["hop_size"], win_length=config["win_length"], window=config["window"], num_mels=config["num_mels"], fmin=config["fmin"], fmax=config["fmax"], ).to(device) else: criterion["mel"] = MelSpectrogramLoss(**config["mel_loss_params"], ).to(device) else: config["use_mel_loss"] = False # define special module for subband processing if config["generator_params"]["out_channels"] > 1: criterion["pqmf"] = PQMF( subbands=config["generator_params"]["out_channels"], # keep compatibility **config.get("pqmf_params", {}), ).to(device) # define optimizers and schedulers generator_optimizer_class = getattr( parallel_wavegan.optimizers, # keep compatibility config.get("generator_optimizer_type", "RAdam"), ) discriminator_optimizer_class = getattr( parallel_wavegan.optimizers, # keep compatibility config.get("discriminator_optimizer_type", "RAdam"), ) optimizer = { "generator": generator_optimizer_class( model["generator"].parameters(), **config["generator_optimizer_params"], ), "discriminator": discriminator_optimizer_class( model["discriminator"].parameters(), **config["discriminator_optimizer_params"], ), } generator_scheduler_class = getattr( torch.optim.lr_scheduler, # keep compatibility config.get("generator_scheduler_type", "StepLR"), ) discriminator_scheduler_class = getattr( torch.optim.lr_scheduler, # keep compatibility config.get("discriminator_scheduler_type", "StepLR"), ) scheduler = { "generator": generator_scheduler_class( optimizer=optimizer["generator"], **config["generator_scheduler_params"], ), "discriminator": discriminator_scheduler_class( optimizer=optimizer["discriminator"], **config["discriminator_scheduler_params"], ), } if args.distributed: # wrap model for distributed training try: from apex.parallel import DistributedDataParallel except ImportError: raise ImportError( "apex is not installed. please check https://github.com/NVIDIA/apex." ) model["generator"] = DistributedDataParallel(model["generator"]) model["discriminator"] = DistributedDataParallel( model["discriminator"]) # show settings logging.info(model["generator"]) logging.info(model["discriminator"]) logging.info(optimizer["generator"]) logging.info(optimizer["discriminator"]) logging.info(scheduler["generator"]) logging.info(scheduler["discriminator"]) for criterion_ in criterion.values(): logging.info(criterion_) # define trainer trainer = Trainer( steps=0, epochs=0, data_loader=data_loader, sampler=sampler, model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, config=config, device=device, ) # load pretrained parameters from checkpoint if len(args.pretrain) != 0: trainer.load_checkpoint(args.pretrain, load_only_params=True) logging.info(f"Successfully load parameters from {args.pretrain}.") # resume from checkpoint if len(args.resume) != 0: trainer.load_checkpoint(args.resume) logging.info(f"Successfully resumed from {args.resume}.") # run training loop try: trainer.run() finally: trainer.save_checkpoint( os.path.join(config["outdir"], f"checkpoint-{trainer.steps}steps.pkl")) logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
def main(): """Run training process.""" parser = argparse.ArgumentParser( description= "Train Parallel WaveGAN (See detail in parallel_wavegan/bin/train.py)." ) parser.add_argument('--input_training_file', default='filelists/train.txt') parser.add_argument('--input_validation_file', default='filelists/val.txt') parser.add_argument('--checkpoint_path', default='parallel_wavegan/outdir') parser.add_argument( '--config', default='parallel_wavegan/config/multi_band_melgan.v2.yaml') parser.add_argument('--checkpoint_interval', default=0, type=int) parser.add_argument('--max_steps', default=1000000, type=int) parser.add_argument('--n_models_to_keep', default=1, type=int) parser.add_argument( "--verbose", type=int, default=1, help="logging level. higher is more logging. (default=1)") parser.add_argument( "--rank", "--local_rank", default=0, type=int, help="rank for distributed training. no need to explictly specify.") args = parser.parse_args() args.distributed = False if not torch.cuda.is_available(): device = torch.device("cpu") else: device = torch.device("cuda") # effective when using fixed size inputs # see https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936 torch.backends.cudnn.benchmark = True torch.cuda.set_device(args.rank) # setup for distributed training # see example: https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed if "WORLD_SIZE" in os.environ: args.world_size = int(os.environ["WORLD_SIZE"]) args.distributed = args.world_size > 1 if args.distributed: torch.distributed.init_process_group(backend="nccl", init_method="env://") # suppress logging for distributed training if args.rank != 0: sys.stdout = open(os.devnull, "w") # set logger if args.verbose > 1: logging.basicConfig( level=logging.DEBUG, stream=sys.stdout, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") elif args.verbose > 0: logging.basicConfig( level=logging.INFO, stream=sys.stdout, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") else: logging.basicConfig( level=logging.WARN, stream=sys.stdout, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") logging.warning("Skip DEBUG/INFO messages") # check directory existence if not os.path.exists(args.checkpoint_path): os.makedirs(args.checkpoint_path) # load and save config with open(args.config) as f: config = yaml.load(f, Loader=yaml.Loader) config.update(vars(args)) config["version"] = parallel_wavegan.__version__ # add version info if config["checkpoint_interval"] > 0: config["save_interval_steps"] = config["checkpoint_interval"] if config["max_steps"] and config["max_steps"] > 0: config["train_max_steps"] = config["max_steps"] with open(os.path.join(args.checkpoint_path, "config.yml"), "w") as f: yaml.dump(config, f, Dumper=yaml.Dumper) for key, value in config.items(): logging.info(f"{key} = {value}") data_loader = {} data_loader["train"], sampler = from_path(config["input_training_file"], config) data_loader["dev"] = from_path(config["input_validation_file"], config)[0] # define models and optimizers generator_class = getattr( parallel_wavegan.models, config.get("generator_type", "ParallelWaveGANGenerator"), ) discriminator_class = getattr( parallel_wavegan.models, config.get("discriminator_type", "ParallelWaveGANDiscriminator"), ) model = { "generator": generator_class(**config["generator_params"]).to(device), "discriminator": discriminator_class(**config["discriminator_params"]).to(device), } criterion = {# reconstruction loss "stft": MultiResolutionSTFTLoss(**config["stft_loss_params"]).to(device), "mse": torch.nn.MSELoss().to(device), } if config.get("use_feat_match_loss", False): # keep compatibility criterion["l1"] = torch.nn.L1Loss().to(device) if config["generator_params"]["out_channels"] > 1: criterion["pqmf"] = PQMF( subbands=config["generator_params"]["out_channels"], # keep compatibility **config.get("pqmf_params", {})).to(device) if config.get("use_subband_stft_loss", False): # keep compatibility assert config["generator_params"]["out_channels"] > 1 criterion["sub_stft"] = MultiResolutionSTFTLoss( **config["subband_stft_loss_params"]).to(device) generator_optimizer_class = getattr( parallel_wavegan.optimizers, config.get("generator_optimizer_type", "RAdam"), ) discriminator_optimizer_class = getattr( parallel_wavegan.optimizers, config.get("discriminator_optimizer_type", "RAdam"), ) optimizer = { "generator": generator_optimizer_class( model["generator"].parameters(), **config["generator_optimizer_params"], ), "discriminator": discriminator_optimizer_class( model["discriminator"].parameters(), **config["discriminator_optimizer_params"], ), } generator_scheduler_class = getattr( torch.optim.lr_scheduler, config.get("generator_scheduler_type", "StepLR"), ) discriminator_scheduler_class = getattr( torch.optim.lr_scheduler, config.get("discriminator_scheduler_type", "StepLR"), ) scheduler = { "generator": generator_scheduler_class( optimizer=optimizer["generator"], **config["generator_scheduler_params"], ), "discriminator": discriminator_scheduler_class( optimizer=optimizer["discriminator"], **config["discriminator_scheduler_params"], ), } if args.distributed: # wrap model for distributed training try: from apex.parallel import DistributedDataParallel except ImportError: raise ImportError( "apex is not installed. please check https://github.com/NVIDIA/apex." ) model["generator"] = DistributedDataParallel(model["generator"]) model["discriminator"] = DistributedDataParallel( model["discriminator"]) logging.info(model["generator"]) logging.info(model["discriminator"]) # define trainer trainer = Trainer( steps=0, epochs=0, data_loader=data_loader, sampler=sampler, model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, config=config, device=device, ) # load pretrained parameters from checkpoint trainer.load_checkpoint() # run training loop try: trainer.run() except KeyboardInterrupt: logging.info(f"Saving checkpoint in 10...") for i in range(9, -1, -1): print(f"{i}...") trainer.save_checkpoint( os.path.join(config["outdir"], f"checkpoint-{trainer.steps:08}steps.pkl")) logging.info( f"Successfully saved checkpoint @ {trainer.steps:08}steps.")