def test_multi_speaker_vits_is_trainable_and_decodable_on_gpu( gen_dict, dis_dict, loss_dict, spks, spk_embed_dim, langs): idim = 10 odim = 5 global_channels = 8 gen_args = make_vits_generator_args(**gen_dict) gen_args["generator_params"]["spks"] = spks gen_args["generator_params"]["langs"] = langs gen_args["generator_params"]["spk_embed_dim"] = spk_embed_dim gen_args["generator_params"]["global_channels"] = global_channels dis_args = make_vits_discriminator_args(**dis_dict) loss_args = make_vits_loss_args(**loss_dict) model = VITS( idim=idim, odim=odim, **gen_args, **dis_args, **loss_args, ) model.train() upsample_factor = model.generator.upsample_factor inputs = dict( text=torch.randint(0, idim, (2, 8)), text_lengths=torch.tensor([8, 5], dtype=torch.long), feats=torch.randn(2, 16, odim), feats_lengths=torch.tensor([16, 13], dtype=torch.long), speech=torch.randn(2, 16 * upsample_factor), speech_lengths=torch.tensor([16, 13] * upsample_factor, dtype=torch.long), ) if spks > 0: inputs["sids"] = torch.randint(0, spks, (2, 1)) if langs > 0: inputs["lids"] = torch.randint(0, langs, (2, 1)) if spk_embed_dim > 0: inputs["spembs"] = torch.randn(2, spk_embed_dim) device = torch.device("cuda") model.to(device) inputs = {k: v.to(device) for k, v in inputs.items()} gen_loss = model(forward_generator=True, **inputs)["loss"] gen_loss.backward() dis_loss = model(forward_generator=False, **inputs)["loss"] dis_loss.backward() with torch.no_grad(): model.eval() # check inference inputs = dict(text=torch.randint( 0, idim, (5, ), ), ) if spks > 0: inputs["sids"] = torch.randint(0, spks, (1, )) if langs > 0: inputs["lids"] = torch.randint(0, langs, (1, )) if spk_embed_dim > 0: inputs["spembs"] = torch.randn(spk_embed_dim) inputs = {k: v.to(device) for k, v in inputs.items()} model.inference(**inputs) # check inference with predefined duration inputs = dict( text=torch.randint( 0, idim, (5, ), ), durations=torch.tensor([1, 2, 3, 4, 5], dtype=torch.long), ) if spks > 0: inputs["sids"] = torch.randint(0, spks, (1, )) if langs > 0: inputs["lids"] = torch.randint(0, langs, (1, )) if spk_embed_dim > 0: inputs["spembs"] = torch.randn(spk_embed_dim) inputs = {k: v.to(device) for k, v in inputs.items()} output_dict = model.inference(**inputs) assert output_dict["wav"].size( 0) == inputs["durations"].sum() * upsample_factor # check inference with teachder forcing inputs = dict( text=torch.randint( 0, idim, (5, ), ), feats=torch.randn(16, odim), ) if spks > 0: inputs["sids"] = torch.randint(0, spks, (1, )) if langs > 0: inputs["lids"] = torch.randint(0, langs, (1, )) if spk_embed_dim > 0: inputs["spembs"] = torch.randn(spk_embed_dim) inputs = {k: v.to(device) for k, v in inputs.items()} output_dict = model.inference(**inputs, use_teacher_forcing=True) assert output_dict["wav"].size( 0) == inputs["feats"].size(0) * upsample_factor
def test_vits_is_trainable_and_decodable(gen_dict, dis_dict, loss_dict): idim = 10 odim = 5 gen_args = make_vits_generator_args(**gen_dict) dis_args = make_vits_discriminator_args(**dis_dict) loss_args = make_vits_loss_args(**loss_dict) model = VITS( idim=idim, odim=odim, **gen_args, **dis_args, **loss_args, ) model.train() upsample_factor = model.generator.upsample_factor inputs = dict( text=torch.randint(0, idim, (2, 8)), text_lengths=torch.tensor([8, 5], dtype=torch.long), feats=torch.randn(2, 16, odim), feats_lengths=torch.tensor([16, 13], dtype=torch.long), speech=torch.randn(2, 16 * upsample_factor), speech_lengths=torch.tensor([16, 13] * upsample_factor, dtype=torch.long), ) gen_loss = model(forward_generator=True, **inputs)["loss"] gen_loss.backward() dis_loss = model(forward_generator=False, **inputs)["loss"] dis_loss.backward() with torch.no_grad(): model.eval() # check inference inputs = dict(text=torch.randint( 0, idim, (5, ), )) model.inference(**inputs) # check inference with predefined durations inputs = dict( text=torch.randint( 0, idim, (5, ), ), durations=torch.tensor([1, 2, 3, 4, 5], dtype=torch.long), ) output_dict = model.inference(**inputs) assert output_dict["wav"].size( 0) == inputs["durations"].sum() * upsample_factor # check inference with teachder forcing inputs = dict( text=torch.randint( 0, idim, (5, ), ), feats=torch.randn(16, odim), ) output_dict = model.inference(**inputs, use_teacher_forcing=True) assert output_dict["wav"].size( 0) == inputs["feats"].size(0) * upsample_factor