def __init__(self, identity_embedding_size, average_function):
        super().__init__()

        self.identity_embedding_size = identity_embedding_size

        import torchvision
        self.identity_encoder = torchvision.models.resnext50_32x4d(
            num_classes=identity_embedding_size)

        import sys
        X2FACE_ROOT_DIR = "embedders/X2Face"
        sys.path.append(f"{X2FACE_ROOT_DIR}/UnwrapMosaic/")
        try:
            from UnwrappedFace import UnwrappedFaceWeightedAverage
            state_dict = torch.load(
                f"{X2FACE_ROOT_DIR}/models/x2face_model_forpython3.pth",
                map_location='cpu')
        except (ImportError, FileNotFoundError):
            logger.critical(
                f"Please initialize submodules, then download 'x2face_model_forpython3.pth' from "
                f"http://www.robots.ox.ac.uk/~vgg/research/unsup_learn_watch_faces/release_x2face_eccv_withpy3.zip"
                f" and put it into {X2FACE_ROOT_DIR}/models/")
            raise

        self.pose_encoder = UnwrappedFaceWeightedAverage(output_num_channels=2,
                                                         input_num_channels=3,
                                                         inner_nc=128,
                                                         sampler_only=True)
        self.pose_encoder.load_state_dict(state_dict['state_dict'],
                                          strict=False)
        self.pose_encoder.eval()

        # Forbid doing .train(), .eval() and .parameters()
        def train_noop(self, mode=True):
            pass

        def parameters_noop(self, recurse=True):
            return []

        self.pose_encoder.train = train_noop.__get__(self.pose_encoder,
                                                     nn.Module)
        self.pose_encoder.parameters = parameters_noop.__get__(
            self.pose_encoder, nn.Module)

        self.average_function = average_function

        self.finetuning = False
    def __init__(self, num_identity_images):
        super().__init__()

        self.identity_images = nn.Parameter(torch.empty(num_identity_images, 3, 256, 256))

        import sys
        X2FACE_ROOT_DIR = "embedders/X2Face"
        sys.path.append(f"{X2FACE_ROOT_DIR}/UnwrapMosaic/")
        try:
            from UnwrappedFace import UnwrappedFaceWeightedAverage
            state_dict = torch.load(
                f"{X2FACE_ROOT_DIR}/models/x2face_model_forpython3.pth", map_location='cpu')
        except (ImportError, FileNotFoundError):
            logger.critical(
                f"Please initialize submodules, then download 'x2face_model_forpython3.pth' from "
                f"http://www.robots.ox.ac.uk/~vgg/research/unsup_learn_watch_faces/release_x2face_eccv_withpy3.zip"
                f" and put it into {X2FACE_ROOT_DIR}/models/")
            raise

        self.x2face_model = UnwrappedFaceWeightedAverage(output_num_channels=2, input_num_channels=3, inner_nc=128)
        self.x2face_model.load_state_dict(state_dict['state_dict'])
        self.x2face_model.eval()

        # Forbid doing .train(), .eval() and .parameters()
        def train_noop(self, *args, **kwargs): pass
        def parameters_noop(self, *args, **kwargs): return []
        self.x2face_model.train = train_noop.__get__(self.x2face_model, nn.Module)
        self.x2face_model.parameters = parameters_noop.__get__(self.x2face_model, nn.Module)

        # Disable saving weights
        def state_dict_empty(self, *args, **kwargs): return {}
        self.x2face_model.state_dict = state_dict_empty.__get__(self.x2face_model, nn.Module)
        # Forbid loading weights after we have done that
        def _load_from_state_dict_noop(self, *args, **kwargs): pass
        for module in self.x2face_model.modules():
            module._load_from_state_dict = _load_from_state_dict_noop.__get__(module, nn.Module)

        self.finetuning = False
Beispiel #3
0
if opt.use_voxceleb2:
	opt.model_epoch_path +='voxceleb2'

if opt.use_content_other_face:
	opt.model_epoch_path += 'use_content_other_face'

if opt.use_discriminator:
	opt.model_epoch_path += 'use_discriminator0.01'

num_inputs = 3

if opt.use_uncertainty:
	num_outputs = 3
else:
	num_outputs = 2
model = UnwrappedFaceWeightedAverage(output_num_channels=num_outputs, input_num_channels=num_inputs,inner_nc=opt.inner_nc)

if opt.copy_weights:
	checkpoint_file = torch.load(opt.old_model)
	model.load_state_dict(checkpoint_file['state_dict'])
	opt.model_epoch_path = opt.model_epoch_path + 'copyWeights'
	del checkpoint_file


