예제 #1
0
def test_multi_speaker_vits_is_trainable_and_decodable_on_gpu(
        gen_dict, dis_dict, loss_dict, spks, spk_embed_dim, langs):
    idim = 10
    odim = 5
    global_channels = 8
    gen_args = make_vits_generator_args(**gen_dict)
    gen_args["generator_params"]["spks"] = spks
    gen_args["generator_params"]["langs"] = langs
    gen_args["generator_params"]["spk_embed_dim"] = spk_embed_dim
    gen_args["generator_params"]["global_channels"] = global_channels
    dis_args = make_vits_discriminator_args(**dis_dict)
    loss_args = make_vits_loss_args(**loss_dict)
    model = VITS(
        idim=idim,
        odim=odim,
        **gen_args,
        **dis_args,
        **loss_args,
    )
    model.train()
    upsample_factor = model.generator.upsample_factor
    inputs = dict(
        text=torch.randint(0, idim, (2, 8)),
        text_lengths=torch.tensor([8, 5], dtype=torch.long),
        feats=torch.randn(2, 16, odim),
        feats_lengths=torch.tensor([16, 13], dtype=torch.long),
        speech=torch.randn(2, 16 * upsample_factor),
        speech_lengths=torch.tensor([16, 13] * upsample_factor,
                                    dtype=torch.long),
    )
    if spks > 0:
        inputs["sids"] = torch.randint(0, spks, (2, 1))
    if langs > 0:
        inputs["lids"] = torch.randint(0, langs, (2, 1))
    if spk_embed_dim > 0:
        inputs["spembs"] = torch.randn(2, spk_embed_dim)
    device = torch.device("cuda")
    model.to(device)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    gen_loss = model(forward_generator=True, **inputs)["loss"]
    gen_loss.backward()
    dis_loss = model(forward_generator=False, **inputs)["loss"]
    dis_loss.backward()

    with torch.no_grad():
        model.eval()

        # check inference
        inputs = dict(text=torch.randint(
            0,
            idim,
            (5, ),
        ), )
        if spks > 0:
            inputs["sids"] = torch.randint(0, spks, (1, ))
        if langs > 0:
            inputs["lids"] = torch.randint(0, langs, (1, ))
        if spk_embed_dim > 0:
            inputs["spembs"] = torch.randn(spk_embed_dim)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        model.inference(**inputs)

        # check inference with predefined duration
        inputs = dict(
            text=torch.randint(
                0,
                idim,
                (5, ),
            ),
            durations=torch.tensor([1, 2, 3, 4, 5], dtype=torch.long),
        )
        if spks > 0:
            inputs["sids"] = torch.randint(0, spks, (1, ))
        if langs > 0:
            inputs["lids"] = torch.randint(0, langs, (1, ))
        if spk_embed_dim > 0:
            inputs["spembs"] = torch.randn(spk_embed_dim)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        output_dict = model.inference(**inputs)
        assert output_dict["wav"].size(
            0) == inputs["durations"].sum() * upsample_factor

        # check inference with teachder forcing
        inputs = dict(
            text=torch.randint(
                0,
                idim,
                (5, ),
            ),
            feats=torch.randn(16, odim),
        )
        if spks > 0:
            inputs["sids"] = torch.randint(0, spks, (1, ))
        if langs > 0:
            inputs["lids"] = torch.randint(0, langs, (1, ))
        if spk_embed_dim > 0:
            inputs["spembs"] = torch.randn(spk_embed_dim)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        output_dict = model.inference(**inputs, use_teacher_forcing=True)
        assert output_dict["wav"].size(
            0) == inputs["feats"].size(0) * upsample_factor
예제 #2
0
def test_vits_is_trainable_and_decodable(gen_dict, dis_dict, loss_dict):
    idim = 10
    odim = 5
    gen_args = make_vits_generator_args(**gen_dict)
    dis_args = make_vits_discriminator_args(**dis_dict)
    loss_args = make_vits_loss_args(**loss_dict)
    model = VITS(
        idim=idim,
        odim=odim,
        **gen_args,
        **dis_args,
        **loss_args,
    )
    model.train()
    upsample_factor = model.generator.upsample_factor
    inputs = dict(
        text=torch.randint(0, idim, (2, 8)),
        text_lengths=torch.tensor([8, 5], dtype=torch.long),
        feats=torch.randn(2, 16, odim),
        feats_lengths=torch.tensor([16, 13], dtype=torch.long),
        speech=torch.randn(2, 16 * upsample_factor),
        speech_lengths=torch.tensor([16, 13] * upsample_factor,
                                    dtype=torch.long),
    )
    gen_loss = model(forward_generator=True, **inputs)["loss"]
    gen_loss.backward()
    dis_loss = model(forward_generator=False, **inputs)["loss"]
    dis_loss.backward()

    with torch.no_grad():
        model.eval()

        # check inference
        inputs = dict(text=torch.randint(
            0,
            idim,
            (5, ),
        ))
        model.inference(**inputs)

        # check inference with predefined durations
        inputs = dict(
            text=torch.randint(
                0,
                idim,
                (5, ),
            ),
            durations=torch.tensor([1, 2, 3, 4, 5], dtype=torch.long),
        )
        output_dict = model.inference(**inputs)
        assert output_dict["wav"].size(
            0) == inputs["durations"].sum() * upsample_factor

        # check inference with teachder forcing
        inputs = dict(
            text=torch.randint(
                0,
                idim,
                (5, ),
            ),
            feats=torch.randn(16, odim),
        )
        output_dict = model.inference(**inputs, use_teacher_forcing=True)
        assert output_dict["wav"].size(
            0) == inputs["feats"].size(0) * upsample_factor