Exemplo n.º 1
0
def get_model(conf, spkr_size=0, device="cuda"):
    models = {"G": VQVAE2(conf, spkr_size=spkr_size).to(device)}
    logging.info(models["G"])

    # discriminator
    if conf["gan_type"] == "lsgan":
        output_channels = 1
    if conf["acgan_flag"]:
        output_channels += spkr_size
    if conf["trainer_type"] in ["lsgan", "cyclegan"]:
        if conf["discriminator_type"] == "pwg":
            D = ParallelWaveGANDiscriminator(
                in_channels=conf["input_size"],
                out_channels=output_channels,
                kernel_size=conf["discriminator_kernel_size"],
                layers=conf["n_discriminator_layers"],
                conv_channels=64,
                dilation_factor=1,
                nonlinear_activation="LeakyReLU",
                nonlinear_activation_params={"negative_slope": 0.2},
                bias=True,
                use_weight_norm=True,
            )
        else:
            raise NotImplementedError()
        models.update({"D": D.to(device)})
        logging.info(models["D"])

    if conf["speaker_adversarial"]:
        SPKRADV = SpeakerAdversarialNetwork(conf, spkr_size)
        models.update({"SPKRADV": SPKRADV.to(device)})
        logging.info(models["SPKRADV"])
    return models
    def _construct_net(self):
        self.grl = GradientReversalLayer(scale=self.conf["spkradv_lambda"])

        # TODO(k2kobayashi): investigate peformance of residual network
        # if self.conf["use_residual_network"]:
        #     self.classifier = ResidualParallelWaveGANDiscriminator(
        #         in_channels=sum(
        #             self.conf["emb_dim"][:self.conf["n_vq_stacks"]]),
        #         out_channels=self.spkr_size,
        #         kernel_size=self.conf["spkradv_kernel_size"],
        #         layers=self.conf["n_spkradv_layers"],
        #         stacks=self.conf["n_spkradv_layers"] // 2,
        #     )
        # else:
        self.classifier = ParallelWaveGANDiscriminator(
            in_channels=sum(self.conf["emb_dim"][:self.conf["n_vq_stacks"]]),
            out_channels=self.spkr_size,
            kernel_size=self.conf["spkradv_kernel_size"],
            layers=self.conf["n_spkradv_layers"],
            conv_channels=64,
            dilation_factor=1,
            nonlinear_activation="LeakyReLU",
            nonlinear_activation_params={"negative_slope": 0.2},
            bias=True,
            use_weight_norm=True,
        )
def test_parallel_wavegan_trainable(dict_g, dict_d, dict_loss):
    # setup
    batch_size = 4
    batch_length = 4096
    args_g = make_generator_args(**dict_g)
    args_d = make_discriminator_args(**dict_d)
    args_loss = make_mutli_reso_stft_loss_args(**dict_loss)
    z = torch.randn(batch_size, 1, batch_length)
    y = torch.randn(batch_size, 1, batch_length)
    c = torch.randn(
        batch_size,
        args_g["aux_channels"],
        batch_length // np.prod(args_g["upsample_params"]["upsample_scales"]) +
        2 * args_g["aux_context_window"],
    )
    model_g = ParallelWaveGANGenerator(**args_g)
    model_d = ParallelWaveGANDiscriminator(**args_d)
    aux_criterion = MultiResolutionSTFTLoss(**args_loss)
    gen_adv_criterion = GeneratorAdversarialLoss()
    dis_adv_criterion = DiscriminatorAdversarialLoss()
    optimizer_g = RAdam(model_g.parameters())
    optimizer_d = RAdam(model_d.parameters())

    # check generator trainable
    y_hat = model_g(z, c)
    p_hat = model_d(y_hat)
    adv_loss = gen_adv_criterion(p_hat)
    sc_loss, mag_loss = aux_criterion(y_hat, y)
    aux_loss = sc_loss + mag_loss
    loss_g = adv_loss + aux_loss
    optimizer_g.zero_grad()
    loss_g.backward()
    optimizer_g.step()

    # check discriminator trainable
    p = model_d(y)
    p_hat = model_d(y_hat.detach())
    real_loss, fake_loss = dis_adv_criterion(p_hat, p)
    loss_d = real_loss + fake_loss
    optimizer_d.zero_grad()
    loss_d.backward()
    optimizer_d.step()