criterion = L1Loss()

if opt.num_views > 2:
	opt.model_epoch_path = opt.model_epoch_path + 'num_views' + str(opt.num_views) + 'combination_function' + opt.combination_function

model.stats = {'photometric_error' : np.zeros((0,1)), 'eyemoutherror' : np.zeros((0,1)), 'contenterror' : np.zeros((0,1)), 'loss' : np.zeros((0,1))}
model.val_stats = {'photometric_error' : np.zeros((0,1)), 'eyemoutherror' : np.zeros((0,1)), 'contenterror' : np.zeros((0,1)), 'loss' : np.zeros((0,1))}
Beispiel #4
0
def main():
    # model
    BASE_MODEL = '../experiment/release_models/'
    state_dict = torch.load(BASE_MODEL + 'x2face_model_forpython3.pth')

    model = UnwrappedFaceWeightedAverage(output_num_channels=2,
                                         input_num_channels=3,
                                         inner_nc=128)
    model.load_state_dict(state_dict['state_dict'])
    model = model.cuda()

    model = model.eval()

    # data
    save_path = "results/demo"
    # drive_file = ["./examples/Taylor_Swift/1.6/nuBaabkzzzI/"]
    # source_files = ["./examples/Taylor_Swift/1.6/nuBaabkzzzI/"]

    # source_root = "/home/cxu-serve/p1/common/voxceleb/test/img_sample"
    # source_files = []
    # for f in os.listdir(source_root):
    #     source_f = os.path.join(source_root, f)
    #     if len(os.listdir(source_f)) > 0:
    #         video_f = os.listdir(source_f)[0]
    #         img_f = os.listdir(os.path.join(source_f, video_f))[0]
    #         source_files.append(os.path.join(source_f, video_f, img_f))

    drive_file = [
        "/home/cxu-serve/p1/common/voxceleb2/unzip/test_video/id01567/cIZMA45dX0M/00291_aligned.mp4",
        "/home/cxu-serve/p1/common/voxceleb2/unzip/test_video/id00017/utfjXffHDgg/00198_aligned.mp4",
        "/home/cxu-serve/p1/common/voxceleb2/unzip/test_video/id01000/RvjbLfo3XDM/00052_aligned.mp4",
        "/home/cxu-serve/p1/common/voxceleb2/unzip/test_video/id04094/2sjuXzB2I1M/00025_aligned.mp4"
    ]
    source_files = drive_file
    # drive_file = source_files

    # drive_data = Grid(drive_file, nums=10)
    # source_data = Grid(source_files, nums=10)

    drive_data = Vox(drive_file, nums=10)
    source_data = Vox(source_files, nums=8)

    drive_loader = DataLoader(dataset=drive_data, batch_size=1, shuffle=False)
    source_loader = DataLoader(dataset=source_data,
                               batch_size=1,
                               shuffle=False)

    # test
    with torch.no_grad():
        for d_index, drive_imgs in tqdm(drive_loader):
            drive_imgs = torch.cat(drive_imgs, dim=0)
            drive_imgs = drive_imgs.cuda()
            drive = torchvision.utils.make_grid(drive_imgs.cpu().data).permute(
                1, 2, 0).numpy()

            # source images
            for s_index, source_imgs in tqdm(source_loader):
                input_imgs = [
                    img[0].repeat(drive_imgs.shape[0], 1, 1, 1)
                    for img in source_imgs
                ]
                # get image
                result = model(drive_imgs, *input_imgs)

                # store
                result = result.clamp(min=0, max=1)
                result_img = torchvision.utils.make_grid(result.cpu().data)
                result_img = result_img.permute(1, 2, 0).numpy()

                drive_file = drive_data.get_file(d_index.item()).split('/')[-2]
                file_name = os.path.join("{}.{}".format(
                    *source_data.get_file(s_index.item()).split('/')[-2:]))
                file_name = os.path.join(save_path, drive_file, file_name)
                if not os.path.exists(file_name):
                    os.makedirs(file_name)
                else:
                    shutil.rmtree(file_name)
                    os.makedirs(file_name)

                plt.figure()
                plt.axis('off')
                plt.imshow(np.vstack((result_img, drive)))
                plt.savefig(os.path.join(file_name, "result.png"))
                plt.close()

                source_store = torchvision.utils.make_grid(
                    torch.cat(source_imgs, dim=0).cpu().data)
                source_store = source_store.permute(1, 2, 0).numpy()
                plt.figure()
                plt.axis('off')
                plt.imshow(source_store)
                plt.savefig(os.path.join(file_name, "origin.png"))
                plt.close()
