def get_model(conf, spkr_size=0, device="cuda"): models = {"G": VQVAE2(conf, spkr_size=spkr_size).to(device)} logging.info(models["G"]) # discriminator if conf["gan_type"] == "lsgan": output_channels = 1 if conf["acgan_flag"]: output_channels += spkr_size if conf["trainer_type"] in ["lsgan", "cyclegan"]: if conf["discriminator_type"] == "pwg": D = ParallelWaveGANDiscriminator( in_channels=conf["input_size"], out_channels=output_channels, kernel_size=conf["discriminator_kernel_size"], layers=conf["n_discriminator_layers"], conv_channels=64, dilation_factor=1, nonlinear_activation="LeakyReLU", nonlinear_activation_params={"negative_slope": 0.2}, bias=True, use_weight_norm=True, ) else: raise NotImplementedError() models.update({"D": D.to(device)}) logging.info(models["D"]) if conf["speaker_adversarial"]: SPKRADV = SpeakerAdversarialNetwork(conf, spkr_size) models.update({"SPKRADV": SPKRADV.to(device)}) logging.info(models["SPKRADV"]) return models
def _construct_net(self): self.grl = GradientReversalLayer(scale=self.conf["spkradv_lambda"]) # TODO(k2kobayashi): investigate peformance of residual network # if self.conf["use_residual_network"]: # self.classifier = ResidualParallelWaveGANDiscriminator( # in_channels=sum( # self.conf["emb_dim"][:self.conf["n_vq_stacks"]]), # out_channels=self.spkr_size, # kernel_size=self.conf["spkradv_kernel_size"], # layers=self.conf["n_spkradv_layers"], # stacks=self.conf["n_spkradv_layers"] // 2, # ) # else: self.classifier = ParallelWaveGANDiscriminator( in_channels=sum(self.conf["emb_dim"][:self.conf["n_vq_stacks"]]), out_channels=self.spkr_size, kernel_size=self.conf["spkradv_kernel_size"], layers=self.conf["n_spkradv_layers"], conv_channels=64, dilation_factor=1, nonlinear_activation="LeakyReLU", nonlinear_activation_params={"negative_slope": 0.2}, bias=True, use_weight_norm=True, )
def test_parallel_wavegan_trainable(dict_g, dict_d, dict_loss): # setup batch_size = 4 batch_length = 4096 args_g = make_generator_args(**dict_g) args_d = make_discriminator_args(**dict_d) args_loss = make_mutli_reso_stft_loss_args(**dict_loss) z = torch.randn(batch_size, 1, batch_length) y = torch.randn(batch_size, 1, batch_length) c = torch.randn( batch_size, args_g["aux_channels"], batch_length // np.prod(args_g["upsample_params"]["upsample_scales"]) + 2 * args_g["aux_context_window"], ) model_g = ParallelWaveGANGenerator(**args_g) model_d = ParallelWaveGANDiscriminator(**args_d) aux_criterion = MultiResolutionSTFTLoss(**args_loss) gen_adv_criterion = GeneratorAdversarialLoss() dis_adv_criterion = DiscriminatorAdversarialLoss() optimizer_g = RAdam(model_g.parameters()) optimizer_d = RAdam(model_d.parameters()) # check generator trainable y_hat = model_g(z, c) p_hat = model_d(y_hat) adv_loss = gen_adv_criterion(p_hat) sc_loss, mag_loss = aux_criterion(y_hat, y) aux_loss = sc_loss + mag_loss loss_g = adv_loss + aux_loss optimizer_g.zero_grad() loss_g.backward() optimizer_g.step() # check discriminator trainable p = model_d(y) p_hat = model_d(y_hat.detach()) real_loss, fake_loss = dis_adv_criterion(p_hat, p) loss_d = real_loss + fake_loss optimizer_d.zero_grad() loss_d.backward() optimizer_d.step()
def test_parallel_wavegan_trainable(dict_g, dict_d, dict_loss): # setup batch_size = 4 batch_length = 4096 args_g = make_generator_args(**dict_g) args_d = make_discriminator_args(**dict_d) args_loss = make_mutli_reso_stft_loss_args(**dict_loss) z = torch.randn(batch_size, 1, batch_length) y = torch.randn(batch_size, 1, batch_length) c = torch.randn( batch_size, args_g["aux_channels"], batch_length // np.prod(args_g["upsample_params"]["upsample_scales"]) + 2 * args_g["aux_context_window"]) model_g = ParallelWaveGANGenerator(**args_g) model_d = ParallelWaveGANDiscriminator(**args_d) aux_criterion = MultiResolutionSTFTLoss(**args_loss) optimizer_g = RAdam(model_g.parameters()) optimizer_d = RAdam(model_d.parameters()) # check generator trainable y_hat = model_g(z, c) p_hat = model_d(y_hat) y, y_hat, p_hat = y.squeeze(1), y_hat.squeeze(1), p_hat.squeeze(1) adv_loss = F.mse_loss(p_hat, p_hat.new_ones(p_hat.size())) sc_loss, mag_loss = aux_criterion(y_hat, y) aux_loss = sc_loss + mag_loss loss_g = adv_loss + aux_loss optimizer_g.zero_grad() loss_g.backward() optimizer_g.step() # check discriminator trainable y, y_hat = y.unsqueeze(1), y_hat.unsqueeze(1).detach() p = model_d(y) p_hat = model_d(y_hat) p, p_hat = p.squeeze(1), p_hat.squeeze(1) loss_d = F.mse_loss(p, p.new_ones(p.size())) + F.mse_loss( p_hat, p_hat.new_zeros(p_hat.size())) optimizer_d.zero_grad() loss_d.backward() optimizer_d.step()
def _construct_net(self): self.grl = GradientReversalLayer(scale=self.conf["spkradv_lambda"]) self.classifier = ParallelWaveGANDiscriminator( in_channels=sum(self.conf["emb_dim"][:self.conf["n_vq_stacks"]]), out_channels=self.spkr_size, kernel_size=self.conf["spkradv_kernel_size"], layers=self.conf["n_spkradv_layers"], conv_channels=64, dilation_factor=1, nonlinear_activation="LeakyReLU", nonlinear_activation_params={"negative_slope": 0.2}, bias=True, use_weight_norm=True, )
def get_model(conf, spkr_size=0, device="cuda"): G = VQVAE2(conf, spkr_size=spkr_size).to(device) # discriminator if conf["gan_type"] == "lsgan": output_channels = 1 if conf["acgan_flag"]: output_channels += spkr_size if conf["discriminator_type"] == "pwg": D = ParallelWaveGANDiscriminator( in_channels=conf["input_size"], out_channels=output_channels, kernel_size=conf["kernel_size"][0], layers=conf["n_discriminator_layers"], conv_channels=64, dilation_factor=1, nonlinear_activation="LeakyReLU", nonlinear_activation_params={"negative_slope": 0.2}, bias=True, use_weight_norm=True, ) return {"G": G.to(device), "D": D.to(device)}
def get_model(conf, spkr_size=0, device="cuda"): models = {"G": VQVAE2(conf, spkr_size=spkr_size).to(device)} logging.info(models["G"]) # speaker adversarial network if conf["use_spkradv_training"]: SPKRADV = SpeakerAdversarialNetwork(conf, spkr_size) models.update({"SPKRADV": SPKRADV.to(device)}) logging.info(models["SPKRADV"]) # spkr classifier network if conf["use_spkr_classifier"]: C = ParallelWaveGANDiscriminator( in_channels=conf["input_size"], out_channels=spkr_size, kernel_size=conf["spkr_classifier_kernel_size"], layers=conf["n_spkr_classifier_layers"], conv_channels=64, dilation_factor=1, nonlinear_activation="LeakyReLU", nonlinear_activation_params={"negative_slope": 0.2}, bias=True, use_weight_norm=True, ) models.update({"C": C.to(device)}) logging.info(models["C"]) # discriminator if conf["trainer_type"] in ["lsgan", "cyclegan", "stargan"]: input_channels = conf["input_size"] if conf["use_D_uv"]: input_channels += 1 # for uv flag if conf["use_D_spkrcode"]: if not conf["use_spkr_embedding"]: input_channels += spkr_size else: input_channels += conf["spkr_embedding_size"] if conf["gan_type"] == "lsgan": output_channels = 1 if conf["acgan_flag"]: output_channels += spkr_size D = ParallelWaveGANDiscriminator( in_channels=input_channels, out_channels=output_channels, kernel_size=conf["discriminator_kernel_size"], layers=conf["n_discriminator_layers"], conv_channels=64, dilation_factor=1, nonlinear_activation="LeakyReLU", nonlinear_activation_params={"negative_slope": 0.2}, bias=True, use_weight_norm=True, ) models.update({"D": D.to(device)}) logging.info(models["D"]) return models
def main(): """Run training process.""" parser = argparse.ArgumentParser( description= "Train Parallel WaveGAN (See detail in parallel_wavegan/bin/train.py)." ) parser.add_argument("--train-dumpdir", type=str, required=True, help="directory including trainning data.") parser.add_argument("--dev-dumpdir", type=str, required=True, help="directory including development data.") parser.add_argument("--outdir", type=str, required=True, help="directory to save checkpoints.") parser.add_argument("--config", type=str, required=True, help="yaml format configuration file.") parser.add_argument( "--resume", default="", type=str, nargs="?", help="checkpoint file path to resume training. (default=\"\")") parser.add_argument( "--verbose", type=int, default=1, help="logging level. higher is more logging. (default=1)") args = parser.parse_args() # set logger if args.verbose > 1: logging.basicConfig( level=logging.DEBUG, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") elif args.verbose > 0: logging.basicConfig( level=logging.INFO, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") else: logging.basicConfig( level=logging.WARN, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") logging.warning('skip DEBUG/INFO messages') # check directory existence if not os.path.exists(args.outdir): os.makedirs(args.outdir) # load and save config with open(args.config) as f: config = yaml.load(f, Loader=yaml.Loader) config.update(vars(args)) with open(os.path.join(args.outdir, "config.yml"), "w") as f: yaml.dump(config, f, Dumper=yaml.Dumper) for key, value in config.items(): logging.info(f"{key} = {value}") # get dataset if config["remove_short_samples"]: mel_length_threshold = config["batch_max_steps"] // config["hop_size"] + \ 2 * config["generator_params"]["aux_context_window"] else: mel_length_threshold = None if config["format"] == "hdf5": audio_query, mel_query = "*.h5", "*.h5" audio_load_fn = lambda x: read_hdf5(x, "wave") # NOQA mel_load_fn = lambda x: read_hdf5(x, "feats") # NOQA elif config["format"] == "npy": audio_query, mel_query = "*-wave.npy", "*-feats.npy" audio_load_fn = np.load mel_load_fn = np.load else: raise ValueError("support only hdf5 or npy format.") dataset = { "train": AudioMelDataset( root_dir=args.train_dumpdir, audio_query=audio_query, mel_query=mel_query, audio_load_fn=audio_load_fn, mel_load_fn=mel_load_fn, mel_length_threshold=mel_length_threshold, allow_cache=config.get("allow_cache", False), # keep compatibilty ), "dev": AudioMelDataset( root_dir=args.dev_dumpdir, audio_query=audio_query, mel_query=mel_query, audio_load_fn=audio_load_fn, mel_load_fn=mel_load_fn, mel_length_threshold=mel_length_threshold, allow_cache=config.get("allow_cache", False), # keep compatibilty ), } # get data loader if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") collater = Collater( batch_max_steps=config["batch_max_steps"], hop_size=config["hop_size"], aux_context_window=config["generator_params"]["aux_context_window"], ) data_loader = { "train": DataLoader(dataset=dataset["train"], shuffle=True, collate_fn=collater, batch_size=config["batch_size"], num_workers=config["num_workers"], pin_memory=config["pin_memory"]), "dev": DataLoader(dataset=dataset["dev"], shuffle=True, collate_fn=collater, batch_size=config["batch_size"], num_workers=config["num_workers"], pin_memory=config["pin_memory"]), } # define models and optimizers model = { "generator": ParallelWaveGANGenerator(**config["generator_params"]).to(device), "discriminator": ParallelWaveGANDiscriminator( **config["discriminator_params"]).to(device), } criterion = { "stft": MultiResolutionSTFTLoss(**config["stft_loss_params"]).to(device), "mse": torch.nn.MSELoss().to(device), } optimizer = { "generator": RAdam(model["generator"].parameters(), **config["generator_optimizer_params"]), "discriminator": RAdam(model["discriminator"].parameters(), **config["discriminator_optimizer_params"]), } scheduler = { "generator": torch.optim.lr_scheduler.StepLR( optimizer=optimizer["generator"], **config["generator_scheduler_params"]), "discriminator": torch.optim.lr_scheduler.StepLR( optimizer=optimizer["discriminator"], **config["discriminator_scheduler_params"]), } logging.info(model["generator"]) logging.info(model["discriminator"]) # define trainer trainer = Trainer( steps=0, epochs=0, data_loader=data_loader, model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, config=config, device=device, ) # resume from checkpoint if len(args.resume) != 0: trainer.load_checkpoint(args.resume) logging.info(f"resumed from {args.resume}.") # run training loop try: trainer.run() finally: trainer.save_checkpoint( os.path.join(config["outdir"], f"checkpoint-{trainer.steps}steps.pkl")) logging.info(f"successfully saved checkpoint @ {trainer.steps}steps.")
def main(): """Run training process.""" parser = argparse.ArgumentParser( description= "Train Parallel WaveGAN (See detail in parallel_wavegan/bin/train.py)." ) parser.add_argument("--train-dumpdir", type=str, required=True, help="directory including training data.") parser.add_argument("--dev-dumpdir", type=str, required=True, help="directory including development data.") parser.add_argument("--outdir", type=str, required=True, help="directory to save checkpoints.") parser.add_argument("--config", type=str, required=True, help="yaml format configuration file.") parser.add_argument( "--resume", default="", type=str, nargs="?", help="checkpoint file path to resume training. (default=\"\")") parser.add_argument( "--verbose", type=int, default=1, help="logging level. higher is more logging. (default=1)") parser.add_argument( "--rank", "--local_rank", default=0, type=int, help="rank for distributed training. no need to explictly specify.") args = parser.parse_args() args.distributed = False if not torch.cuda.is_available(): device = torch.device("cpu") else: device = torch.device("cuda") # effective when using fixed size inputs # see https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936 torch.backends.cudnn.benchmark = True torch.cuda.set_device(args.rank) # setup for distributed training # see example: https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed if "WORLD_SIZE" in os.environ: args.world_size = int(os.environ["WORLD_SIZE"]) args.distributed = args.world_size > 1 if args.distributed: torch.distributed.init_process_group(backend="nccl", init_method="env://") # suppress logging for distributed training if args.rank != 0: sys.stdout = open(os.devnull, "w") # set logger if args.verbose > 1: logging.basicConfig( level=logging.DEBUG, stream=sys.stdout, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") elif args.verbose > 0: logging.basicConfig( level=logging.INFO, stream=sys.stdout, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") else: logging.basicConfig( level=logging.WARN, stream=sys.stdout, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") logging.warning("skip DEBUG/INFO messages") # check directory existence if not os.path.exists(args.outdir): os.makedirs(args.outdir) # load and save config with open(args.config) as f: config = yaml.load(f, Loader=yaml.Loader) config.update(vars(args)) config["version"] = parallel_wavegan.__version__ # add version info with open(os.path.join(args.outdir, "config.yml"), "w") as f: yaml.dump(config, f, Dumper=yaml.Dumper) for key, value in config.items(): logging.info(f"{key} = {value}") # get dataset if config["remove_short_samples"]: mel_length_threshold = config["batch_max_steps"] // config["hop_size"] + \ 2 * config["generator_params"]["aux_context_window"] else: mel_length_threshold = None if config["format"] == "hdf5": audio_query, mel_query = "*.h5", "*.h5" audio_load_fn = lambda x: read_hdf5(x, "wave") # NOQA mel_load_fn = lambda x: read_hdf5(x, "feats") # NOQA elif config["format"] == "npy": audio_query, mel_query = "*-wave.npy", "*-feats.npy" audio_load_fn = np.load mel_load_fn = np.load else: raise ValueError("support only hdf5 or npy format.") dataset = { "train": AudioMelDataset(root_dir=args.train_dumpdir, audio_query=audio_query, mel_query=mel_query, audio_load_fn=audio_load_fn, mel_load_fn=mel_load_fn, mel_length_threshold=mel_length_threshold, allow_cache=config.get("allow_cache", False)), # keep compatibility "dev": AudioMelDataset(root_dir=args.dev_dumpdir, audio_query=audio_query, mel_query=mel_query, audio_load_fn=audio_load_fn, mel_load_fn=mel_load_fn, mel_length_threshold=mel_length_threshold, allow_cache=config.get("allow_cache", False)), # keep compatibility } # get data loader collater = Collater( batch_max_steps=config["batch_max_steps"], hop_size=config["hop_size"], aux_context_window=config["generator_params"]["aux_context_window"], ) train_sampler, dev_sampler = None, None if args.distributed: # setup sampler for distributed training from torch.utils.data.distributed import DistributedSampler train_sampler = DistributedSampler(dataset=dataset["train"], num_replicas=args.world_size, rank=args.rank, shuffle=True) dev_sampler = DistributedSampler(dataset=dataset["dev"], num_replicas=args.world_size, rank=args.rank, shuffle=False) data_loader = { "train": DataLoader(dataset=dataset["train"], shuffle=False if args.distributed else True, collate_fn=collater, batch_size=config["batch_size"], num_workers=config["num_workers"], sampler=train_sampler, pin_memory=config["pin_memory"]), "dev": DataLoader(dataset=dataset["dev"], shuffle=False if args.distributed else True, collate_fn=collater, batch_size=config["batch_size"], num_workers=config["num_workers"], sampler=dev_sampler, pin_memory=config["pin_memory"]), } # define models and optimizers model = { "generator": ParallelWaveGANGenerator(**config["generator_params"]).to(device), "discriminator": ParallelWaveGANDiscriminator( **config["discriminator_params"]).to(device), } criterion = { "stft": MultiResolutionSTFTLoss(**config["stft_loss_params"]).to(device), "mse": torch.nn.MSELoss().to(device), } optimizer = { "generator": RAdam(model["generator"].parameters(), **config["generator_optimizer_params"]), "discriminator": RAdam(model["discriminator"].parameters(), **config["discriminator_optimizer_params"]), } scheduler = { "generator": torch.optim.lr_scheduler.StepLR( optimizer=optimizer["generator"], **config["generator_scheduler_params"]), "discriminator": torch.optim.lr_scheduler.StepLR( optimizer=optimizer["discriminator"], **config["discriminator_scheduler_params"]), } if args.distributed: # wrap model for distributed training try: from apex.parallel import DistributedDataParallel except ImportError: raise ImportError( "apex is not installed. please check https://github.com/NVIDIA/apex." ) model["generator"] = DistributedDataParallel(model["generator"]) model["discriminator"] = DistributedDataParallel( model["discriminator"]) logging.info(model["generator"]) logging.info(model["discriminator"]) # define trainer trainer = Trainer( steps=0, epochs=0, data_loader=data_loader, model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, config=config, device=device, ) # resume from checkpoint if len(args.resume) != 0: trainer.load_checkpoint(args.resume) logging.info(f"successfully resumed from {args.resume}.") # run training loop try: trainer.run() except KeyboardInterrupt: trainer.save_checkpoint( os.path.join(config["outdir"], f"checkpoint-{trainer.steps}steps.pkl")) logging.info(f"successfully saved checkpoint @ {trainer.steps}steps.")