Exemplo n.º 1
0
def load_model(path):
    model = Wav2Lip()
    print("Load checkpoint from: {}".format(path))
    model.load_state_dict(
        torch.load(path, map_location=lambda storage, loc: storage))

    model = model.to(device)
    return model.eval()
Exemplo n.º 2
0
def load_model(path, device):
    model = Wav2Lip()
    print("Load checkpoint from: {}".format(path))
    checkpoint = __load__(path, device)
    s = checkpoint["state_dict"]
    new_s = {}
    for k, v in s.items():
        new_s[k.replace('module.', '')] = v
    model.load_state_dict(new_s)
    model = model.to(device)
    return model.eval()
Exemplo n.º 3
0
    train_dataset = Dataset('train')
    test_dataset = Dataset('val')

    train_data_loader = data_utils.DataLoader(train_dataset,
                                              batch_size=hparams.batch_size,
                                              shuffle=True,
                                              num_workers=hparams.num_workers)

    test_data_loader = data_utils.DataLoader(test_dataset,
                                             batch_size=hparams.batch_size,
                                             num_workers=4)

    device = torch.device("cuda" if use_cuda else "cpu")

    # Model
    model = Wav2Lip().to(device)
    print('total trainable params {}'.format(
        sum(p.numel() for p in model.parameters() if p.requires_grad)))

    optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
                           lr=hparams.initial_learning_rate)

    if args.checkpoint_path is not None:
        load_checkpoint(args.checkpoint_path,
                        model,
                        optimizer,
                        reset_optimizer=False)

    load_checkpoint(args.syncnet_checkpoint_path,
                    syncnet,
                    None,
Exemplo n.º 4
0
                                              batch_size=hparams.batch_size,
                                              shuffle=True,
                                              num_workers=hparams.num_workers)

    val_data_loader = data_utils.DataLoader(val_dataset,
                                            batch_size=hparams.batch_size,
                                            num_workers=4)

    test_data_loader = data_utils.DataLoader(val_dataset,
                                             batch_size=hparams.batch_size,
                                             num_workers=4)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Model
    model = Wav2Lip()
    disc = SeqNetDisc()
    model.to(device)
    disc.to(device)

    print('total trainable params {}'.format(
        sum(p.numel() for p in model.parameters() if p.requires_grad)))
    print('total DISC trainable params {}'.format(
        sum(p.numel() for p in disc.parameters() if p.requires_grad)))

    optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
                           lr=hparams.initial_learning_rate,
                           betas=(0.5, 0.999))
    disc_optimizer = optim.Adam(
        [p for p in disc.parameters() if p.requires_grad],
        lr=hparams.disc_initial_learning_rate,
Exemplo n.º 5
0
    # Dataset and Dataloader setup
    train_dataset_left = Dataset('train')  # Left stereo image training dataset
    test_dataset_left = Dataset('val')  # Left Stereo image validation dataset

    train_data_loader_Left = data_utils.DataLoader(
        train_dataset_left,
        batch_size=hparams.batch_size,
        shuffle=True,
        num_workers=hparams.num_workers)

    test_data_loader_Left = data_utils.DataLoader(
        test_dataset_left, batch_size=hparams.batch_size, num_workers=4)

    device = torch.device("cuda" if use_cuda else "cpu")

    model_Wav2Lip = Wav2Lip().to(device)
    print('total trainable params {}'.format(
        sum(p.numel() for p in model_Wav2Lip.parameters() if p.requires_grad)))

    optimizer = optim.Adam(
        [p for p in model_Wav2Lip.parameters() if p.requires_grad],
        lr=hparams.initial_learning_rate)

    if args.checkpoint_path is not None:
        load_checkpoint(args.checkpoint_path,
                        model_Wav2Lip,
                        optimizer,
                        reset_optimizer=False)

    load_checkpoint(args.syncnet_checkpoint_path,
                    syncnet,