Beispiel #5
0
        'examples/audio_faces/Maya_Rudolph/1.6/Ylm6PVkbwhs/0004500.jpg',
        'examples/audio_faces/Cristin_Milioti/1.6/IblJpk1GDZA/0004575.jpg',
        'examples/audio_faces/Peter_Capaldi/1.6/uAgUjSqIj7U/0001375.jpg'
    ]

    # path to frames corresponding to driving audio features
    audio_path = 'examples/audio_faces/Peter_Capaldi/1.6/uAgUjSqIj7U'
    imgpaths = os.listdir(audio_path)

    # loading models
    # BASE_MODEL = '/scratch/shared/slow/ow/eccv/2018/release_models/' # Change to your path
    BASE_MODEL = '../experiment/release_models/'
    # model_path = BASE_MODEL + 'x2face_model.pth'
    model_path = BASE_MODEL + "x2face_model_forpython3.pth"
    model = UnwrappedFaceWeightedAverage(output_num_channels=2,
                                         input_num_channels=3,
                                         inner_nc=128)
    model.load_state_dict(torch.load(model_path)['state_dict'])

    s_dict = torch.load(model_path)
    modelfortargetpose = BottleneckFromNet()
    state = modelfortargetpose.state_dict()
    s_dict = {
        k: v
        for k, v in s_dict['state_dict'].items() if k in state.keys()
    }
    state.update(s_dict)
    modelfortargetpose.load_state_dict(state)

    posemodel = nn.Sequential(nn.Linear(128, 3))
    # p_dict_pre = torch.load(BASE_MODEL + '/posereg.pth')['state_dict']
parser.add_argument('--copy_weights', type=bool, default=False)
parser.add_argument('--model_type', type=str, default='UnwrappedFaceSampler_from1view')
parser.add_argument('--inner_nc', type=int, default=128)
parser.add_argument('--old_model', type=str, default='')
parser.add_argument('--results_folder', type=str, default='/scratch/local/hdd/ow/results/') # Where temp results will be stored
parser.add_argument('--model_epoch_path', type=str, default='/scratch/local/hdd/ow/faces/models/python/sampler/%s/', help='Location to save to')
opt = parser.parse_args()

torch.manual_seed(opt.seed)
torch.cuda.manual_seed(opt.seed)

writer = SummaryWriter(opt.results_folder)

opt.model_epoch_path = opt.model_epoch_path % 'x2face'

model = UnwrappedFaceWeightedAverage(output_num_channels=2, input_num_channels=3,inner_nc=opt.inner_nc)

if opt.copy_weights:
        checkpoint_file = torch.load(opt.old_model)
        model.load_state_dict(checkpoint_file['state_dict'])
        opt.model_epoch_path = opt.model_epoch_path + 'copyWeights'
        del checkpoint_file


criterion = nn.L1Loss()

model = model.cuda()

criterion = criterion.cuda()
parameters = [{'params' : model.parameters()}]
optimizer = optim.SGD(parameters, lr=opt.lr, momentum=0.9)
Beispiel #7
0
        results = torch.cat(results, dim=1)
        loss_list.append(reconstruction_loss(imgs, results).data.cpu().numpy())

        results = (results.data.cpu().numpy() * 255).astype('uint8')
        results = results[0].transpose((0, 2, 3, 1))
        imageio.mimsave(os.path.join(log_dir, x['name'][0] + format), results)

    print("Reconstruction loss: %s" % np.mean(loss_list))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='UnwrappedFace')
    parser.add_argument("--dataset",
                        default='data/nemo',
                        help="Path to dataset")
    parser.add_argument("--folder", default="out", help="out folder")
    parser.add_argument("--arch",
                        default='unet_64',
                        help="Network architecture")
    parser.add_argument("--format", default='.gif', help="Save format")
    args = parser.parse_args()

    model = UnwrappedFaceWeightedAverage(output_num_channels=2,
                                         input_num_channels=3,
                                         inner_nc=512)
    model = model.cuda()

    dataset = FramesDataset(args.dataset, is_train=False)
    reconstruction(model, os.path.join(args.folder, 'model.cpk'), args.folder,
                   dataset, args.format)
