shuffle=True,
                                              num_workers=hparams.num_workers)

    val_data_loader = data_utils.DataLoader(val_dataset,
                                            batch_size=hparams.batch_size,
                                            num_workers=4)

    test_data_loader = data_utils.DataLoader(test_dataset,
                                             batch_size=hparams.batch_size,
                                             num_workers=4)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Model
    model = Wav2Lip()
    disc = Wav2Lip_disc_qual()
    model.to(device)
    disc.to(device)

    print('total trainable params {}'.format(
        sum(p.numel() for p in model.parameters() if p.requires_grad)))
    print('total DISC trainable params {}'.format(
        sum(p.numel() for p in disc.parameters() if p.requires_grad)))

    optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
                           lr=hparams.initial_learning_rate,
                           betas=(0.5, 0.999))
    disc_optimizer = optim.Adam(
        [p for p in disc.parameters() if p.requires_grad],
        lr=hparams.disc_initial_learning_rate,
        betas=(0.5, 0.999))
Example #2
0
    train_data_loader = data_utils.DataLoader(
        train_dataset, batch_size=hparams.batch_size, shuffle=True,
        num_workers=hparams.num_workers, prefetch_factor = 8 , pin_memory = True)

    test_data_loader = data_utils.DataLoader(
        test_dataset, batch_size=hparams.batch_size,
        num_workers=hparams.num_workers , prefetch_factor = 8 , pin_memory = True)



    device = torch.device("cuda" if use_cuda else "cpu")

     # Model
    model = Wav2Lip().to(device)
    disc = Wav2Lip_disc_qual().to(device)

    print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))
    print('total DISC trainable params {}'.format(sum(p.numel() for p in disc.parameters() if p.requires_grad)))

    optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
                           lr=hparams.initial_learning_rate, betas=(0.5, 0.999))
    disc_optimizer = optim.Adam([p for p in disc.parameters() if p.requires_grad],
                           lr=hparams.disc_initial_learning_rate, betas=(0.5, 0.999))

    if args.checkpoint_path is not None:
        load_checkpoint(args.checkpoint_path, model, optimizer, reset_optimizer=False)

    if args.disc_checkpoint_path is not None:
        load_checkpoint(args.disc_checkpoint_path, disc, disc_optimizer, 
                                reset_optimizer=False, overwrite_global_states=False)