def test_parallel_wavegan_trainable(dict_g, dict_d, dict_loss):
    # setup
    batch_size = 4
    batch_length = 4096
    args_g = make_generator_args(**dict_g)
    args_d = make_discriminator_args(**dict_d)
    args_loss = make_mutli_reso_stft_loss_args(**dict_loss)
    z = torch.randn(batch_size, 1, batch_length)
    y = torch.randn(batch_size, 1, batch_length)
    c = torch.randn(
        batch_size, args_g["aux_channels"],
        batch_length // np.prod(args_g["upsample_params"]["upsample_scales"]) +
        2 * args_g["aux_context_window"])
    model_g = ParallelWaveGANGenerator(**args_g)
    model_d = ParallelWaveGANDiscriminator(**args_d)
    aux_criterion = MultiResolutionSTFTLoss(**args_loss)
    optimizer_g = RAdam(model_g.parameters())
    optimizer_d = RAdam(model_d.parameters())

    # check generator trainable
    y_hat = model_g(z, c)
    p_hat = model_d(y_hat)
    y, y_hat, p_hat = y.squeeze(1), y_hat.squeeze(1), p_hat.squeeze(1)
    adv_loss = F.mse_loss(p_hat, p_hat.new_ones(p_hat.size()))
    sc_loss, mag_loss = aux_criterion(y_hat, y)
    aux_loss = sc_loss + mag_loss
    loss_g = adv_loss + aux_loss
    optimizer_g.zero_grad()
    loss_g.backward()
    optimizer_g.step()

    # check discriminator trainable
    y, y_hat = y.unsqueeze(1), y_hat.unsqueeze(1).detach()
    p = model_d(y)
    p_hat = model_d(y_hat)
    p, p_hat = p.squeeze(1), p_hat.squeeze(1)
    loss_d = F.mse_loss(p, p.new_ones(p.size())) + F.mse_loss(
        p_hat, p_hat.new_zeros(p_hat.size()))
    optimizer_d.zero_grad()
    loss_d.backward()
    optimizer_d.step()
Exemplo n.º 5
0
 def _construct_net(self):
     self.grl = GradientReversalLayer(scale=self.conf["spkradv_lambda"])
     self.classifier = ParallelWaveGANDiscriminator(
         in_channels=sum(self.conf["emb_dim"][:self.conf["n_vq_stacks"]]),
         out_channels=self.spkr_size,
         kernel_size=self.conf["spkradv_kernel_size"],
         layers=self.conf["n_spkradv_layers"],
         conv_channels=64,
         dilation_factor=1,
         nonlinear_activation="LeakyReLU",
         nonlinear_activation_params={"negative_slope": 0.2},
         bias=True,
         use_weight_norm=True,
     )
Exemplo n.º 6
0
def get_model(conf, spkr_size=0, device="cuda"):
    G = VQVAE2(conf, spkr_size=spkr_size).to(device)

    # discriminator
    if conf["gan_type"] == "lsgan":
        output_channels = 1
    if conf["acgan_flag"]:
        output_channels += spkr_size

    if conf["discriminator_type"] == "pwg":
        D = ParallelWaveGANDiscriminator(
            in_channels=conf["input_size"],
            out_channels=output_channels,
            kernel_size=conf["kernel_size"][0],
            layers=conf["n_discriminator_layers"],
            conv_channels=64,
            dilation_factor=1,
            nonlinear_activation="LeakyReLU",
            nonlinear_activation_params={"negative_slope": 0.2},
            bias=True,
            use_weight_norm=True,
        )
    return {"G": G.to(device), "D": D.to(device)}
Exemplo n.º 7
0
def get_model(conf, spkr_size=0, device="cuda"):
    models = {"G": VQVAE2(conf, spkr_size=spkr_size).to(device)}
    logging.info(models["G"])

    # speaker adversarial network
    if conf["use_spkradv_training"]:
        SPKRADV = SpeakerAdversarialNetwork(conf, spkr_size)
        models.update({"SPKRADV": SPKRADV.to(device)})
        logging.info(models["SPKRADV"])

    # spkr classifier network
    if conf["use_spkr_classifier"]:
        C = ParallelWaveGANDiscriminator(
            in_channels=conf["input_size"],
            out_channels=spkr_size,
            kernel_size=conf["spkr_classifier_kernel_size"],
            layers=conf["n_spkr_classifier_layers"],
            conv_channels=64,
            dilation_factor=1,
            nonlinear_activation="LeakyReLU",
            nonlinear_activation_params={"negative_slope": 0.2},
            bias=True,
            use_weight_norm=True,
        )
        models.update({"C": C.to(device)})
        logging.info(models["C"])

    # discriminator
    if conf["trainer_type"] in ["lsgan", "cyclegan", "stargan"]:
        input_channels = conf["input_size"]
        if conf["use_D_uv"]:
            input_channels += 1  # for uv flag
        if conf["use_D_spkrcode"]:
            if not conf["use_spkr_embedding"]:
                input_channels += spkr_size
            else:
                input_channels += conf["spkr_embedding_size"]
        if conf["gan_type"] == "lsgan":
            output_channels = 1
        if conf["acgan_flag"]:
            output_channels += spkr_size
        D = ParallelWaveGANDiscriminator(
            in_channels=input_channels,
            out_channels=output_channels,
            kernel_size=conf["discriminator_kernel_size"],
            layers=conf["n_discriminator_layers"],
            conv_channels=64,
            dilation_factor=1,
            nonlinear_activation="LeakyReLU",
            nonlinear_activation_params={"negative_slope": 0.2},
            bias=True,
            use_weight_norm=True,
        )
        models.update({"D": D.to(device)})
        logging.info(models["D"])
    return models