class Embedder(nn.Module):
    def __init__(self, identity_embedding_size, average_function):
        super().__init__()

        self.identity_embedding_size = identity_embedding_size

        import torchvision
        self.identity_encoder = torchvision.models.resnext50_32x4d(
            num_classes=identity_embedding_size)

        import sys
        X2FACE_ROOT_DIR = "embedders/X2Face"
        sys.path.append(f"{X2FACE_ROOT_DIR}/UnwrapMosaic/")
        try:
            from UnwrappedFace import UnwrappedFaceWeightedAverage
            state_dict = torch.load(
                f"{X2FACE_ROOT_DIR}/models/x2face_model_forpython3.pth",
                map_location='cpu')
        except (ImportError, FileNotFoundError):
            logger.critical(
                f"Please initialize submodules, then download 'x2face_model_forpython3.pth' from "
                f"http://www.robots.ox.ac.uk/~vgg/research/unsup_learn_watch_faces/release_x2face_eccv_withpy3.zip"
                f" and put it into {X2FACE_ROOT_DIR}/models/")
            raise

        self.pose_encoder = UnwrappedFaceWeightedAverage(output_num_channels=2,
                                                         input_num_channels=3,
                                                         inner_nc=128,
                                                         sampler_only=True)
        self.pose_encoder.load_state_dict(state_dict['state_dict'],
                                          strict=False)
        self.pose_encoder.eval()

        # Forbid doing .train(), .eval() and .parameters()
        def train_noop(self, mode=True):
            pass

        def parameters_noop(self, recurse=True):
            return []

        self.pose_encoder.train = train_noop.__get__(self.pose_encoder,
                                                     nn.Module)
        self.pose_encoder.parameters = parameters_noop.__get__(
            self.pose_encoder, nn.Module)

        self.average_function = average_function

        self.finetuning = False

    def enable_finetuning(self, data_dict=None):
        self.finetuning = True

    def get_identity_embedding(self, data_dict):
        inputs = data_dict['enc_rgbs']

        batch_size, num_faces, c, h, w = inputs.shape

        inputs = inputs.view(-1, c, h, w)
        identity_embeddings = self.identity_encoder(inputs).view(
            batch_size, num_faces, -1)
        assert identity_embeddings.shape[2] == self.identity_embedding_size

        if self.average_function == 'sum':
            identity_embeddings_aggregated = identity_embeddings.mean(1)
        elif self.average_function == 'max':
            identity_embeddings_aggregated = identity_embeddings.max(1)[0]
        else:
            raise ValueError(
                "Incorrect `average_function` argument, expected `sum` or `max`"
            )

        data_dict['embeds'] = identity_embeddings_aggregated
        data_dict['embeds_elemwise'] = identity_embeddings

    def get_pose_embedding(self, data_dict):
        x = data_dict['pose_input_rgbs'][:, 0]
        with torch.no_grad():
            data_dict['pose_embedding'] = self.pose_encoder.get_sampler(
                x, latent_pose_vector_only=True)[:, :, 0, 0]

    def forward(self, data_dict):
        if not self.finetuning:
            self.get_identity_embedding(data_dict)
        self.get_pose_embedding(data_dict)
Beispiel #9
0
    if not other_images is None:
        for i in range(0, len(other_images)):
            other_images[i] = Variable(other_images[i],
                                       requires_grad=requires_grad,
                                       volatile=volatile).cuda()

        return (model(poses, *imgs[0:-3]),
                model(poses,
                      *other_images[0:-3])), imgs + [poses], other_images[0:-3]

    return model(poses, *imgs[0:-3]), imgs + [poses]


model = UnwrappedFaceWeightedAverage(output_num_channels=2,
                                     input_num_channels=3,
                                     inner_nc=512)
model = model.cuda()

parameters = [{'params': model.parameters()}]
optimizer = optim.SGD(parameters, lr=args.lr, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=5)

train_set = FramesDataset(args.dataset)
training_data_loader = DataLoader(dataset=train_set,
                                  num_workers=4,
                                  batch_size=16,
                                  shuffle=True)

if not os.path.exists(args.folder):
    os.makedirs(args.folder)
