Ejemplo n.º 1
0
def get_model(conf, spkr_size=0, device="cuda"):
    models = {"G": VQVAE2(conf, spkr_size=spkr_size).to(device)}
    logging.info(models["G"])

    # speaker adversarial network
    if conf["use_spkradv_training"]:
        SPKRADV = SpeakerAdversarialNetwork(conf, spkr_size)
        models.update({"SPKRADV": SPKRADV.to(device)})
        logging.info(models["SPKRADV"])

    # spkr classifier network
    if conf["use_spkr_classifier"]:
        C = ParallelWaveGANDiscriminator(
            in_channels=conf["input_size"],
            out_channels=spkr_size,
            kernel_size=conf["spkr_classifier_kernel_size"],
            layers=conf["n_spkr_classifier_layers"],
            conv_channels=64,
            dilation_factor=1,
            nonlinear_activation="LeakyReLU",
            nonlinear_activation_params={"negative_slope": 0.2},
            bias=True,
            use_weight_norm=True,
        )
        models.update({"C": C.to(device)})
        logging.info(models["C"])

    # discriminator
    if conf["trainer_type"] in ["lsgan", "cyclegan", "stargan"]:
        input_channels = conf["input_size"]
        if conf["use_D_uv"]:
            input_channels += 1  # for uv flag
        if conf["use_D_spkrcode"]:
            if not conf["use_spkr_embedding"]:
                input_channels += spkr_size
            else:
                input_channels += conf["spkr_embedding_size"]
        if conf["gan_type"] == "lsgan":
            output_channels = 1
        if conf["acgan_flag"]:
            output_channels += spkr_size
        D = ParallelWaveGANDiscriminator(
            in_channels=input_channels,
            out_channels=output_channels,
            kernel_size=conf["discriminator_kernel_size"],
            layers=conf["n_discriminator_layers"],
            conv_channels=64,
            dilation_factor=1,
            nonlinear_activation="LeakyReLU",
            nonlinear_activation_params={"negative_slope": 0.2},
            bias=True,
            use_weight_norm=True,
        )
        models.update({"D": D.to(device)})
        logging.info(models["D"])
    return models
Ejemplo n.º 2
0
def get_model(conf, spkr_size=0, device="cuda"):
    models = {"G": VQVAE2(conf, spkr_size=spkr_size).to(device)}
    logging.info(models["G"])

    # discriminator
    if conf["gan_type"] == "lsgan":
        output_channels = 1
    if conf["acgan_flag"]:
        output_channels += spkr_size
    if conf["trainer_type"] in ["lsgan", "cyclegan"]:
        if conf["discriminator_type"] == "pwg":
            D = ParallelWaveGANDiscriminator(
                in_channels=conf["input_size"],
                out_channels=output_channels,
                kernel_size=conf["discriminator_kernel_size"],
                layers=conf["n_discriminator_layers"],
                conv_channels=64,
                dilation_factor=1,
                nonlinear_activation="LeakyReLU",
                nonlinear_activation_params={"negative_slope": 0.2},
                bias=True,
                use_weight_norm=True,
            )
        else:
            raise NotImplementedError()
        models.update({"D": D.to(device)})
        logging.info(models["D"])

    if conf["speaker_adversarial"]:
        SPKRADV = SpeakerAdversarialNetwork(conf, spkr_size)
        models.update({"SPKRADV": SPKRADV.to(device)})
        logging.info(models["SPKRADV"])
    return models
Ejemplo n.º 3
0
def get_model(conf, spkr_size=0, device="cuda"):
    G = VQVAE2(conf, spkr_size=spkr_size).to(device)

    # discriminator
    if conf["gan_type"] == "lsgan":
        output_channels = 1
    if conf["acgan_flag"]:
        output_channels += spkr_size

    if conf["discriminator_type"] == "pwg":
        D = ParallelWaveGANDiscriminator(
            in_channels=conf["input_size"],
            out_channels=output_channels,
            kernel_size=conf["kernel_size"][0],
            layers=conf["n_discriminator_layers"],
            conv_channels=64,
            dilation_factor=1,
            nonlinear_activation="LeakyReLU",
            nonlinear_activation_params={"negative_slope": 0.2},
            bias=True,
            use_weight_norm=True,
        )
    return {"G": G.to(device), "D": D.to(device)}