Exemplo n.º 8
0
def main():
    """Run training process."""
    parser = argparse.ArgumentParser(
        description=
        "Train Parallel WaveGAN (See detail in parallel_wavegan/bin/train.py)."
    )
    parser.add_argument("--train-dumpdir",
                        type=str,
                        required=True,
                        help="directory including trainning data.")
    parser.add_argument("--dev-dumpdir",
                        type=str,
                        required=True,
                        help="directory including development data.")
    parser.add_argument("--outdir",
                        type=str,
                        required=True,
                        help="directory to save checkpoints.")
    parser.add_argument("--config",
                        type=str,
                        required=True,
                        help="yaml format configuration file.")
    parser.add_argument(
        "--resume",
        default="",
        type=str,
        nargs="?",
        help="checkpoint file path to resume training. (default=\"\")")
    parser.add_argument(
        "--verbose",
        type=int,
        default=1,
        help="logging level. higher is more logging. (default=1)")
    args = parser.parse_args()

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    else:
        logging.basicConfig(
            level=logging.WARN,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
        logging.warning('skip DEBUG/INFO messages')

    # check directory existence
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    # load and save config
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))
    with open(os.path.join(args.outdir, "config.yml"), "w") as f:
        yaml.dump(config, f, Dumper=yaml.Dumper)
    for key, value in config.items():
        logging.info(f"{key} = {value}")

    # get dataset
    if config["remove_short_samples"]:
        mel_length_threshold = config["batch_max_steps"] // config["hop_size"] + \
            2 * config["generator_params"]["aux_context_window"]
    else:
        mel_length_threshold = None
    if config["format"] == "hdf5":
        audio_query, mel_query = "*.h5", "*.h5"
        audio_load_fn = lambda x: read_hdf5(x, "wave")  # NOQA
        mel_load_fn = lambda x: read_hdf5(x, "feats")  # NOQA
    elif config["format"] == "npy":
        audio_query, mel_query = "*-wave.npy", "*-feats.npy"
        audio_load_fn = np.load
        mel_load_fn = np.load
    else:
        raise ValueError("support only hdf5 or npy format.")
    dataset = {
        "train":
        AudioMelDataset(
            root_dir=args.train_dumpdir,
            audio_query=audio_query,
            mel_query=mel_query,
            audio_load_fn=audio_load_fn,
            mel_load_fn=mel_load_fn,
            mel_length_threshold=mel_length_threshold,
            allow_cache=config.get("allow_cache", False),  # keep compatibilty
        ),
        "dev":
        AudioMelDataset(
            root_dir=args.dev_dumpdir,
            audio_query=audio_query,
            mel_query=mel_query,
            audio_load_fn=audio_load_fn,
            mel_load_fn=mel_load_fn,
            mel_length_threshold=mel_length_threshold,
            allow_cache=config.get("allow_cache", False),  # keep compatibilty
        ),
    }

    # get data loader
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    collater = Collater(
        batch_max_steps=config["batch_max_steps"],
        hop_size=config["hop_size"],
        aux_context_window=config["generator_params"]["aux_context_window"],
    )
    data_loader = {
        "train":
        DataLoader(dataset=dataset["train"],
                   shuffle=True,
                   collate_fn=collater,
                   batch_size=config["batch_size"],
                   num_workers=config["num_workers"],
                   pin_memory=config["pin_memory"]),
        "dev":
        DataLoader(dataset=dataset["dev"],
                   shuffle=True,
                   collate_fn=collater,
                   batch_size=config["batch_size"],
                   num_workers=config["num_workers"],
                   pin_memory=config["pin_memory"]),
    }

    # define models and optimizers
    model = {
        "generator":
        ParallelWaveGANGenerator(**config["generator_params"]).to(device),
        "discriminator":
        ParallelWaveGANDiscriminator(
            **config["discriminator_params"]).to(device),
    }
    criterion = {
        "stft":
        MultiResolutionSTFTLoss(**config["stft_loss_params"]).to(device),
        "mse": torch.nn.MSELoss().to(device),
    }
    optimizer = {
        "generator":
        RAdam(model["generator"].parameters(),
              **config["generator_optimizer_params"]),
        "discriminator":
        RAdam(model["discriminator"].parameters(),
              **config["discriminator_optimizer_params"]),
    }
    scheduler = {
        "generator":
        torch.optim.lr_scheduler.StepLR(
            optimizer=optimizer["generator"],
            **config["generator_scheduler_params"]),
        "discriminator":
        torch.optim.lr_scheduler.StepLR(
            optimizer=optimizer["discriminator"],
            **config["discriminator_scheduler_params"]),
    }
    logging.info(model["generator"])
    logging.info(model["discriminator"])

    # define trainer
    trainer = Trainer(
        steps=0,
        epochs=0,
        data_loader=data_loader,
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        config=config,
        device=device,
    )

    # resume from checkpoint
    if len(args.resume) != 0:
        trainer.load_checkpoint(args.resume)
        logging.info(f"resumed from {args.resume}.")

    # run training loop
    try:
        trainer.run()
    finally:
        trainer.save_checkpoint(
            os.path.join(config["outdir"],
                         f"checkpoint-{trainer.steps}steps.pkl"))
        logging.info(f"successfully saved checkpoint @ {trainer.steps}steps.")