class Generator(nn.Module):
    def __init__(self, num_identity_images):
        super().__init__()

        self.identity_images = nn.Parameter(torch.empty(num_identity_images, 3, 256, 256))

        import sys
        X2FACE_ROOT_DIR = "embedders/X2Face"
        sys.path.append(f"{X2FACE_ROOT_DIR}/UnwrapMosaic/")
        try:
            from UnwrappedFace import UnwrappedFaceWeightedAverage
            state_dict = torch.load(
                f"{X2FACE_ROOT_DIR}/models/x2face_model_forpython3.pth", map_location='cpu')
        except (ImportError, FileNotFoundError):
            logger.critical(
                f"Please initialize submodules, then download 'x2face_model_forpython3.pth' from "
                f"http://www.robots.ox.ac.uk/~vgg/research/unsup_learn_watch_faces/release_x2face_eccv_withpy3.zip"
                f" and put it into {X2FACE_ROOT_DIR}/models/")
            raise

        self.x2face_model = UnwrappedFaceWeightedAverage(output_num_channels=2, input_num_channels=3, inner_nc=128)
        self.x2face_model.load_state_dict(state_dict['state_dict'])
        self.x2face_model.eval()

        # Forbid doing .train(), .eval() and .parameters()
        def train_noop(self, *args, **kwargs): pass
        def parameters_noop(self, *args, **kwargs): return []
        self.x2face_model.train = train_noop.__get__(self.x2face_model, nn.Module)
        self.x2face_model.parameters = parameters_noop.__get__(self.x2face_model, nn.Module)

        # Disable saving weights
        def state_dict_empty(self, *args, **kwargs): return {}
        self.x2face_model.state_dict = state_dict_empty.__get__(self.x2face_model, nn.Module)
        # Forbid loading weights after we have done that
        def _load_from_state_dict_noop(self, *args, **kwargs): pass
        for module in self.x2face_model.modules():
            module._load_from_state_dict = _load_from_state_dict_noop.__get__(module, nn.Module)

        self.finetuning = False

    @torch.no_grad()
    def enable_finetuning(self, data_dict=None):
        """
            Make the necessary adjustments to generator architecture to allow fine-tuning.
            For `vanilla` generator, initialize AdaIN parameters from `data_dict['embeds']`
            and flag them as trainable parameters.
            Will require re-initializing optimizer, but only after the first call.

            data_dict:
                dict
                Required contents depend on the specific generator. For `vanilla` generator,
                it is `'embeds'` (1 x `args.embed_channels`).
                If `None`, the module's new parameters will be initialized randomly.
        """
        if data_dict is not None:
            self.identity_images = nn.Parameter(data_dict['enc_rgbs'][0]) # N x C x H x W

        self.finetuning = True

    @torch.no_grad()
    def forward(self, data_dict):
        batch_size = len(data_dict['pose_input_rgbs'])
        outputs = torch.empty_like(data_dict['pose_input_rgbs'][:, 0])

        for batch_idx in range(batch_size):
            # N x C x H x W
            identity_images = self.identity_images if self.finetuning else data_dict['enc_rgbs'][batch_idx]
            identity_images_list = []
            for identity_image in identity_images:
                identity_images_list.append(identity_image[None])
                
            # C x H x W
            pose_driver = data_dict['pose_input_rgbs'][batch_idx, 0]
            driver_images = pose_driver[None]

            result = self.x2face_model(driver_images, *identity_images_list)
            result = result.clamp(min=0, max=1)

            outputs[batch_idx].copy_(result[0])

        data_dict['fake_rgbs'] = outputs
        outputs.requires_grad_()
Beispiel #11
0
    audio_label_path = str(file_path).replace('audio_faces', 'audio_features').replace('jpg','npz')
    audio_feature = torch.Tensor(np.load(audio_label_path)['audio_feat'])
    return {'image' : img, 'audio' : audio_feature}
   # paths to source frames
sourcepaths= ['examples/audio_faces/Retta/1.6/ALELNl9E1Jc/0002725.jpg',
                'examples/audio_faces/Maya_Rudolph/1.6/Ylm6PVkbwhs/0004500.jpg', 
               'examples/audio_faces/Cristin_Milioti/1.6/IblJpk1GDZA/0004575.jpg']

# path to frames corresponding to driving audio features
audio_path = 'examples/audio_faces/Peter_Capaldi/1.6/uAgUjSqIj7U'
imgpaths = os.listdir(audio_path)

# loading models
BASE_MODEL = '/mnt/ssd0/dat/lchen63/release_models/' # Change to your path
model_path = BASE_MODEL + 'x2face_model.pth'
model = UnwrappedFaceWeightedAverage(output_num_channels=2, input_num_channels=3,inner_nc=128)
model.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage)['state_dict'])
s_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
modelfortargetpose = BottleneckFromNet()
state = modelfortargetpose.state_dict()
s_dict = {k: v for k, v in s_dict['state_dict'].items() if k in state.keys()}
state.update(s_dict)
modelfortargetpose.load_state_dict(state)

