コード例 #1
0
    device = 'cuda'
    torch.backends.cudnn.benchmark = True

    transform = transforms.Compose(
        [
            transforms.Resize(args.size),
            transforms.CenterCrop(args.size),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
        ]
    )

    loader_train, loader_val, _ = \
        iPERLoader(data_root=args.path, batch=args.batch_size, transform=transform).data_load()

    model = VQVAE_SPADE(embed_dim=128, parser=parser).to(device)
    model = nn.DataParallel(model).cuda()
    print('Loading Model_SPADE...', end='')
    model.load_state_dict(torch.load('/p300/mem/mem_src/SPADE/checkpoint/as_101/vqvae_072.pt'))
    model.eval()
    print('Complete !')

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = None

    model_cond = poseVQVAE().to(device)
    model_cond = nn.DataParallel(model_cond).cuda()
    print('Loading Model_condition...', end='')
    model_cond.load_state_dict(torch.load('/p300/mem/mem_src/checkpoint/pose_06_black/vqvae_016.pt'))
    model_cond.eval()
    print('Complete !')
コード例 #2
0
    device = 'cuda'
    torch.backends.cudnn.benchmark = True

    transform = transforms.Compose([
        transforms.Resize(args.size),
        transforms.CenterCrop(args.size),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

    _, _, loader = iPERLoader(data_root=args.path,
                              batch=args.batch_size,
                              transform=transform).data_load()

    model = VQVAE_SPADE(embed_dim=128, parser=parser).to(device)
    model = nn.DataParallel(model).cuda()
    # print('Loading Model...', end='')
    # model.load_state_dict(torch.load('/p300/mem/mem_src/SPADE/checkpoint/app_v04/vqvae_089.pt'))
    # model.eval()
    # print('Complete !')

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = None

    model_cond = poseVQVAE().to(device)
    model_cond = nn.DataParallel(model_cond).cuda()
    print('Loading Model...', end='')
    model_cond.load_state_dict(
        torch.load('/p300/mem/mem_src/checkpoint/pose_04/vqvae_462.pt'))
    model_cond.eval()
コード例 #3
0
    device = 'cuda'
    torch.backends.cudnn.benchmark = True

    transform = transforms.Compose([
        transforms.Resize(args.size),
        transforms.CenterCrop(args.size),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

    _, loader, _ = iPERLoader(data_root=args.path,
                              batch=args.batch_size,
                              transform=transform).data_load()

    model = VQVAE_SPADE(embed_dim=128, parser=parser).to(device)
    model = nn.DataParallel(model).cuda()
    print('Loading Model...', end='')
    model.load_state_dict(
        torch.load('/p300/mem/mem_src/SPADE/checkpoint/as_82/vqvae_244.pt'))
    model.eval()
    print('Complete !')

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = None

    model_cond = poseVQVAE().to(device)
    model_cond = nn.DataParallel(model_cond).cuda()
    print('Loading Model...', end='')
    model_cond.load_state_dict(
        torch.load('/p300/mem/mem_src/checkpoint/pose_04/vqvae_462.pt'))
コード例 #4
0
    device = 'cuda'
    torch.backends.cudnn.benchmark = True

    transform = transforms.Compose(
        [
            transforms.Resize(args.size),
            transforms.CenterCrop(args.size),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
        ]
    )

    _, loader, _ = iPERLoader(data_root=args.path, batch=args.batch_size, transform=transform).data_load()

    model = VQVAE_SPADE(embed_dim=128, parser=parser).to(device)
    model = nn.DataParallel(model).cuda()
    # print('Loading Model...', end='')
    # model.load_state_dict(torch.load('/p300/mem/mem_src/SPADE/checkpoint/app_v04/vqvae_089.pt'))
    # model.eval()
    # print('Complete !')

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = None

    model_cond = poseVQVAE().to(device)
    model_cond = nn.DataParallel(model_cond).cuda()
    print('Loading Model...', end='')
    model_cond.load_state_dict(torch.load('/p300/mem/mem_src/checkpoint/pose_04/vqvae_462.pt'))
    model_cond.eval()
    print('Complete !')
コード例 #5
0
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    device = 'cuda'
    torch.backends.cudnn.benchmark = True

    transform = transforms.Compose([
        transforms.Resize(args.size),
        transforms.CenterCrop(args.size),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

    loader_train, loader_val, _ = \
        iPERLoader(data_root=args.path, batch=args.batch_size, transform=transform).data_load()

    model = VQVAE_SPADE(embed_dim=128, parser=parser).to(device)
    model = nn.DataParallel(model).cuda()
    print('Loading Model...', end='')
    model.load_state_dict(
        torch.load('/p300/mem/mem_src/SPADE/checkpoint/as_115/vqvae_014.pt'))
    model.eval()
    print('Complete !')
    # optimizer = optim.Adam(model.parameters(), lr=args.lr)
    optimizer = build_optimizer(model, lr=args.lr)
    scheduler = None

    model_cond = poseVQVAE().to(device)
    model_cond = nn.DataParallel(model_cond).cuda()
    print('Loading Model_condition...', end='')
    model_cond.load_state_dict(
        torch.load('/p300/mem/mem_src/checkpoint/pose_06_black/vqvae_120.pt'))