コード例 #1
0
def test_parallel_wavegan_with_residual_discriminator_trainable(
        dict_g, dict_d, dict_loss):
    # setup
    batch_size = 4
    batch_length = 4096
    args_g = make_generator_args(**dict_g)
    args_d = make_residual_discriminator_args(**dict_d)
    args_loss = make_mutli_reso_stft_loss_args(**dict_loss)
    z = torch.randn(batch_size, 1, batch_length)
    y = torch.randn(batch_size, 1, batch_length)
    c = torch.randn(
        batch_size,
        args_g["aux_channels"],
        batch_length // np.prod(args_g["upsample_params"]["upsample_scales"]) +
        2 * args_g["aux_context_window"],
    )
    model_g = ParallelWaveGANGenerator(**args_g)
    model_d = ResidualParallelWaveGANDiscriminator(**args_d)
    aux_criterion = MultiResolutionSTFTLoss(**args_loss)
    gen_adv_criterion = GeneratorAdversarialLoss()
    dis_adv_criterion = DiscriminatorAdversarialLoss()
    optimizer_g = RAdam(model_g.parameters())
    optimizer_d = RAdam(model_d.parameters())

    # check generator trainable
    y_hat = model_g(z, c)
    p_hat = model_d(y_hat)
    adv_loss = gen_adv_criterion(p_hat)
    sc_loss, mag_loss = aux_criterion(y_hat, y)
    aux_loss = sc_loss + mag_loss
    loss_g = adv_loss + aux_loss
    optimizer_g.zero_grad()
    loss_g.backward()
    optimizer_g.step()

    # check discriminator trainable
    p = model_d(y)
    p_hat = model_d(y_hat.detach())
    real_loss, fake_loss = dis_adv_criterion(p_hat, p)
    loss_d = real_loss + fake_loss
    optimizer_d.zero_grad()
    loss_d.backward()
    optimizer_d.step()
コード例 #2
0
def test_parallel_wavegan_with_residual_discriminator_trainable(dict_g, dict_d, dict_loss):
    # setup
    batch_size = 4
    batch_length = 4096
    args_g = make_generator_args(**dict_g)
    args_d = make_residual_discriminator_args(**dict_d)
    args_loss = make_mutli_reso_stft_loss_args(**dict_loss)
    z = torch.randn(batch_size, 1, batch_length)
    y = torch.randn(batch_size, 1, batch_length)
    c = torch.randn(batch_size, args_g["aux_channels"],
                    batch_length // np.prod(
                        args_g["upsample_params"]["upsample_scales"]) + 2 * args_g["aux_context_window"])
    model_g = ParallelWaveGANGenerator(**args_g)
    model_d = ResidualParallelWaveGANDiscriminator(**args_d)
    aux_criterion = MultiResolutionSTFTLoss(**args_loss)
    optimizer_g = RAdam(model_g.parameters())
    optimizer_d = RAdam(model_d.parameters())

    # check generator trainable
    y_hat = model_g(z, c)
    p_hat = model_d(y_hat)
    y, y_hat, p_hat = y.squeeze(1), y_hat.squeeze(1), p_hat.squeeze(1)
    adv_loss = F.mse_loss(p_hat, p_hat.new_ones(p_hat.size()))
    sc_loss, mag_loss = aux_criterion(y_hat, y)
    aux_loss = sc_loss + mag_loss
    loss_g = adv_loss + aux_loss
    optimizer_g.zero_grad()
    loss_g.backward()
    optimizer_g.step()

    # check discriminator trainable
    y, y_hat = y.unsqueeze(1), y_hat.unsqueeze(1).detach()
    p = model_d(y)
    p_hat = model_d(y_hat)
    p, p_hat = p.squeeze(1), p_hat.squeeze(1)
    loss_d = F.mse_loss(p, p.new_ones(p.size())) + F.mse_loss(p_hat, p_hat.new_zeros(p_hat.size()))
    optimizer_d.zero_grad()
    loss_d.backward()
    optimizer_d.step()
コード例 #3
0
def get_model(conf, spkr_size=0, device="cuda"):
    models = {"G": VQVAE2(conf, spkr_size=spkr_size).to(device)}
    logging.info(models["G"])

    # speaker adversarial network
    if conf["use_spkradv_training"]:
        SPKRADV = SpeakerAdversarialNetwork(conf, spkr_size)
        models.update({"SPKRADV": SPKRADV.to(device)})
        logging.info(models["SPKRADV"])

    # spkr classifier network
    if conf["use_spkr_classifier"]:
        # TODO(k2kobayashi): investigate peformance of residual network
        # if conf["use_residual_network"]:
        #     C = ResidualParallelWaveGANDiscriminator(
        #         in_channels=conf["input_size"],
        #         out_channels=spkr_size,
        #         kernel_size=conf["spkr_classifier_kernel_size"],
        #         layers=conf["n_spkr_classifier_layers"],
        #         stacks=conf["n_spkr_classifier_layers"],
        #     )
        # else:
        C = ParallelWaveGANDiscriminator(
            in_channels=conf["input_size"],
            out_channels=spkr_size,
            kernel_size=conf["spkr_classifier_kernel_size"],
            layers=conf["n_spkr_classifier_layers"],
            conv_channels=64,
            dilation_factor=1,
            nonlinear_activation="LeakyReLU",
            nonlinear_activation_params={"negative_slope": 0.2},
            bias=True,
            use_weight_norm=True,
        )
        models.update({"C": C.to(device)})
        logging.info(models["C"])

    # discriminator
    if conf["trainer_type"] in ["lsgan", "cyclegan", "stargan"]:
        input_channels = conf["input_size"]
        if conf["use_D_uv"]:
            input_channels += 1  # for uv flag
        if conf["use_D_spkrcode"]:
            if not conf["use_spkr_embedding"]:
                input_channels += spkr_size
            else:
                input_channels += conf["spkr_embedding_size"]
        if conf["gan_type"] == "lsgan":
            output_channels = 1
        if conf["acgan_flag"]:
            output_channels += spkr_size
        if conf["use_residual_network"]:
            D = ResidualParallelWaveGANDiscriminator(
                in_channels=input_channels,
                out_channels=output_channels,
                kernel_size=conf["discriminator_kernel_size"],
                layers=conf["n_discriminator_layers"] *
                conf["n_discriminator_stacks"],
                stacks=conf["n_discriminator_stacks"],
                dropout=conf["discriminator_dropout"],
            )
        else:
            D = ParallelWaveGANDiscriminator(
                in_channels=input_channels,
                out_channels=output_channels,
                kernel_size=conf["discriminator_kernel_size"],
                layers=conf["n_discriminator_layers"] *
                ["n_discriminator_stacks"],
                conv_channels=64,
                dilation_factor=1,
                nonlinear_activation="LeakyReLU",
                nonlinear_activation_params={"negative_slope": 0.2},
                bias=True,
                use_weight_norm=True,
            )
        models.update({"D": D.to(device)})
        logging.info(models["D"])
    return models