posemodel = nn.Sequential(nn.Linear(128, 3))
p_dict_pre = torch.load(BASE_MODEL + '/posereg.pth', map_location=lambda storage, loc: storage)['state_dict']
posemodel._modules['0'].weight.data = p_dict_pre['posefrombottle.weight'].cpu()
posemodel._modules['0'].bias.data = p_dict_pre['posefrombottle.bias'].cpu()

bottleneckmodel = nn.Sequential(nn.Linear(3, 128, bias=False), nn.BatchNorm1d(128))
b_dict_pre = torch.load(BASE_MODEL + '/posetobottle.pth', map_location=lambda storage, loc: storage)['state_dict']
def main():
    # model
    BASE_MODEL = '../experiment/release_models/'
    state_dict = torch.load(BASE_MODEL + 'x2face_model_forpython3.pth')

    model = UnwrappedFaceWeightedAverage(output_num_channels=2,
                                         input_num_channels=3,
                                         inner_nc=128)
    model.load_state_dict(state_dict['state_dict'])
    model = model.cuda()

    model = model.eval()

    # data
    save_path = "extra_degree_result/vox"

    ref_files = []
    files = []

    pickle_files = np.load('vox_demo.npy', allow_pickle=True)
    files = list(set([f[0] for f in pickle_files]))
    tgt_ids = {}
    for f in pickle_files:
        if f[0] not in tgt_ids:
            tgt_ids[f[0]] = []
        tgt_ids[f[0]] += np.arange(int(f[1]), int(f[2])).tolist()
    ref_files = files

    # files = files[670:]
    total_f = len(files)
    files = [f for f in files if os.path.exists(f)]
    print('totally {} files while valid {} files'.format(total_f, len(files)))

    dataset = VoxSingle(files, ref_indx='0,5,10,15,20,25,30,35')
    dataloader = DataLoader(dataset=dataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=0)

    # test
    state = 1
    with torch.no_grad():
        while state != -1:
            print("current {}/{}".format(state, len(files)))
            # get reference images
            ref_imgs = dataset.get_ref()
            # ref_imgs = ref_dataset.get_ref()
            # save reference images
            file = dataset.cur_file.split('/')
            path_name = '{}'.format(file[-1])
            save_file_name = os.path.join(save_path, path_name)
            for ref_id, img in enumerate(ref_imgs):
                save_ref_file = os.path.join(save_file_name, 'reference')
                if not os.path.exists(save_ref_file):
                    os.makedirs(save_ref_file)
                save_img = (img * 255).permute(1, 2, 0).cpu().numpy().astype(
                    np.uint8)
                Image.fromarray(save_img).save(
                    os.path.join(save_ref_file, '%05d.jpg' % ref_id))

            # preprocess
            ref_imgs = [ref.unsqueeze(0) for ref in ref_imgs]
            ref_imgs = torch.cat(ref_imgs, dim=0)
            ref_imgs = ref_imgs.cuda()
            ref = torchvision.utils.make_grid(ref_imgs.cpu().data).permute(
                1, 2, 0).numpy()
            # synthesize
            for d_index, drive_img in tqdm(dataloader):
                if d_index not in tgt_ids[dataset.cur_file]:
                    continue
                if os.path.exists(
                        os.path.join(save_file_name, '%05d.jpg' % d_index)):
                    continue

                drive_img = drive_img.cuda()

                input_imgs = [
                    img.repeat(drive_img.shape[0], 1, 1, 1) for img in ref_imgs
                ]
                # get image
                result = model(drive_img, *input_imgs)

                # store
                result = result.clamp(min=0, max=1)
                result_img = torchvision.utils.make_grid(result.cpu().data)
                result_img = result_img.permute(1, 2, 0).numpy()

                save_img = drive_img.clamp(min=0, max=1)
                save_img = torchvision.utils.make_grid(save_img.cpu().data)
                save_img = save_img.permute(1, 2, 0).numpy()

                # final_img = np.hstack([result_img, save_img])
                final_img = result_img

                final_img = (final_img * 255).astype(np.uint8)
                Image.fromarray(final_img).save(
                    os.path.join(save_file_name, '%05d.jpg' % d_index))

            # combine video
            # image_to_video(save_file_name, os.path.join(save_file_name, '{}.mp4'.format(path_name)))

            # new file
            state = dataset.nextfile()