Exemplo n.º 9
0
def main():
    """Run training process."""
    parser = argparse.ArgumentParser(
        description=
        "Train Parallel WaveGAN (See detail in parallel_wavegan/bin/train.py)."
    )
    parser.add_argument("--train-dumpdir",
                        type=str,
                        required=True,
                        help="directory including training data.")
    parser.add_argument("--dev-dumpdir",
                        type=str,
                        required=True,
                        help="directory including development data.")
    parser.add_argument("--outdir",
                        type=str,
                        required=True,
                        help="directory to save checkpoints.")
    parser.add_argument("--config",
                        type=str,
                        required=True,
                        help="yaml format configuration file.")
    parser.add_argument(
        "--resume",
        default="",
        type=str,
        nargs="?",
        help="checkpoint file path to resume training. (default=\"\")")
    parser.add_argument(
        "--verbose",
        type=int,
        default=1,
        help="logging level. higher is more logging. (default=1)")
    parser.add_argument(
        "--rank",
        "--local_rank",
        default=0,
        type=int,
        help="rank for distributed training. no need to explictly specify.")
    args = parser.parse_args()

    args.distributed = False
    if not torch.cuda.is_available():
        device = torch.device("cpu")
    else:
        device = torch.device("cuda")
        # effective when using fixed size inputs
        # see https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
        torch.backends.cudnn.benchmark = True
        torch.cuda.set_device(args.rank)
        # setup for distributed training
        # see example: https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed
        if "WORLD_SIZE" in os.environ:
            args.world_size = int(os.environ["WORLD_SIZE"])
            args.distributed = args.world_size > 1
        if args.distributed:
            torch.distributed.init_process_group(backend="nccl",
                                                 init_method="env://")

    # suppress logging for distributed training
    if args.rank != 0:
        sys.stdout = open(os.devnull, "w")

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            stream=sys.stdout,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            stream=sys.stdout,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    else:
        logging.basicConfig(
            level=logging.WARN,
            stream=sys.stdout,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
        logging.warning("skip DEBUG/INFO messages")

    # check directory existence
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    # load and save config
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))
    config["version"] = parallel_wavegan.__version__  # add version info
    with open(os.path.join(args.outdir, "config.yml"), "w") as f:
        yaml.dump(config, f, Dumper=yaml.Dumper)
    for key, value in config.items():
        logging.info(f"{key} = {value}")

    # get dataset
    if config["remove_short_samples"]:
        mel_length_threshold = config["batch_max_steps"] // config["hop_size"] + \
            2 * config["generator_params"]["aux_context_window"]
    else:
        mel_length_threshold = None
    if config["format"] == "hdf5":
        audio_query, mel_query = "*.h5", "*.h5"
        audio_load_fn = lambda x: read_hdf5(x, "wave")  # NOQA
        mel_load_fn = lambda x: read_hdf5(x, "feats")  # NOQA
    elif config["format"] == "npy":
        audio_query, mel_query = "*-wave.npy", "*-feats.npy"
        audio_load_fn = np.load
        mel_load_fn = np.load
    else:
        raise ValueError("support only hdf5 or npy format.")
    dataset = {
        "train":
        AudioMelDataset(root_dir=args.train_dumpdir,
                        audio_query=audio_query,
                        mel_query=mel_query,
                        audio_load_fn=audio_load_fn,
                        mel_load_fn=mel_load_fn,
                        mel_length_threshold=mel_length_threshold,
                        allow_cache=config.get("allow_cache",
                                               False)),  # keep compatibility
        "dev":
        AudioMelDataset(root_dir=args.dev_dumpdir,
                        audio_query=audio_query,
                        mel_query=mel_query,
                        audio_load_fn=audio_load_fn,
                        mel_load_fn=mel_load_fn,
                        mel_length_threshold=mel_length_threshold,
                        allow_cache=config.get("allow_cache",
                                               False)),  # keep compatibility
    }

    # get data loader
    collater = Collater(
        batch_max_steps=config["batch_max_steps"],
        hop_size=config["hop_size"],
        aux_context_window=config["generator_params"]["aux_context_window"],
    )
    train_sampler, dev_sampler = None, None
    if args.distributed:
        # setup sampler for distributed training
        from torch.utils.data.distributed import DistributedSampler
        train_sampler = DistributedSampler(dataset=dataset["train"],
                                           num_replicas=args.world_size,
                                           rank=args.rank,
                                           shuffle=True)
        dev_sampler = DistributedSampler(dataset=dataset["dev"],
                                         num_replicas=args.world_size,
                                         rank=args.rank,
                                         shuffle=False)
    data_loader = {
        "train":
        DataLoader(dataset=dataset["train"],
                   shuffle=False if args.distributed else True,
                   collate_fn=collater,
                   batch_size=config["batch_size"],
                   num_workers=config["num_workers"],
                   sampler=train_sampler,
                   pin_memory=config["pin_memory"]),
        "dev":
        DataLoader(dataset=dataset["dev"],
                   shuffle=False if args.distributed else True,
                   collate_fn=collater,
                   batch_size=config["batch_size"],
                   num_workers=config["num_workers"],
                   sampler=dev_sampler,
                   pin_memory=config["pin_memory"]),
    }

    # define models and optimizers
    model = {
        "generator":
        ParallelWaveGANGenerator(**config["generator_params"]).to(device),
        "discriminator":
        ParallelWaveGANDiscriminator(
            **config["discriminator_params"]).to(device),
    }
    criterion = {
        "stft":
        MultiResolutionSTFTLoss(**config["stft_loss_params"]).to(device),
        "mse": torch.nn.MSELoss().to(device),
    }
    optimizer = {
        "generator":
        RAdam(model["generator"].parameters(),
              **config["generator_optimizer_params"]),
        "discriminator":
        RAdam(model["discriminator"].parameters(),
              **config["discriminator_optimizer_params"]),
    }
    scheduler = {
        "generator":
        torch.optim.lr_scheduler.StepLR(
            optimizer=optimizer["generator"],
            **config["generator_scheduler_params"]),
        "discriminator":
        torch.optim.lr_scheduler.StepLR(
            optimizer=optimizer["discriminator"],
            **config["discriminator_scheduler_params"]),
    }
    if args.distributed:
        # wrap model for distributed training
        try:
            from apex.parallel import DistributedDataParallel
        except ImportError:
            raise ImportError(
                "apex is not installed. please check https://github.com/NVIDIA/apex."
            )
        model["generator"] = DistributedDataParallel(model["generator"])
        model["discriminator"] = DistributedDataParallel(
            model["discriminator"])
    logging.info(model["generator"])
    logging.info(model["discriminator"])

    # define trainer
    trainer = Trainer(
        steps=0,
        epochs=0,
        data_loader=data_loader,
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        config=config,
        device=device,
    )

    # resume from checkpoint
    if len(args.resume) != 0:
        trainer.load_checkpoint(args.resume)
        logging.info(f"successfully resumed from {args.resume}.")

    # run training loop
    try:
        trainer.run()
    except KeyboardInterrupt:
        trainer.save_checkpoint(
            os.path.join(config["outdir"],
                         f"checkpoint-{trainer.steps}steps.pkl"))
        logging.info(f"successfully saved checkpoint @ {trainer.steps}steps.")