def main(): """Run preprocessing process.""" parser = argparse.ArgumentParser( description= "Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py)." ) parser.add_argument( "--rootdir", type=str, required=True, help="directory including feature files to be normalized.") parser.add_argument("--dumpdir", type=str, required=True, help="directory to dump normalized feature files.") parser.add_argument("--stats", type=str, required=True, help="statistics file.") parser.add_argument("--config", type=str, required=True, help="yaml format configuration file.") parser.add_argument("--n_jobs", type=int, default=16, help="number of parallel jobs. (default=16)") parser.add_argument( "--verbose", type=int, default=1, help="logging level. higher is more logging. (default=1)") args = parser.parse_args() # set logger if args.verbose > 1: logging.basicConfig( level=logging.DEBUG, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") elif args.verbose > 0: logging.basicConfig( level=logging.INFO, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") else: logging.basicConfig( level=logging.WARN, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") logging.warning('Skip DEBUG/INFO messages') # load config with open(args.config) as f: config = yaml.load(f, Loader=yaml.Loader) config.update(vars(args)) # check directory existence or mkdir new one if not os.path.exists(args.dumpdir): os.makedirs(args.dumpdir) # get dataset if config["format"] == "hdf5": audio_query, mel_query = "*.h5", "*.h5" audio_load_fn = lambda x: read_hdf5(x, "wave") # NOQA mel_load_fn = lambda x: read_hdf5(x, "feats") # NOQA elif config["format"] == "npy": audio_query, mel_query = "*-wave.npy", "*-feats.npy" audio_load_fn = np.load mel_load_fn = np.load else: raise ValueError("support only hdf5 or npy format.") dataset = AudioMelDataset( root_dir=args.rootdir, audio_query=audio_query, mel_query=mel_query, audio_load_fn=audio_load_fn, mel_load_fn=mel_load_fn, return_filename=True, ) logging.info(f"The number of files = {len(dataset)}.") # restore scaler scaler = StandardScaler() if config["format"] == "hdf5": scaler.mean_ = read_hdf5(args.stats, "mean") scaler.scale_ = read_hdf5(args.stats, "scale") elif config["format"] == "npy": scaler.mean_ = np.load(args.stats)[0] scaler.scale_ = np.load(args.stats)[1] else: raise ValueError("support only hdf5 or npy format.") def _process_single_file(data): # parse inputs for each audio audio_name, mel_name, audio, mel = data # normalize """Scale features of X according to feature_range. mel *= self.scale_ mel += self.min_ """ mel = scaler.transform(mel) # save if config["format"] == "hdf5": write_hdf5( os.path.join(args.dumpdir, f"{os.path.basename(audio_name)}"), "wave", audio.astype(np.float32)) write_hdf5( os.path.join(args.dumpdir, f"{os.path.basename(mel_name)}"), "feats", mel.astype(np.float32)) elif config["format"] == "npy": np.save(os.path.join(args.dumpdir, f"{os.path.basename(audio_name)}"), audio.astype(np.float32), allow_pickle=False) np.save(os.path.join(args.dumpdir, f"{os.path.basename(mel_name)}"), mel.astype(np.float32), allow_pickle=False) else: raise ValueError("support only hdf5 or npy format.") # process in parallel """delayed => Decorator used to capture the arguments of a function.""" Parallel(n_jobs=args.n_jobs, verbose=args.verbose)( [delayed(_process_single_file)(data) for data in tqdm(dataset)])
def main(): """Run training process.""" parser = argparse.ArgumentParser( description= "Train Parallel WaveGAN (See detail in parallel_wavegan/bin/train.py)." ) parser.add_argument("--train-dumpdir", type=str, required=True, help="directory including training data.") parser.add_argument("--dev-dumpdir", type=str, required=True, help="directory including development data.") parser.add_argument("--outdir", type=str, required=True, help="directory to save checkpoints.") parser.add_argument("--config", type=str, required=True, help="yaml format configuration file.") parser.add_argument( "--resume", default="", type=str, nargs="?", help="checkpoint file path to resume training. (default=\"\")") parser.add_argument( "--verbose", type=int, default=1, help="logging level. higher is more logging. (default=1)") parser.add_argument( "--rank", "--local_rank", default=0, type=int, help="rank for distributed training. no need to explictly specify.") args = parser.parse_args() args.distributed = False if not torch.cuda.is_available(): device = torch.device("cpu") else: device = torch.device("cuda") # effective when using fixed size inputs # see https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936 torch.backends.cudnn.benchmark = True torch.cuda.set_device(args.rank) # setup for distributed training # see example: https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed if "WORLD_SIZE" in os.environ: args.world_size = int(os.environ["WORLD_SIZE"]) args.distributed = args.world_size > 1 if args.distributed: torch.distributed.init_process_group(backend="nccl", init_method="env://") # suppress logging for distributed training if args.rank != 0: sys.stdout = open(os.devnull, "w") # set logger if args.verbose > 1: logging.basicConfig( level=logging.DEBUG, stream=sys.stdout, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") elif args.verbose > 0: logging.basicConfig( level=logging.INFO, stream=sys.stdout, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") else: logging.basicConfig( level=logging.WARN, stream=sys.stdout, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") logging.warning("Skip DEBUG/INFO messages") # check directory existence if not os.path.exists(args.outdir): os.makedirs(args.outdir) # load and save config with open(args.config) as f: config = yaml.load(f, Loader=yaml.Loader) config.update(vars(args)) config["version"] = parallel_wavegan.__version__ # add version info with open(os.path.join(args.outdir, "config.yml"), "w") as f: yaml.dump(config, f, Dumper=yaml.Dumper) for key, value in config.items(): logging.info(f"{key} = {value}") # get dataset if config["remove_short_samples"]: mel_length_threshold = config["batch_max_steps"] // config["hop_size"] + \ 2 * config["generator_params"].get("aux_context_window", 0) else: mel_length_threshold = None if config["format"] == "hdf5": audio_query, mel_query = "*.h5", "*.h5" audio_load_fn = lambda x: read_hdf5(x, "wave") # NOQA mel_load_fn = lambda x: read_hdf5(x, "feats") # NOQA elif config["format"] == "npy": audio_query, mel_query = "*-wave.npy", "*-feats.npy" audio_load_fn = np.load mel_load_fn = np.load else: raise ValueError("support only hdf5 or npy format.") dataset = { "train": AudioMelDataset( root_dir=args.train_dumpdir, audio_query=audio_query, mel_query=mel_query, audio_load_fn=audio_load_fn, mel_load_fn=mel_load_fn, mel_length_threshold=mel_length_threshold, allow_cache=config.get("allow_cache", False), # keep compatibility ), "dev": AudioMelDataset( root_dir=args.dev_dumpdir, audio_query=audio_query, mel_query=mel_query, audio_load_fn=audio_load_fn, mel_load_fn=mel_load_fn, mel_length_threshold=mel_length_threshold, allow_cache=config.get("allow_cache", False), # keep compatibility ), } # get data loader collater = Collater( batch_max_steps=config["batch_max_steps"], hop_size=config["hop_size"], # keep compatibility aux_context_window=config["generator_params"].get( "aux_context_window", 0), # keep compatibility use_noise_input=config.get( "generator_type", "ParallelWaveGANGenerator") != "MelGANGenerator", ) train_sampler, dev_sampler = None, None if args.distributed: # setup sampler for distributed training from torch.utils.data.distributed import DistributedSampler train_sampler = DistributedSampler( dataset=dataset["train"], num_replicas=args.world_size, rank=args.rank, shuffle=True, ) dev_sampler = DistributedSampler( dataset=dataset["dev"], num_replicas=args.world_size, rank=args.rank, shuffle=False, ) data_loader = { "train": DataLoader( dataset=dataset["train"], shuffle=False if args.distributed else True, collate_fn=collater, batch_size=config["batch_size"], num_workers=config["num_workers"], sampler=train_sampler, pin_memory=config["pin_memory"], ), "dev": DataLoader( dataset=dataset["dev"], shuffle=False if args.distributed else True, collate_fn=collater, batch_size=config["batch_size"], num_workers=config["num_workers"], sampler=dev_sampler, pin_memory=config["pin_memory"], ), } # define models and optimizers generator_class = getattr( parallel_wavegan.models, # keep compatibility config.get("generator_type", "ParallelWaveGANGenerator"), ) discriminator_class = getattr( parallel_wavegan.models, # keep compatibility config.get("discriminator_type", "ParallelWaveGANDiscriminator"), ) model = { "generator": generator_class(**config["generator_params"]).to(device), "discriminator": discriminator_class(**config["discriminator_params"]).to(device), } criterion = { "stft": MultiResolutionSTFTLoss(**config["stft_loss_params"]).to(device), "mse": torch.nn.MSELoss().to(device), } if config.get("use_feat_match_loss", False): # keep compatibility criterion["l1"] = torch.nn.L1Loss().to(device) optimizer = { "generator": RAdam(model["generator"].parameters(), **config["generator_optimizer_params"]), "discriminator": RAdam(model["discriminator"].parameters(), **config["discriminator_optimizer_params"]), } scheduler = { "generator": torch.optim.lr_scheduler.StepLR( optimizer=optimizer["generator"], **config["generator_scheduler_params"]), "discriminator": torch.optim.lr_scheduler.StepLR( optimizer=optimizer["discriminator"], **config["discriminator_scheduler_params"]), } if args.distributed: # wrap model for distributed training try: from apex.parallel import DistributedDataParallel except ImportError: raise ImportError( "apex is not installed. please check https://github.com/NVIDIA/apex." ) model["generator"] = DistributedDataParallel(model["generator"]) model["discriminator"] = DistributedDataParallel( model["discriminator"]) logging.info(model["generator"]) logging.info(model["discriminator"]) # define trainer trainer = Trainer( steps=0, epochs=0, data_loader=data_loader, model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, config=config, device=device, ) # resume from checkpoint if len(args.resume) != 0: trainer.load_checkpoint(args.resume) logging.info(f"Successfully resumed from {args.resume}.") # run training loop try: trainer.run() except KeyboardInterrupt: trainer.save_checkpoint( os.path.join(config["outdir"], f"checkpoint-{trainer.steps}steps.pkl")) logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
def main(): """Run training process.""" parser = argparse.ArgumentParser( description= "Train Parallel WaveGAN (See detail in parallel_wavegan/bin/train.py)." ) parser.add_argument("--train-dumpdir", type=str, required=True, help="directory including trainning data.") parser.add_argument("--dev-dumpdir", type=str, required=True, help="directory including development data.") parser.add_argument("--outdir", type=str, required=True, help="directory to save checkpoints.") parser.add_argument("--config", type=str, required=True, help="yaml format configuration file.") parser.add_argument( "--resume", default="", type=str, nargs="?", help="checkpoint file path to resume training. (default=\"\")") parser.add_argument( "--verbose", type=int, default=1, help="logging level. higher is more logging. (default=1)") args = parser.parse_args() # set logger if args.verbose > 1: logging.basicConfig( level=logging.DEBUG, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") elif args.verbose > 0: logging.basicConfig( level=logging.INFO, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") else: logging.basicConfig( level=logging.WARN, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") logging.warning('skip DEBUG/INFO messages') # check directory existence if not os.path.exists(args.outdir): os.makedirs(args.outdir) # load and save config with open(args.config) as f: config = yaml.load(f, Loader=yaml.Loader) config.update(vars(args)) with open(os.path.join(args.outdir, "config.yml"), "w") as f: yaml.dump(config, f, Dumper=yaml.Dumper) for key, value in config.items(): logging.info(f"{key} = {value}") # get dataset if config["remove_short_samples"]: mel_length_threshold = config["batch_max_steps"] // config["hop_size"] + \ 2 * config["generator_params"]["aux_context_window"] else: mel_length_threshold = None if config["format"] == "hdf5": audio_query, mel_query = "*.h5", "*.h5" audio_load_fn = lambda x: read_hdf5(x, "wave") # NOQA mel_load_fn = lambda x: read_hdf5(x, "feats") # NOQA elif config["format"] == "npy": audio_query, mel_query = "*-wave.npy", "*-feats.npy" audio_load_fn = np.load mel_load_fn = np.load else: raise ValueError("support only hdf5 or npy format.") dataset = { "train": AudioMelDataset( root_dir=args.train_dumpdir, audio_query=audio_query, mel_query=mel_query, audio_load_fn=audio_load_fn, mel_load_fn=mel_load_fn, mel_length_threshold=mel_length_threshold, allow_cache=config.get("allow_cache", False), # keep compatibilty ), "dev": AudioMelDataset( root_dir=args.dev_dumpdir, audio_query=audio_query, mel_query=mel_query, audio_load_fn=audio_load_fn, mel_load_fn=mel_load_fn, mel_length_threshold=mel_length_threshold, allow_cache=config.get("allow_cache", False), # keep compatibilty ), } # get data loader if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") collater = Collater( batch_max_steps=config["batch_max_steps"], hop_size=config["hop_size"], aux_context_window=config["generator_params"]["aux_context_window"], ) data_loader = { "train": DataLoader(dataset=dataset["train"], shuffle=True, collate_fn=collater, batch_size=config["batch_size"], num_workers=config["num_workers"], pin_memory=config["pin_memory"]), "dev": DataLoader(dataset=dataset["dev"], shuffle=True, collate_fn=collater, batch_size=config["batch_size"], num_workers=config["num_workers"], pin_memory=config["pin_memory"]), } # define models and optimizers model = { "generator": ParallelWaveGANGenerator(**config["generator_params"]).to(device), "discriminator": ParallelWaveGANDiscriminator( **config["discriminator_params"]).to(device), } criterion = { "stft": MultiResolutionSTFTLoss(**config["stft_loss_params"]).to(device), "mse": torch.nn.MSELoss().to(device), } optimizer = { "generator": RAdam(model["generator"].parameters(), **config["generator_optimizer_params"]), "discriminator": RAdam(model["discriminator"].parameters(), **config["discriminator_optimizer_params"]), } scheduler = { "generator": torch.optim.lr_scheduler.StepLR( optimizer=optimizer["generator"], **config["generator_scheduler_params"]), "discriminator": torch.optim.lr_scheduler.StepLR( optimizer=optimizer["discriminator"], **config["discriminator_scheduler_params"]), } logging.info(model["generator"]) logging.info(model["discriminator"]) # define trainer trainer = Trainer( steps=0, epochs=0, data_loader=data_loader, model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, config=config, device=device, ) # resume from checkpoint if len(args.resume) != 0: trainer.load_checkpoint(args.resume) logging.info(f"resumed from {args.resume}.") # run training loop try: trainer.run() finally: trainer.save_checkpoint( os.path.join(config["outdir"], f"checkpoint-{trainer.steps}steps.pkl")) logging.info(f"successfully saved checkpoint @ {trainer.steps}steps.")
def main(): """The main function that runs training process.""" # initialize the argument parser parser = argparse.ArgumentParser(description="Train Parallel WaveGAN.") # just a description of the job that the parser is used to support. # Add arguments to the parser #first is name of the argument #default: The value produced if the argument is absent from the command line #type: The type to which the command-line argument should be converted #help: hint that appears when the user doesnot know what is this argument [-h] #required: Whether or not the command-line option may be omitted (optionals only) #nargs:The number of command-line arguments that should be consumed # "?" One argument will be consumed from the command line if possible, and produced as a single item. If no command-line argument is present, # the value from default will be produced. # Note that for optional arguments, there is an additional case - the option string is present but not followed by a command-line argument. In this case the value from const will be produced. parser.add_argument("--train-wav-scp", default=None, type=str, help="kaldi-style wav.scp file for training. " "you need to specify either train-*-scp or train-dumpdir.") parser.add_argument("--train-feats-scp", default=None, type=str, help="kaldi-style feats.scp file for training. " "you need to specify either train-*-scp or train-dumpdir.") parser.add_argument("--train-segments", default=None, type=str, help="kaldi-style segments file for training.") parser.add_argument("--train-dumpdir", default=None, type=str, help="directory including training data. " "you need to specify either train-*-scp or train-dumpdir.") parser.add_argument("--dev-wav-scp", default=None, type=str, help="kaldi-style wav.scp file for validation. " "you need to specify either dev-*-scp or dev-dumpdir.") parser.add_argument("--dev-feats-scp", default=None, type=str, help="kaldi-style feats.scp file for vaidation. " "you need to specify either dev-*-scp or dev-dumpdir.") parser.add_argument("--dev-segments", default=None, type=str, help="kaldi-style segments file for validation.") parser.add_argument("--dev-dumpdir", default=None, type=str, help="directory including development data. " "you need to specify either dev-*-scp or dev-dumpdir.") parser.add_argument("--outdir", type=str, required=True, help="directory to save checkpoints.") parser.add_argument("--config", type=str, required=True, help="yaml format configuration file.") parser.add_argument("--pretrain", default="", type=str, nargs="?", help="checkpoint file path to load pretrained params. (default=\"\")") parser.add_argument("--resume", default="", type=str, nargs="?", help="checkpoint file path to resume training. (default=\"\")") parser.add_argument("--verbose", type=int, default=1, help="logging level. higher is more logging. (default=1)") parser.add_argument("--rank", "--local_rank", default=0, type=int, help="rank for distributed training. no need to explictly specify.") # parse all the input arguments args = parser.parse_args() args.distributed = False if not torch.cuda.is_available(): #if gpu is not available device = torch.device("cpu") #train on cpu else: #GPU device = torch.device("cuda")#train on gpu torch.backends.cudnn.benchmark = True # effective when using fixed size inputs (no conditional layers or layers inside loops),benchmark mode in cudnn,faster runtime torch.cuda.set_device(args.rank) # sets the default GPU for distributed training if "WORLD_SIZE" in os.environ:#determine max number of parallel processes (distributed) args.world_size = int(os.environ["WORLD_SIZE"]) #get the world size from the os args.distributed = args.world_size > 1 #set distributed if woldsize > 1 if args.distributed: torch.distributed.init_process_group(backend="nccl", init_method="env://") #Use the NCCL backend for distributed GPU training (Rule of thumb) #NCCL:since it currently provides the best distributed GPU training performance, especially for multiprocess single-node or multi-node distributed training # suppress logging for distributed training if args.rank != 0: #if process is not p0 sys.stdout = open(os.devnull, "w")#DEVNULL is Special value that can be used as the stdin, stdout or stderr argument to # set logger if args.verbose > 1: #if level of logging is heigher then 1 logging.basicConfig( #configure the logging level=logging.DEBUG, stream=sys.stdout, #heigh logging level,detailed information, typically of interest only when diagnosing problems. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") #format includes Time,module,line#,level,and message. elif args.verbose > 0:#if level of logging is between 0,1 logging.basicConfig(#configure the logging level=logging.INFO, stream=sys.stdout,#moderate logging level,Confirmation that things are working as expected. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")#format includes Time,module,line#,level,and message. else:#if level of logging is 0 logging.basicConfig(#configure the logging level=logging.WARN, stream=sys.stdout,#low logging level,An indication that something unexpected happened, or indicative of some problem in the near future (e.g. ‘disk space low’). format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")#format includes Time,module,line#,level,and message. logging.warning("Skip DEBUG/INFO messages")#tell the user that he will skip logging DEBUG/INFO messages by choosing this level. # check directory existence if not os.path.exists(args.outdir): #directory to save checkpoints os.makedirs(args.outdir) # check arguments if (args.train_feats_scp is not None and args.train_dumpdir is not None) or \ (args.train_feats_scp is None and args.train_dumpdir is None): # if the user chooses both training data files (examples) or # the user doesnot choose any training data file raise ValueError("Please specify either --train-dumpdir or --train-*-scp.") #raise an error to tell the user to choose one training file if (args.dev_feats_scp is not None and args.dev_dumpdir is not None) or \ (args.dev_feats_scp is None and args.dev_dumpdir is None): # if the user chooses both validatation data files (examples) or # the user doesnot choose any validatation data file raise ValueError("Please specify either --dev-dumpdir or --dev-*-scp.") #raise an error to tell the user to choose one validation data file # load config with open(args.config) as f:#open configuration file (yaml format) config = yaml.load(f, Loader=yaml.Loader) #load configuration file (yaml format to python object) # update config config.update(vars(args))#update arguments in configuration file config["version"] = parallel_wavegan.__version__ # add parallel wavegan version info # save config with open(os.path.join(args.outdir, "config.yml"), "w") as f:#open outdir/config.yml yaml.dump(config, f, Dumper=yaml.Dumper) #dump function accepts a Python object and produces a YAML document. # add config info to the high level logger. for key, value in config.items(): logging.info(f"{key} = {value}") # get dataset if config["remove_short_samples"]:#if configuration tells to remove short samples from training. mel_length_threshold = config["batch_max_steps"] // config["hop_size"] + \ 2 * config["generator_params"].get("aux_context_window", 0)#th of length = floor(batch_max_steps/hop_size) + 2 * (generator_params.aux_context_window) else: mel_length_threshold = None # No th. if args.train_wav_scp is None or args.dev_wav_scp is None: #if at least one of training or evaluating datasets = None if config["format"] == "hdf5":# format of data = hdf5 audio_query, mel_query = "*.h5", "*.h5" # audio and text queries = "...".h5 #lambda example: #x = lambda a, b: a * b #x(5, 6)-->x(a=5,b=6)=a*b=5*6=30 audio_load_fn = lambda x: read_hdf5(x, "wave") # The function to load data,NOQA mel_load_fn = lambda x: read_hdf5(x, "feats") # The function to load data,NOQA elif config["format"] == "npy":# format of data = npy audio_query, mel_query = "*-wave.npy", "*-feats.npy" #audio query = "..."-wave.npy and text query = "..."-feats.h5 audio_load_fn = np.load#The function to load data. mel_load_fn = np.load#The function to load data. else:#if any other data format raise ValueError("support only hdf5 or npy format.") #raise error to tell the user the data format is not supported. if args.train_dumpdir is not None: # if training ds is not None train_dataset = AudioMelDataset( # define the training dataset root_dir=args.train_dumpdir,#the directory of ds. audio_query=audio_query,#audio query according to format above. mel_query=mel_query,#mel query according to format above. audio_load_fn=audio_load_fn,#load the function that loads the audio data according to format above. mel_load_fn=mel_load_fn,#load the function that loads the mel data according to format above. mel_length_threshold=mel_length_threshold,#th to remove short samples -calculated above-. allow_cache=config.get("allow_cache", False), # keep compatibility. ) else:# if training ds is None train_dataset = AudioMelSCPDataset(# define the training dataset wav_scp=args.train_wav_scp, feats_scp=args.train_feats_scp, segments=args.train_segments, #segments of dataset mel_length_threshold=mel_length_threshold,#th to remove short samples -calculated above-. allow_cache=config.get("allow_cache", False), # keep compatibility ) logging.info(f"The number of training files = {len(train_dataset)}.") # add length of trainning data set to the logger. if args.dev_dumpdir is not None: #if evaluating ds is not None dev_dataset = AudioMelDataset( # define the evaluating dataset root_dir=args.dev_dumpdir,#the directory of ds. audio_query=audio_query,#audio query according to format above. mel_query=mel_query,#mel query according to format above. audio_load_fn=audio_load_fn,#load the function that loads the audio data according to format above. mel_load_fn=mel_load_fn,#load the function that loads the mel data according to format above. mel_length_threshold=mel_length_threshold,#th to remove short samples -calculated above-. allow_cache=config.get("allow_cache", False), # keep compatibility ) else:# if evaluating ds is None dev_dataset = AudioMelSCPDataset( wav_scp=args.dev_wav_scp, feats_scp=args.dev_feats_scp, segments=args.dev_segments,#segments of dataset mel_length_threshold=mel_length_threshold,#th to remove short samples -calculated above-. allow_cache=config.get("allow_cache", False), # keep compatibility ) logging.info(f"The number of development files = {len(dev_dataset)}.") # add length of evaluating data set to the logger. dataset = { "train": train_dataset, "dev": dev_dataset, } #define the whole dataset used which is divided into training and evaluating datasets # get data loader collater = Collater( batch_max_steps=config["batch_max_steps"], hop_size=config["hop_size"], # keep compatibility aux_context_window=config["generator_params"].get("aux_context_window", 0), # keep compatibility use_noise_input=config.get( "generator_type", "ParallelWaveGANGenerator") != "MelGANGenerator", ) train_sampler, dev_sampler = None, None if args.distributed: # setup sampler for distributed training from torch.utils.data.distributed import DistributedSampler train_sampler = DistributedSampler( dataset=dataset["train"], num_replicas=args.world_size, rank=args.rank, shuffle=True, ) dev_sampler = DistributedSampler( dataset=dataset["dev"], num_replicas=args.world_size, rank=args.rank, shuffle=False, ) data_loader = { "train": DataLoader( dataset=dataset["train"], shuffle=False if args.distributed else True, collate_fn=collater, batch_size=config["batch_size"], num_workers=config["num_workers"], sampler=train_sampler, pin_memory=config["pin_memory"], ), "dev": DataLoader( dataset=dataset["dev"], shuffle=False if args.distributed else True, collate_fn=collater, batch_size=config["batch_size"], num_workers=config["num_workers"], sampler=dev_sampler, pin_memory=config["pin_memory"], ), } # define models and optimizers generator_class = getattr( parallel_wavegan.models, # keep compatibility config.get("generator_type", "ParallelWaveGANGenerator"), ) discriminator_class = getattr( parallel_wavegan.models, # keep compatibility config.get("discriminator_type", "ParallelWaveGANDiscriminator"), ) model = { "generator": generator_class( **config["generator_params"]).to(device), "discriminator": discriminator_class( **config["discriminator_params"]).to(device), } criterion = { "stft": MultiResolutionSTFTLoss( **config["stft_loss_params"]).to(device), "mse": torch.nn.MSELoss().to(device), } if config.get("use_feat_match_loss", False): # keep compatibility criterion["l1"] = torch.nn.L1Loss().to(device) optimizer = { "generator": RAdam( model["generator"].parameters(), **config["generator_optimizer_params"]), "discriminator": RAdam( model["discriminator"].parameters(), **config["discriminator_optimizer_params"]), } scheduler = { "generator": torch.optim.lr_scheduler.StepLR( optimizer=optimizer["generator"], **config["generator_scheduler_params"]), "discriminator": torch.optim.lr_scheduler.StepLR( optimizer=optimizer["discriminator"], **config["discriminator_scheduler_params"]), } if args.distributed: # wrap model for distributed training try: from apex.parallel import DistributedDataParallel except ImportError: raise ImportError("apex is not installed. please check https://github.com/NVIDIA/apex.") model["generator"] = DistributedDataParallel(model["generator"]) model["discriminator"] = DistributedDataParallel(model["discriminator"]) logging.info(model["generator"]) logging.info(model["discriminator"]) # define trainer trainer = Trainer( steps=0, epochs=0, data_loader=data_loader, model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, config=config, device=device, ) # load pretrained parameters from checkpoint if len(args.pretrain) != 0: trainer.load_checkpoint(args.pretrain, load_only_params=True) logging.info(f"Successfully load parameters from {args.pretrain}.") # resume from checkpoint if len(args.resume) != 0: trainer.load_checkpoint(args.resume) logging.info(f"Successfully resumed from {args.resume}.") # run training loop try: trainer.run() except KeyboardInterrupt: trainer.save_checkpoint( os.path.join(config["outdir"], f"checkpoint-{trainer.steps}steps.pkl")) logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
def main(): """Run preprocessing process.""" parser = argparse.ArgumentParser( description="Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py).") parser.add_argument("--rootdir", default=None, type=str, help="directory including feature files to be normalized. " "you need to specify either *-scp or rootdir.") parser.add_argument("--wav-scp", default=None, type=str, help="kaldi-style wav.scp file. " "you need to specify either *-scp or rootdir.") parser.add_argument("--feats-scp", default=None, type=str, help="kaldi-style feats.scp file. " "you need to specify either *-scp or rootdir.") parser.add_argument("--segments", default=None, type=str, help="kaldi-style segments file.") parser.add_argument("--dumpdir", type=str, required=True, help="directory to dump normalized feature files.") parser.add_argument("--stats", type=str, required=True, help="statistics file.") parser.add_argument("--skip-wav-copy", default=False, action="store_true", help="whether to skip the copy of wav files.") parser.add_argument("--config", type=str, required=True, help="yaml format configuration file.") parser.add_argument("--ftype", default='mel', type=str, help="feature type") parser.add_argument("--verbose", type=int, default=1, help="logging level. higher is more logging. (default=1)") # runtime mode args = parser.parse_args() # interactive mode # args = argparse.ArgumentParser() # args.wav_scp = None # args.feats_scp = None # args.segment = None # args.dumpdir = "" # args.skip_wav_copy = True # args.config = 'egs/so_emo_female/voc1/conf/multi_band_melgan.v2.yaml' # args.ftype = 'spec' # args.verbose = 1 # args.rootdir = '/data/evs/VCTK/VCTK-wgan/spec' # args.stats = '/data/evs/VCTK/VCTK-wgan/spec/mel_mean_std.npy' # args.rootdir = '/data/evs/Arctic/spec' # args.stats = '/data/evs/Arctic/spec/spec_mean_std.npy' # set logger if args.verbose > 1: logging.basicConfig( level=logging.DEBUG, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") elif args.verbose > 0: logging.basicConfig( level=logging.INFO, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") else: logging.basicConfig( level=logging.WARN, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") logging.warning('Skip DEBUG/INFO messages') # load config with open(args.config) as f: config = yaml.load(f, Loader=yaml.Loader) config.update(vars(args)) # check arguments if (args.feats_scp is not None and args.rootdir is not None) or \ (args.feats_scp is None and args.rootdir is None): raise ValueError("Please specify either --rootdir or --feats-scp.") # check directory existence if args.dumpdir != "": if not os.path.exists(args.dumpdir): os.makedirs(args.dumpdir, exist_ok=True) # get dataset if args.rootdir is not None: if config["format"] == "hdf5": audio_query, mel_query = "*.h5", "*.h5" audio_load_fn = lambda x: read_hdf5(x, "wave") # NOQA mel_load_fn = lambda x: read_hdf5(x, "feats") # NOQA elif config["format"] == "npy": audio_query, mel_query, spc_query = "*.wav.npy", "*.mel.npy", "*.spec.npy" audio_load_fn = np.load mel_load_fn = np.load spc_load_fn = np.load else: raise ValueError("support only hdf5 or npy format.") if not args.skip_wav_copy: dataset = AudioMelDataset( root_dir=args.rootdir, audio_query=audio_query, mel_query=mel_query, audio_load_fn=audio_load_fn, mel_load_fn=mel_load_fn, return_utt_id=True, ) else: dataset1 = MelDatasetNew( root_dir=args.rootdir, mel_query=mel_query, mel_load_fn=mel_load_fn, return_utt_id=True, ) dataset2 = SpcDatasetNew( root_dir=args.rootdir, spc_query=spc_query, spc_load_fn=spc_load_fn, return_utt_id=True, ) else: if not args.skip_wav_copy: dataset = AudioMelSCPDataset( wav_scp=args.wav_scp, feats_scp=args.feats_scp, segments=args.segments, return_utt_id=True, ) else: dataset = MelSCPDataset( feats_scp=args.feats_scp, return_utt_id=True, ) logging.info(f"The number of files in mel dataset = {len(dataset1)}.") logging.info(f"The number of files in spc dataset = {len(dataset2)}.") # restore scaler scaler = StandardScaler() if config["format"] == "hdf5": scaler.mean_ = read_hdf5(args.stats, "mean") scaler.scale_ = read_hdf5(args.stats, "scale") elif config["format"] == "npy": scaler.mean_ = np.load(args.stats)[0] scaler.scale_ = np.load(args.stats)[1] else: raise ValueError("support only hdf5 or npy format.") # from version 0.23.0, this information is needed scaler.n_features_in_ = scaler.mean_.shape[0] # process each file if args.ftype == 'mel': dataset = dataset1 elif args.ftype == 'spec': dataset = dataset2 for items in tqdm(dataset): if not args.skip_wav_copy: utt_id, audio, feat = items else: utt_id, feat, feat_file = items # normalize feat = scaler.transform(feat) # feat = (feat - scaler.mean_) / scaler.scale_ # this is identical to scaler.transform(feat) # save if config["format"] == "hdf5": write_hdf5(os.path.join(args.dumpdir, f"{utt_id}.h5"), "feats", feat.astype(np.float32)) if not args.skip_wav_copy: write_hdf5(os.path.join(args.dumpdir, f"{utt_id}.h5"), "wave", audio.astype(np.float32)) elif config["format"] == "npy": if args.dumpdir == "": feat_file = feat_file.replace('.npy', '') np.save((feat_file + "-norm.npy"), feat.astype(np.float32), allow_pickle=False) if not args.skip_wav_copy: print("Please include --skip_wav_copy in arguments") else: np.save(os.path.join(args.dumpdir, f"{utt_id}.npy"), feat.astype(np.float32), allow_pickle=False) if not args.skip_wav_copy: np.save(os.path.join(args.dumpdir, f"{utt_id}.wav.npy"), audio.astype(np.float32), allow_pickle=False) else: raise ValueError("support only hdf5 or npy format.")
def main(): """Run preprocessing process.""" parser = argparse.ArgumentParser( description="Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py)." ) parser.add_argument( "--rootdir", default=None, type=str, help="directory including feature files to be normalized. " "you need to specify either *-scp or rootdir.", ) parser.add_argument( "--wav-scp", default=None, type=str, help="kaldi-style wav.scp file. " "you need to specify either *-scp or rootdir.", ) parser.add_argument( "--feats-scp", default=None, type=str, help="kaldi-style feats.scp file. " "you need to specify either *-scp or rootdir.", ) parser.add_argument( "--segments", default=None, type=str, help="kaldi-style segments file.", ) parser.add_argument( "--dumpdir", type=str, required=True, help="directory to dump normalized feature files.", ) parser.add_argument( "--stats", type=str, required=True, help="statistics file.", ) parser.add_argument( "--skip-wav-copy", default=False, action="store_true", help="whether to skip the copy of wav files.", ) parser.add_argument( "--config", type=str, required=True, help="yaml format configuration file." ) parser.add_argument( "--verbose", type=int, default=1, help="logging level. higher is more logging. (default=1)", ) args = parser.parse_args() # set logger if args.verbose > 1: logging.basicConfig( level=logging.DEBUG, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) elif args.verbose > 0: logging.basicConfig( level=logging.INFO, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) else: logging.basicConfig( level=logging.WARN, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) logging.warning("Skip DEBUG/INFO messages") # load config with open(args.config) as f: config = yaml.load(f, Loader=yaml.Loader) config.update(vars(args)) # check arguments if (args.feats_scp is not None and args.rootdir is not None) or ( args.feats_scp is None and args.rootdir is None ): raise ValueError("Please specify either --rootdir or --feats-scp.") # check directory existence if not os.path.exists(args.dumpdir): os.makedirs(args.dumpdir) # get dataset if args.rootdir is not None: if config["format"] == "hdf5": audio_query, mel_query = "*.h5", "*.h5" audio_load_fn = lambda x: read_hdf5(x, "wave") # NOQA mel_load_fn = lambda x: read_hdf5(x, "feats") # NOQA elif config["format"] == "npy": audio_query, mel_query = "*-wave.npy", "*-feats.npy" audio_load_fn = np.load mel_load_fn = np.load else: raise ValueError("support only hdf5 or npy format.") if not args.skip_wav_copy: dataset = AudioMelDataset( root_dir=args.rootdir, audio_query=audio_query, mel_query=mel_query, audio_load_fn=audio_load_fn, mel_load_fn=mel_load_fn, return_utt_id=True, ) else: dataset = MelDataset( root_dir=args.rootdir, mel_query=mel_query, mel_load_fn=mel_load_fn, return_utt_id=True, ) else: if not args.skip_wav_copy: dataset = AudioMelSCPDataset( wav_scp=args.wav_scp, feats_scp=args.feats_scp, segments=args.segments, return_utt_id=True, ) else: dataset = MelSCPDataset( feats_scp=args.feats_scp, return_utt_id=True, ) logging.info(f"The number of files = {len(dataset)}.") # restore scaler scaler = StandardScaler() if config["format"] == "hdf5": scaler.mean_ = read_hdf5(args.stats, "mean") scaler.scale_ = read_hdf5(args.stats, "scale") elif config["format"] == "npy": scaler.mean_ = np.load(args.stats)[0] scaler.scale_ = np.load(args.stats)[1] else: raise ValueError("support only hdf5 or npy format.") # from version 0.23.0, this information is needed scaler.n_features_in_ = scaler.mean_.shape[0] # process each file for items in tqdm(dataset): if not args.skip_wav_copy: utt_id, audio, mel = items else: utt_id, mel = items # normalize mel = scaler.transform(mel) # save if config["format"] == "hdf5": write_hdf5( os.path.join(args.dumpdir, f"{utt_id}.h5"), "feats", mel.astype(np.float32), ) if not args.skip_wav_copy: write_hdf5( os.path.join(args.dumpdir, f"{utt_id}.h5"), "wave", audio.astype(np.float32), ) elif config["format"] == "npy": np.save( os.path.join(args.dumpdir, f"{utt_id}-feats.npy"), mel.astype(np.float32), allow_pickle=False, ) if not args.skip_wav_copy: np.save( os.path.join(args.dumpdir, f"{utt_id}-wave.npy"), audio.astype(np.float32), allow_pickle=False, ) else: raise ValueError("support only hdf5 or npy format.")
def main(): """Run training process.""" parser = argparse.ArgumentParser(description=( "Train Parallel WaveGAN (See detail in parallel_wavegan/bin/train.py)." )) parser.add_argument( "--train-wav-scp", default=None, type=str, help=("kaldi-style wav.scp file for training. " "you need to specify either train-*-scp or train-dumpdir."), ) parser.add_argument( "--train-feats-scp", default=None, type=str, help=("kaldi-style feats.scp file for training. " "you need to specify either train-*-scp or train-dumpdir."), ) parser.add_argument( "--train-segments", default=None, type=str, help="kaldi-style segments file for training.", ) parser.add_argument( "--train-dumpdir", default=None, type=str, help=("directory including training data. " "you need to specify either train-*-scp or train-dumpdir."), ) parser.add_argument( "--dev-wav-scp", default=None, type=str, help=("kaldi-style wav.scp file for validation. " "you need to specify either dev-*-scp or dev-dumpdir."), ) parser.add_argument( "--dev-feats-scp", default=None, type=str, help=("kaldi-style feats.scp file for vaidation. " "you need to specify either dev-*-scp or dev-dumpdir."), ) parser.add_argument( "--dev-segments", default=None, type=str, help="kaldi-style segments file for validation.", ) parser.add_argument( "--dev-dumpdir", default=None, type=str, help=("directory including development data. " "you need to specify either dev-*-scp or dev-dumpdir."), ) parser.add_argument( "--outdir", type=str, required=True, help="directory to save checkpoints.", ) parser.add_argument( "--config", type=str, required=True, help="yaml format configuration file.", ) parser.add_argument( "--pretrain", default="", type=str, nargs="?", help='checkpoint file path to load pretrained params. (default="")', ) parser.add_argument( "--resume", default="", type=str, nargs="?", help='checkpoint file path to resume training. (default="")', ) parser.add_argument( "--verbose", type=int, default=1, help="logging level. higher is more logging. (default=1)", ) parser.add_argument( "--rank", "--local_rank", default=0, type=int, help="rank for distributed training. no need to explictly specify.", ) args = parser.parse_args() args.distributed = False if not torch.cuda.is_available(): device = torch.device("cpu") else: device = torch.device("cuda") # effective when using fixed size inputs # see https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936 torch.backends.cudnn.benchmark = True torch.cuda.set_device(args.rank) # setup for distributed training # see example: https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed if "WORLD_SIZE" in os.environ: args.world_size = int(os.environ["WORLD_SIZE"]) args.distributed = args.world_size > 1 if args.distributed: torch.distributed.init_process_group(backend="nccl", init_method="env://") # suppress logging for distributed training if args.rank != 0: sys.stdout = open(os.devnull, "w") # set logger if args.verbose > 1: logging.basicConfig( level=logging.DEBUG, stream=sys.stdout, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) elif args.verbose > 0: logging.basicConfig( level=logging.INFO, stream=sys.stdout, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) else: logging.basicConfig( level=logging.WARN, stream=sys.stdout, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) logging.warning("Skip DEBUG/INFO messages") # check directory existence if not os.path.exists(args.outdir): os.makedirs(args.outdir) # check arguments if (args.train_feats_scp is not None and args.train_dumpdir is not None) or (args.train_feats_scp is None and args.train_dumpdir is None): raise ValueError( "Please specify either --train-dumpdir or --train-*-scp.") if (args.dev_feats_scp is not None and args.dev_dumpdir is not None) or ( args.dev_feats_scp is None and args.dev_dumpdir is None): raise ValueError("Please specify either --dev-dumpdir or --dev-*-scp.") # load and save config with open(args.config) as f: config = yaml.load(f, Loader=yaml.Loader) config.update(vars(args)) config["version"] = parallel_wavegan.__version__ # add version info with open(os.path.join(args.outdir, "config.yml"), "w") as f: yaml.dump(config, f, Dumper=yaml.Dumper) for key, value in config.items(): logging.info(f"{key} = {value}") # get dataset if config["remove_short_samples"]: mel_length_threshold = config["batch_max_steps"] // config[ "hop_size"] + 2 * config["generator_params"].get( "aux_context_window", 0) else: mel_length_threshold = None if args.train_wav_scp is None or args.dev_wav_scp is None: if config["format"] == "hdf5": audio_query, mel_query = "*.h5", "*.h5" audio_load_fn = lambda x: read_hdf5(x, "wave") # NOQA mel_load_fn = lambda x: read_hdf5(x, "feats") # NOQA elif config["format"] == "npy": audio_query, mel_query = "*-wave.npy", "*-feats.npy" audio_load_fn = np.load mel_load_fn = np.load else: raise ValueError("support only hdf5 or npy format.") if args.train_dumpdir is not None: train_dataset = AudioMelDataset( root_dir=args.train_dumpdir, audio_query=audio_query, mel_query=mel_query, audio_load_fn=audio_load_fn, mel_load_fn=mel_load_fn, mel_length_threshold=mel_length_threshold, allow_cache=config.get("allow_cache", False), # keep compatibility ) else: train_dataset = AudioMelSCPDataset( wav_scp=args.train_wav_scp, feats_scp=args.train_feats_scp, segments=args.train_segments, mel_length_threshold=mel_length_threshold, allow_cache=config.get("allow_cache", False), # keep compatibility ) logging.info(f"The number of training files = {len(train_dataset)}.") if args.dev_dumpdir is not None: dev_dataset = AudioMelDataset( root_dir=args.dev_dumpdir, audio_query=audio_query, mel_query=mel_query, audio_load_fn=audio_load_fn, mel_load_fn=mel_load_fn, mel_length_threshold=mel_length_threshold, allow_cache=config.get("allow_cache", False), # keep compatibility ) else: dev_dataset = AudioMelSCPDataset( wav_scp=args.dev_wav_scp, feats_scp=args.dev_feats_scp, segments=args.dev_segments, mel_length_threshold=mel_length_threshold, allow_cache=config.get("allow_cache", False), # keep compatibility ) logging.info(f"The number of development files = {len(dev_dataset)}.") dataset = { "train": train_dataset, "dev": dev_dataset, } # get data loader collater = Collater( batch_max_steps=config["batch_max_steps"], hop_size=config["hop_size"], # keep compatibility aux_context_window=config["generator_params"].get( "aux_context_window", 0), # keep compatibility use_noise_input=config.get("generator_type", "ParallelWaveGANGenerator") in ["ParallelWaveGANGenerator"], ) sampler = {"train": None, "dev": None} if args.distributed: # setup sampler for distributed training from torch.utils.data.distributed import DistributedSampler sampler["train"] = DistributedSampler( dataset=dataset["train"], num_replicas=args.world_size, rank=args.rank, shuffle=True, ) sampler["dev"] = DistributedSampler( dataset=dataset["dev"], num_replicas=args.world_size, rank=args.rank, shuffle=False, ) data_loader = { "train": DataLoader( dataset=dataset["train"], shuffle=False if args.distributed else True, collate_fn=collater, batch_size=config["batch_size"], num_workers=config["num_workers"], sampler=sampler["train"], pin_memory=config["pin_memory"], ), "dev": DataLoader( dataset=dataset["dev"], shuffle=False if args.distributed else True, collate_fn=collater, batch_size=config["batch_size"], num_workers=config["num_workers"], sampler=sampler["dev"], pin_memory=config["pin_memory"], ), } # define models generator_class = getattr( parallel_wavegan.models, # keep compatibility config.get("generator_type", "ParallelWaveGANGenerator"), ) discriminator_class = getattr( parallel_wavegan.models, # keep compatibility config.get("discriminator_type", "ParallelWaveGANDiscriminator"), ) model = { "generator": generator_class(**config["generator_params"], ).to(device), "discriminator": discriminator_class(**config["discriminator_params"], ).to(device), } # define criterions criterion = { "gen_adv": GeneratorAdversarialLoss( # keep compatibility **config.get("generator_adv_loss_params", {})).to(device), "dis_adv": DiscriminatorAdversarialLoss( # keep compatibility **config.get("discriminator_adv_loss_params", {})).to(device), } if config.get("use_stft_loss", True): # keep compatibility config["use_stft_loss"] = True criterion["stft"] = MultiResolutionSTFTLoss( **config["stft_loss_params"], ).to(device) if config.get("use_subband_stft_loss", False): # keep compatibility assert config["generator_params"]["out_channels"] > 1 criterion["sub_stft"] = MultiResolutionSTFTLoss( **config["subband_stft_loss_params"], ).to(device) else: config["use_subband_stft_loss"] = False if config.get("use_feat_match_loss", False): # keep compatibility criterion["feat_match"] = FeatureMatchLoss( # keep compatibility **config.get("feat_match_loss_params", {}), ).to(device) else: config["use_feat_match_loss"] = False if config.get("use_mel_loss", False): # keep compatibility if config.get("mel_loss_params", None) is None: criterion["mel"] = MelSpectrogramLoss( fs=config["sampling_rate"], fft_size=config["fft_size"], hop_size=config["hop_size"], win_length=config["win_length"], window=config["window"], num_mels=config["num_mels"], fmin=config["fmin"], fmax=config["fmax"], ).to(device) else: criterion["mel"] = MelSpectrogramLoss(**config["mel_loss_params"], ).to(device) else: config["use_mel_loss"] = False # define special module for subband processing if config["generator_params"]["out_channels"] > 1: criterion["pqmf"] = PQMF( subbands=config["generator_params"]["out_channels"], # keep compatibility **config.get("pqmf_params", {}), ).to(device) # define optimizers and schedulers generator_optimizer_class = getattr( parallel_wavegan.optimizers, # keep compatibility config.get("generator_optimizer_type", "RAdam"), ) discriminator_optimizer_class = getattr( parallel_wavegan.optimizers, # keep compatibility config.get("discriminator_optimizer_type", "RAdam"), ) optimizer = { "generator": generator_optimizer_class( model["generator"].parameters(), **config["generator_optimizer_params"], ), "discriminator": discriminator_optimizer_class( model["discriminator"].parameters(), **config["discriminator_optimizer_params"], ), } generator_scheduler_class = getattr( torch.optim.lr_scheduler, # keep compatibility config.get("generator_scheduler_type", "StepLR"), ) discriminator_scheduler_class = getattr( torch.optim.lr_scheduler, # keep compatibility config.get("discriminator_scheduler_type", "StepLR"), ) scheduler = { "generator": generator_scheduler_class( optimizer=optimizer["generator"], **config["generator_scheduler_params"], ), "discriminator": discriminator_scheduler_class( optimizer=optimizer["discriminator"], **config["discriminator_scheduler_params"], ), } if args.distributed: # wrap model for distributed training try: from apex.parallel import DistributedDataParallel except ImportError: raise ImportError( "apex is not installed. please check https://github.com/NVIDIA/apex." ) model["generator"] = DistributedDataParallel(model["generator"]) model["discriminator"] = DistributedDataParallel( model["discriminator"]) # show settings logging.info(model["generator"]) logging.info(model["discriminator"]) logging.info(optimizer["generator"]) logging.info(optimizer["discriminator"]) logging.info(scheduler["generator"]) logging.info(scheduler["discriminator"]) for criterion_ in criterion.values(): logging.info(criterion_) # define trainer trainer = Trainer( steps=0, epochs=0, data_loader=data_loader, sampler=sampler, model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, config=config, device=device, ) # load pretrained parameters from checkpoint if len(args.pretrain) != 0: trainer.load_checkpoint(args.pretrain, load_only_params=True) logging.info(f"Successfully load parameters from {args.pretrain}.") # resume from checkpoint if len(args.resume) != 0: trainer.load_checkpoint(args.resume) logging.info(f"Successfully resumed from {args.resume}.") # run training loop try: trainer.run() finally: trainer.save_checkpoint( os.path.join(config["outdir"], f"checkpoint-{trainer.steps}steps.pkl")) logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")