def test_melgan_trainable_with_melgan_discriminator(dict_g, dict_d, dict_loss): # setup batch_size = 4 batch_length = 4096 args_g = make_melgan_generator_args(**dict_g) args_d = make_melgan_discriminator_args(**dict_d) args_loss = make_mutli_reso_stft_loss_args(**dict_loss) y = torch.randn(batch_size, 1, batch_length) c = torch.randn(batch_size, args_g["in_channels"], batch_length // np.prod( args_g["upsample_scales"])) model_g = MelGANGenerator(**args_g) model_d = MelGANMultiScaleDiscriminator(**args_d) aux_criterion = MultiResolutionSTFTLoss(**args_loss) optimizer_g = RAdam(model_g.parameters()) optimizer_d = RAdam(model_d.parameters()) # check generator trainable y_hat = model_g(c) p_hat = model_d(y_hat) y, y_hat = y.squeeze(1), y_hat.squeeze(1) sc_loss, mag_loss = aux_criterion(y_hat, y) aux_loss = sc_loss + mag_loss adv_loss = 0.0 for i in range(len(p_hat)): adv_loss += F.mse_loss( p_hat[i][-1], p_hat[i][-1].new_ones(p_hat[i][-1].size())) adv_loss /= (i + 1) with torch.no_grad(): p = model_d(y.unsqueeze(1)) fm_loss = 0.0 for i in range(len(p_hat)): for j in range(len(p_hat[i]) - 1): fm_loss += F.l1_loss(p_hat[i][j], p[i][j].detach()) fm_loss /= (i + 1) * j loss_g = adv_loss + aux_loss + fm_loss optimizer_g.zero_grad() loss_g.backward() optimizer_g.step() # check discriminator trainable y, y_hat = y.unsqueeze(1), y_hat.unsqueeze(1).detach() p = model_d(y) p_hat = model_d(y_hat) real_loss = 0.0 fake_loss = 0.0 for i in range(len(p)): real_loss += F.mse_loss( p[i][-1], p[i][-1].new_ones(p[i][-1].size())) fake_loss += F.mse_loss( p_hat[i][-1], p_hat[i][-1].new_zeros(p_hat[i][-1].size())) real_loss /= (i + 1) fake_loss /= (i + 1) loss_d = real_loss + fake_loss optimizer_d.zero_grad() loss_d.backward() optimizer_d.step()
def test_melgan_trainable_with_melgan_discriminator(dict_g, dict_d, dict_loss): # setup batch_size = 4 batch_length = 4096 args_g = make_melgan_generator_args(**dict_g) args_d = make_melgan_discriminator_args(**dict_d) args_loss = make_mutli_reso_stft_loss_args(**dict_loss) y = torch.randn(batch_size, 1, batch_length) c = torch.randn( batch_size, args_g["in_channels"], batch_length // np.prod(args_g["upsample_scales"]), ) model_g = MelGANGenerator(**args_g) model_d = MelGANMultiScaleDiscriminator(**args_d) aux_criterion = MultiResolutionSTFTLoss(**args_loss) feat_match_criterion = FeatureMatchLoss() gen_adv_criterion = GeneratorAdversarialLoss() dis_adv_criterion = DiscriminatorAdversarialLoss() optimizer_g = RAdam(model_g.parameters()) optimizer_d = RAdam(model_d.parameters()) # check generator trainable y_hat = model_g(c) p_hat = model_d(y_hat) sc_loss, mag_loss = aux_criterion(y_hat, y) aux_loss = sc_loss + mag_loss adv_loss = gen_adv_criterion(p_hat) with torch.no_grad(): p = model_d(y) fm_loss = feat_match_criterion(p_hat, p) loss_g = adv_loss + aux_loss + fm_loss optimizer_g.zero_grad() loss_g.backward() optimizer_g.step() # check discriminator trainable p = model_d(y) p_hat = model_d(y_hat.detach()) real_loss, fake_loss = dis_adv_criterion(p_hat, p) loss_d = real_loss + fake_loss optimizer_d.zero_grad() loss_d.backward() optimizer_d.step()
def test_parallel_wavegan_with_residual_discriminator_trainable( dict_g, dict_d, dict_loss): # setup batch_size = 4 batch_length = 4096 args_g = make_generator_args(**dict_g) args_d = make_residual_discriminator_args(**dict_d) args_loss = make_mutli_reso_stft_loss_args(**dict_loss) z = torch.randn(batch_size, 1, batch_length) y = torch.randn(batch_size, 1, batch_length) c = torch.randn( batch_size, args_g["aux_channels"], batch_length // np.prod(args_g["upsample_params"]["upsample_scales"]) + 2 * args_g["aux_context_window"], ) model_g = ParallelWaveGANGenerator(**args_g) model_d = ResidualParallelWaveGANDiscriminator(**args_d) aux_criterion = MultiResolutionSTFTLoss(**args_loss) gen_adv_criterion = GeneratorAdversarialLoss() dis_adv_criterion = DiscriminatorAdversarialLoss() optimizer_g = RAdam(model_g.parameters()) optimizer_d = RAdam(model_d.parameters()) # check generator trainable y_hat = model_g(z, c) p_hat = model_d(y_hat) adv_loss = gen_adv_criterion(p_hat) sc_loss, mag_loss = aux_criterion(y_hat, y) aux_loss = sc_loss + mag_loss loss_g = adv_loss + aux_loss optimizer_g.zero_grad() loss_g.backward() optimizer_g.step() # check discriminator trainable p = model_d(y) p_hat = model_d(y_hat.detach()) real_loss, fake_loss = dis_adv_criterion(p_hat, p) loss_d = real_loss + fake_loss optimizer_d.zero_grad() loss_d.backward() optimizer_d.step()
def test_parallel_wavegan_trainable(dict_g, dict_d, dict_loss): # setup batch_size = 4 batch_length = 4096 args_g = make_generator_args(**dict_g) args_d = make_discriminator_args(**dict_d) args_loss = make_mutli_reso_stft_loss_args(**dict_loss) z = torch.randn(batch_size, 1, batch_length) y = torch.randn(batch_size, 1, batch_length) c = torch.randn( batch_size, args_g["aux_channels"], batch_length // np.prod(args_g["upsample_params"]["upsample_scales"]) + 2 * args_g["aux_context_window"]) model_g = ParallelWaveGANGenerator(**args_g) model_d = ParallelWaveGANDiscriminator(**args_d) aux_criterion = MultiResolutionSTFTLoss(**args_loss) optimizer_g = RAdam(model_g.parameters()) optimizer_d = RAdam(model_d.parameters()) # check generator trainable y_hat = model_g(z, c) p_hat = model_d(y_hat) y, y_hat, p_hat = y.squeeze(1), y_hat.squeeze(1), p_hat.squeeze(1) adv_loss = F.mse_loss(p_hat, p_hat.new_ones(p_hat.size())) sc_loss, mag_loss = aux_criterion(y_hat, y) aux_loss = sc_loss + mag_loss loss_g = adv_loss + aux_loss optimizer_g.zero_grad() loss_g.backward() optimizer_g.step() # check discriminator trainable y, y_hat = y.unsqueeze(1), y_hat.unsqueeze(1).detach() p = model_d(y) p_hat = model_d(y_hat) p, p_hat = p.squeeze(1), p_hat.squeeze(1) loss_d = F.mse_loss(p, p.new_ones(p.size())) + F.mse_loss( p_hat, p_hat.new_zeros(p_hat.size())) optimizer_d.zero_grad() loss_d.backward() optimizer_d.step()
def main(): """Run training process.""" parser = argparse.ArgumentParser( description= "Train Parallel WaveGAN (See detail in parallel_wavegan/bin/train.py)." ) parser.add_argument("--train-dumpdir", type=str, required=True, help="directory including training data.") parser.add_argument("--dev-dumpdir", type=str, required=True, help="directory including development data.") parser.add_argument("--outdir", type=str, required=True, help="directory to save checkpoints.") parser.add_argument("--config", type=str, required=True, help="yaml format configuration file.") parser.add_argument( "--resume", default="", type=str, nargs="?", help="checkpoint file path to resume training. (default=\"\")") parser.add_argument( "--verbose", type=int, default=1, help="logging level. higher is more logging. (default=1)") parser.add_argument( "--rank", "--local_rank", default=0, type=int, help="rank for distributed training. no need to explictly specify.") args = parser.parse_args() args.distributed = False if not torch.cuda.is_available(): device = torch.device("cpu") else: device = torch.device("cuda") # effective when using fixed size inputs # see https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936 torch.backends.cudnn.benchmark = True torch.cuda.set_device(args.rank) # setup for distributed training # see example: https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed if "WORLD_SIZE" in os.environ: args.world_size = int(os.environ["WORLD_SIZE"]) args.distributed = args.world_size > 1 if args.distributed: torch.distributed.init_process_group(backend="nccl", init_method="env://") # suppress logging for distributed training if args.rank != 0: sys.stdout = open(os.devnull, "w") # set logger if args.verbose > 1: logging.basicConfig( level=logging.DEBUG, stream=sys.stdout, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") elif args.verbose > 0: logging.basicConfig( level=logging.INFO, stream=sys.stdout, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") else: logging.basicConfig( level=logging.WARN, stream=sys.stdout, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") logging.warning("Skip DEBUG/INFO messages") # check directory existence if not os.path.exists(args.outdir): os.makedirs(args.outdir) # load and save config with open(args.config) as f: config = yaml.load(f, Loader=yaml.Loader) config.update(vars(args)) config["version"] = parallel_wavegan.__version__ # add version info with open(os.path.join(args.outdir, "config.yml"), "w") as f: yaml.dump(config, f, Dumper=yaml.Dumper) for key, value in config.items(): logging.info(f"{key} = {value}") # get dataset if config["remove_short_samples"]: mel_length_threshold = config["batch_max_steps"] // config["hop_size"] + \ 2 * config["generator_params"].get("aux_context_window", 0) else: mel_length_threshold = None if config["format"] == "hdf5": audio_query, mel_query = "*.h5", "*.h5" audio_load_fn = lambda x: read_hdf5(x, "wave") # NOQA mel_load_fn = lambda x: read_hdf5(x, "feats") # NOQA elif config["format"] == "npy": audio_query, mel_query = "*-wave.npy", "*-feats.npy" audio_load_fn = np.load mel_load_fn = np.load else: raise ValueError("support only hdf5 or npy format.") dataset = { "train": AudioMelDataset( root_dir=args.train_dumpdir, audio_query=audio_query, mel_query=mel_query, audio_load_fn=audio_load_fn, mel_load_fn=mel_load_fn, mel_length_threshold=mel_length_threshold, allow_cache=config.get("allow_cache", False), # keep compatibility ), "dev": AudioMelDataset( root_dir=args.dev_dumpdir, audio_query=audio_query, mel_query=mel_query, audio_load_fn=audio_load_fn, mel_load_fn=mel_load_fn, mel_length_threshold=mel_length_threshold, allow_cache=config.get("allow_cache", False), # keep compatibility ), } # get data loader collater = Collater( batch_max_steps=config["batch_max_steps"], hop_size=config["hop_size"], # keep compatibility aux_context_window=config["generator_params"].get( "aux_context_window", 0), # keep compatibility use_noise_input=config.get( "generator_type", "ParallelWaveGANGenerator") != "MelGANGenerator", ) train_sampler, dev_sampler = None, None if args.distributed: # setup sampler for distributed training from torch.utils.data.distributed import DistributedSampler train_sampler = DistributedSampler( dataset=dataset["train"], num_replicas=args.world_size, rank=args.rank, shuffle=True, ) dev_sampler = DistributedSampler( dataset=dataset["dev"], num_replicas=args.world_size, rank=args.rank, shuffle=False, ) data_loader = { "train": DataLoader( dataset=dataset["train"], shuffle=False if args.distributed else True, collate_fn=collater, batch_size=config["batch_size"], num_workers=config["num_workers"], sampler=train_sampler, pin_memory=config["pin_memory"], ), "dev": DataLoader( dataset=dataset["dev"], shuffle=False if args.distributed else True, collate_fn=collater, batch_size=config["batch_size"], num_workers=config["num_workers"], sampler=dev_sampler, pin_memory=config["pin_memory"], ), } # define models and optimizers generator_class = getattr( parallel_wavegan.models, # keep compatibility config.get("generator_type", "ParallelWaveGANGenerator"), ) discriminator_class = getattr( parallel_wavegan.models, # keep compatibility config.get("discriminator_type", "ParallelWaveGANDiscriminator"), ) model = { "generator": generator_class(**config["generator_params"]).to(device), "discriminator": discriminator_class(**config["discriminator_params"]).to(device), } criterion = { "stft": MultiResolutionSTFTLoss(**config["stft_loss_params"]).to(device), "mse": torch.nn.MSELoss().to(device), } if config.get("use_feat_match_loss", False): # keep compatibility criterion["l1"] = torch.nn.L1Loss().to(device) optimizer = { "generator": RAdam(model["generator"].parameters(), **config["generator_optimizer_params"]), "discriminator": RAdam(model["discriminator"].parameters(), **config["discriminator_optimizer_params"]), } scheduler = { "generator": torch.optim.lr_scheduler.StepLR( optimizer=optimizer["generator"], **config["generator_scheduler_params"]), "discriminator": torch.optim.lr_scheduler.StepLR( optimizer=optimizer["discriminator"], **config["discriminator_scheduler_params"]), } if args.distributed: # wrap model for distributed training try: from apex.parallel import DistributedDataParallel except ImportError: raise ImportError( "apex is not installed. please check https://github.com/NVIDIA/apex." ) model["generator"] = DistributedDataParallel(model["generator"]) model["discriminator"] = DistributedDataParallel( model["discriminator"]) logging.info(model["generator"]) logging.info(model["discriminator"]) # define trainer trainer = Trainer( steps=0, epochs=0, data_loader=data_loader, model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, config=config, device=device, ) # resume from checkpoint if len(args.resume) != 0: trainer.load_checkpoint(args.resume) logging.info(f"Successfully resumed from {args.resume}.") # run training loop try: trainer.run() except KeyboardInterrupt: trainer.save_checkpoint( os.path.join(config["outdir"], f"checkpoint-{trainer.steps}steps.pkl")) logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
def main(): """Run training process.""" parser = argparse.ArgumentParser( description= "Train Parallel WaveGAN (See detail in parallel_wavegan/bin/train.py)." ) parser.add_argument("--train-dumpdir", type=str, required=True, help="directory including trainning data.") parser.add_argument("--dev-dumpdir", type=str, required=True, help="directory including development data.") parser.add_argument("--outdir", type=str, required=True, help="directory to save checkpoints.") parser.add_argument("--config", type=str, required=True, help="yaml format configuration file.") parser.add_argument( "--resume", default="", type=str, nargs="?", help="checkpoint file path to resume training. (default=\"\")") parser.add_argument( "--verbose", type=int, default=1, help="logging level. higher is more logging. (default=1)") args = parser.parse_args() # set logger if args.verbose > 1: logging.basicConfig( level=logging.DEBUG, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") elif args.verbose > 0: logging.basicConfig( level=logging.INFO, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") else: logging.basicConfig( level=logging.WARN, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") logging.warning('skip DEBUG/INFO messages') # check directory existence if not os.path.exists(args.outdir): os.makedirs(args.outdir) # load and save config with open(args.config) as f: config = yaml.load(f, Loader=yaml.Loader) config.update(vars(args)) with open(os.path.join(args.outdir, "config.yml"), "w") as f: yaml.dump(config, f, Dumper=yaml.Dumper) for key, value in config.items(): logging.info(f"{key} = {value}") # get dataset if config["remove_short_samples"]: mel_length_threshold = config["batch_max_steps"] // config["hop_size"] + \ 2 * config["generator_params"]["aux_context_window"] else: mel_length_threshold = None if config["format"] == "hdf5": audio_query, mel_query = "*.h5", "*.h5" audio_load_fn = lambda x: read_hdf5(x, "wave") # NOQA mel_load_fn = lambda x: read_hdf5(x, "feats") # NOQA elif config["format"] == "npy": audio_query, mel_query = "*-wave.npy", "*-feats.npy" audio_load_fn = np.load mel_load_fn = np.load else: raise ValueError("support only hdf5 or npy format.") dataset = { "train": AudioMelDataset( root_dir=args.train_dumpdir, audio_query=audio_query, mel_query=mel_query, audio_load_fn=audio_load_fn, mel_load_fn=mel_load_fn, mel_length_threshold=mel_length_threshold, allow_cache=config.get("allow_cache", False), # keep compatibilty ), "dev": AudioMelDataset( root_dir=args.dev_dumpdir, audio_query=audio_query, mel_query=mel_query, audio_load_fn=audio_load_fn, mel_load_fn=mel_load_fn, mel_length_threshold=mel_length_threshold, allow_cache=config.get("allow_cache", False), # keep compatibilty ), } # get data loader if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") collater = Collater( batch_max_steps=config["batch_max_steps"], hop_size=config["hop_size"], aux_context_window=config["generator_params"]["aux_context_window"], ) data_loader = { "train": DataLoader(dataset=dataset["train"], shuffle=True, collate_fn=collater, batch_size=config["batch_size"], num_workers=config["num_workers"], pin_memory=config["pin_memory"]), "dev": DataLoader(dataset=dataset["dev"], shuffle=True, collate_fn=collater, batch_size=config["batch_size"], num_workers=config["num_workers"], pin_memory=config["pin_memory"]), } # define models and optimizers model = { "generator": ParallelWaveGANGenerator(**config["generator_params"]).to(device), "discriminator": ParallelWaveGANDiscriminator( **config["discriminator_params"]).to(device), } criterion = { "stft": MultiResolutionSTFTLoss(**config["stft_loss_params"]).to(device), "mse": torch.nn.MSELoss().to(device), } optimizer = { "generator": RAdam(model["generator"].parameters(), **config["generator_optimizer_params"]), "discriminator": RAdam(model["discriminator"].parameters(), **config["discriminator_optimizer_params"]), } scheduler = { "generator": torch.optim.lr_scheduler.StepLR( optimizer=optimizer["generator"], **config["generator_scheduler_params"]), "discriminator": torch.optim.lr_scheduler.StepLR( optimizer=optimizer["discriminator"], **config["discriminator_scheduler_params"]), } logging.info(model["generator"]) logging.info(model["discriminator"]) # define trainer trainer = Trainer( steps=0, epochs=0, data_loader=data_loader, model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, config=config, device=device, ) # resume from checkpoint if len(args.resume) != 0: trainer.load_checkpoint(args.resume) logging.info(f"resumed from {args.resume}.") # run training loop try: trainer.run() finally: trainer.save_checkpoint( os.path.join(config["outdir"], f"checkpoint-{trainer.steps}steps.pkl")) logging.info(f"successfully saved checkpoint @ {trainer.steps}steps.")
def main(): """The main function that runs training process.""" # initialize the argument parser parser = argparse.ArgumentParser(description="Train Parallel WaveGAN.") # just a description of the job that the parser is used to support. # Add arguments to the parser #first is name of the argument #default: The value produced if the argument is absent from the command line #type: The type to which the command-line argument should be converted #help: hint that appears when the user doesnot know what is this argument [-h] #required: Whether or not the command-line option may be omitted (optionals only) #nargs:The number of command-line arguments that should be consumed # "?" One argument will be consumed from the command line if possible, and produced as a single item. If no command-line argument is present, # the value from default will be produced. # Note that for optional arguments, there is an additional case - the option string is present but not followed by a command-line argument. In this case the value from const will be produced. parser.add_argument("--train-wav-scp", default=None, type=str, help="kaldi-style wav.scp file for training. " "you need to specify either train-*-scp or train-dumpdir.") parser.add_argument("--train-feats-scp", default=None, type=str, help="kaldi-style feats.scp file for training. " "you need to specify either train-*-scp or train-dumpdir.") parser.add_argument("--train-segments", default=None, type=str, help="kaldi-style segments file for training.") parser.add_argument("--train-dumpdir", default=None, type=str, help="directory including training data. " "you need to specify either train-*-scp or train-dumpdir.") parser.add_argument("--dev-wav-scp", default=None, type=str, help="kaldi-style wav.scp file for validation. " "you need to specify either dev-*-scp or dev-dumpdir.") parser.add_argument("--dev-feats-scp", default=None, type=str, help="kaldi-style feats.scp file for vaidation. " "you need to specify either dev-*-scp or dev-dumpdir.") parser.add_argument("--dev-segments", default=None, type=str, help="kaldi-style segments file for validation.") parser.add_argument("--dev-dumpdir", default=None, type=str, help="directory including development data. " "you need to specify either dev-*-scp or dev-dumpdir.") parser.add_argument("--outdir", type=str, required=True, help="directory to save checkpoints.") parser.add_argument("--config", type=str, required=True, help="yaml format configuration file.") parser.add_argument("--pretrain", default="", type=str, nargs="?", help="checkpoint file path to load pretrained params. (default=\"\")") parser.add_argument("--resume", default="", type=str, nargs="?", help="checkpoint file path to resume training. (default=\"\")") parser.add_argument("--verbose", type=int, default=1, help="logging level. higher is more logging. (default=1)") parser.add_argument("--rank", "--local_rank", default=0, type=int, help="rank for distributed training. no need to explictly specify.") # parse all the input arguments args = parser.parse_args() args.distributed = False if not torch.cuda.is_available(): #if gpu is not available device = torch.device("cpu") #train on cpu else: #GPU device = torch.device("cuda")#train on gpu torch.backends.cudnn.benchmark = True # effective when using fixed size inputs (no conditional layers or layers inside loops),benchmark mode in cudnn,faster runtime torch.cuda.set_device(args.rank) # sets the default GPU for distributed training if "WORLD_SIZE" in os.environ:#determine max number of parallel processes (distributed) args.world_size = int(os.environ["WORLD_SIZE"]) #get the world size from the os args.distributed = args.world_size > 1 #set distributed if woldsize > 1 if args.distributed: torch.distributed.init_process_group(backend="nccl", init_method="env://") #Use the NCCL backend for distributed GPU training (Rule of thumb) #NCCL:since it currently provides the best distributed GPU training performance, especially for multiprocess single-node or multi-node distributed training # suppress logging for distributed training if args.rank != 0: #if process is not p0 sys.stdout = open(os.devnull, "w")#DEVNULL is Special value that can be used as the stdin, stdout or stderr argument to # set logger if args.verbose > 1: #if level of logging is heigher then 1 logging.basicConfig( #configure the logging level=logging.DEBUG, stream=sys.stdout, #heigh logging level,detailed information, typically of interest only when diagnosing problems. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") #format includes Time,module,line#,level,and message. elif args.verbose > 0:#if level of logging is between 0,1 logging.basicConfig(#configure the logging level=logging.INFO, stream=sys.stdout,#moderate logging level,Confirmation that things are working as expected. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")#format includes Time,module,line#,level,and message. else:#if level of logging is 0 logging.basicConfig(#configure the logging level=logging.WARN, stream=sys.stdout,#low logging level,An indication that something unexpected happened, or indicative of some problem in the near future (e.g. ‘disk space low’). format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")#format includes Time,module,line#,level,and message. logging.warning("Skip DEBUG/INFO messages")#tell the user that he will skip logging DEBUG/INFO messages by choosing this level. # check directory existence if not os.path.exists(args.outdir): #directory to save checkpoints os.makedirs(args.outdir) # check arguments if (args.train_feats_scp is not None and args.train_dumpdir is not None) or \ (args.train_feats_scp is None and args.train_dumpdir is None): # if the user chooses both training data files (examples) or # the user doesnot choose any training data file raise ValueError("Please specify either --train-dumpdir or --train-*-scp.") #raise an error to tell the user to choose one training file if (args.dev_feats_scp is not None and args.dev_dumpdir is not None) or \ (args.dev_feats_scp is None and args.dev_dumpdir is None): # if the user chooses both validatation data files (examples) or # the user doesnot choose any validatation data file raise ValueError("Please specify either --dev-dumpdir or --dev-*-scp.") #raise an error to tell the user to choose one validation data file # load config with open(args.config) as f:#open configuration file (yaml format) config = yaml.load(f, Loader=yaml.Loader) #load configuration file (yaml format to python object) # update config config.update(vars(args))#update arguments in configuration file config["version"] = parallel_wavegan.__version__ # add parallel wavegan version info # save config with open(os.path.join(args.outdir, "config.yml"), "w") as f:#open outdir/config.yml yaml.dump(config, f, Dumper=yaml.Dumper) #dump function accepts a Python object and produces a YAML document. # add config info to the high level logger. for key, value in config.items(): logging.info(f"{key} = {value}") # get dataset if config["remove_short_samples"]:#if configuration tells to remove short samples from training. mel_length_threshold = config["batch_max_steps"] // config["hop_size"] + \ 2 * config["generator_params"].get("aux_context_window", 0)#th of length = floor(batch_max_steps/hop_size) + 2 * (generator_params.aux_context_window) else: mel_length_threshold = None # No th. if args.train_wav_scp is None or args.dev_wav_scp is None: #if at least one of training or evaluating datasets = None if config["format"] == "hdf5":# format of data = hdf5 audio_query, mel_query = "*.h5", "*.h5" # audio and text queries = "...".h5 #lambda example: #x = lambda a, b: a * b #x(5, 6)-->x(a=5,b=6)=a*b=5*6=30 audio_load_fn = lambda x: read_hdf5(x, "wave") # The function to load data,NOQA mel_load_fn = lambda x: read_hdf5(x, "feats") # The function to load data,NOQA elif config["format"] == "npy":# format of data = npy audio_query, mel_query = "*-wave.npy", "*-feats.npy" #audio query = "..."-wave.npy and text query = "..."-feats.h5 audio_load_fn = np.load#The function to load data. mel_load_fn = np.load#The function to load data. else:#if any other data format raise ValueError("support only hdf5 or npy format.") #raise error to tell the user the data format is not supported. if args.train_dumpdir is not None: # if training ds is not None train_dataset = AudioMelDataset( # define the training dataset root_dir=args.train_dumpdir,#the directory of ds. audio_query=audio_query,#audio query according to format above. mel_query=mel_query,#mel query according to format above. audio_load_fn=audio_load_fn,#load the function that loads the audio data according to format above. mel_load_fn=mel_load_fn,#load the function that loads the mel data according to format above. mel_length_threshold=mel_length_threshold,#th to remove short samples -calculated above-. allow_cache=config.get("allow_cache", False), # keep compatibility. ) else:# if training ds is None train_dataset = AudioMelSCPDataset(# define the training dataset wav_scp=args.train_wav_scp, feats_scp=args.train_feats_scp, segments=args.train_segments, #segments of dataset mel_length_threshold=mel_length_threshold,#th to remove short samples -calculated above-. allow_cache=config.get("allow_cache", False), # keep compatibility ) logging.info(f"The number of training files = {len(train_dataset)}.") # add length of trainning data set to the logger. if args.dev_dumpdir is not None: #if evaluating ds is not None dev_dataset = AudioMelDataset( # define the evaluating dataset root_dir=args.dev_dumpdir,#the directory of ds. audio_query=audio_query,#audio query according to format above. mel_query=mel_query,#mel query according to format above. audio_load_fn=audio_load_fn,#load the function that loads the audio data according to format above. mel_load_fn=mel_load_fn,#load the function that loads the mel data according to format above. mel_length_threshold=mel_length_threshold,#th to remove short samples -calculated above-. allow_cache=config.get("allow_cache", False), # keep compatibility ) else:# if evaluating ds is None dev_dataset = AudioMelSCPDataset( wav_scp=args.dev_wav_scp, feats_scp=args.dev_feats_scp, segments=args.dev_segments,#segments of dataset mel_length_threshold=mel_length_threshold,#th to remove short samples -calculated above-. allow_cache=config.get("allow_cache", False), # keep compatibility ) logging.info(f"The number of development files = {len(dev_dataset)}.") # add length of evaluating data set to the logger. dataset = { "train": train_dataset, "dev": dev_dataset, } #define the whole dataset used which is divided into training and evaluating datasets # get data loader collater = Collater( batch_max_steps=config["batch_max_steps"], hop_size=config["hop_size"], # keep compatibility aux_context_window=config["generator_params"].get("aux_context_window", 0), # keep compatibility use_noise_input=config.get( "generator_type", "ParallelWaveGANGenerator") != "MelGANGenerator", ) train_sampler, dev_sampler = None, None if args.distributed: # setup sampler for distributed training from torch.utils.data.distributed import DistributedSampler train_sampler = DistributedSampler( dataset=dataset["train"], num_replicas=args.world_size, rank=args.rank, shuffle=True, ) dev_sampler = DistributedSampler( dataset=dataset["dev"], num_replicas=args.world_size, rank=args.rank, shuffle=False, ) data_loader = { "train": DataLoader( dataset=dataset["train"], shuffle=False if args.distributed else True, collate_fn=collater, batch_size=config["batch_size"], num_workers=config["num_workers"], sampler=train_sampler, pin_memory=config["pin_memory"], ), "dev": DataLoader( dataset=dataset["dev"], shuffle=False if args.distributed else True, collate_fn=collater, batch_size=config["batch_size"], num_workers=config["num_workers"], sampler=dev_sampler, pin_memory=config["pin_memory"], ), } # define models and optimizers generator_class = getattr( parallel_wavegan.models, # keep compatibility config.get("generator_type", "ParallelWaveGANGenerator"), ) discriminator_class = getattr( parallel_wavegan.models, # keep compatibility config.get("discriminator_type", "ParallelWaveGANDiscriminator"), ) model = { "generator": generator_class( **config["generator_params"]).to(device), "discriminator": discriminator_class( **config["discriminator_params"]).to(device), } criterion = { "stft": MultiResolutionSTFTLoss( **config["stft_loss_params"]).to(device), "mse": torch.nn.MSELoss().to(device), } if config.get("use_feat_match_loss", False): # keep compatibility criterion["l1"] = torch.nn.L1Loss().to(device) optimizer = { "generator": RAdam( model["generator"].parameters(), **config["generator_optimizer_params"]), "discriminator": RAdam( model["discriminator"].parameters(), **config["discriminator_optimizer_params"]), } scheduler = { "generator": torch.optim.lr_scheduler.StepLR( optimizer=optimizer["generator"], **config["generator_scheduler_params"]), "discriminator": torch.optim.lr_scheduler.StepLR( optimizer=optimizer["discriminator"], **config["discriminator_scheduler_params"]), } if args.distributed: # wrap model for distributed training try: from apex.parallel import DistributedDataParallel except ImportError: raise ImportError("apex is not installed. please check https://github.com/NVIDIA/apex.") model["generator"] = DistributedDataParallel(model["generator"]) model["discriminator"] = DistributedDataParallel(model["discriminator"]) logging.info(model["generator"]) logging.info(model["discriminator"]) # define trainer trainer = Trainer( steps=0, epochs=0, data_loader=data_loader, model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, config=config, device=device, ) # load pretrained parameters from checkpoint if len(args.pretrain) != 0: trainer.load_checkpoint(args.pretrain, load_only_params=True) logging.info(f"Successfully load parameters from {args.pretrain}.") # resume from checkpoint if len(args.resume) != 0: trainer.load_checkpoint(args.resume) logging.info(f"Successfully resumed from {args.resume}.") # run training loop try: trainer.run() except KeyboardInterrupt: trainer.save_checkpoint( os.path.join(config["outdir"], f"checkpoint-{trainer.steps}steps.pkl")) logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")