def __getitem__(self, index): if self.opt.dataset_name == 'face': video_path = self.data[index][ 1] #os.path.join(self.root, 'pretrain', v_id[0] , v_id[1][:5] + '_crop.mp4' ) lmark_path = self.data[index][ 0] #= os.path.join(self.root, 'pretrain', v_id[0] , v_id[1] ) elif self.opt.dataset_name == 'vox': paths = self.data[index] video_path = os.path.join(self.root, self.video_bag, paths[0], paths[1], paths[2] + "_aligned.mp4") if self.opt.no_head_motion: lmark_path = os.path.join(self.root, self.video_bag, paths[0], paths[1], paths[2] + "_aligned_front.npy") else: lmark_path = os.path.join(self.root, self.video_bag, paths[0], paths[1], paths[2] + "_aligned.npy") ani_path = os.path.join(self.root, self.video_bag, paths[0], paths[1], paths[2] + "_aligned_ani.mp4") rt_path = os.path.join(self.root, self.video_bag, paths[0], paths[1], paths[2] + "_aligned_rt.npy") front_path = os.path.join(self.root, self.video_bag, paths[0], paths[1], paths[2] + "_aligned_front.npy") ani_id = paths[3] elif self.opt.dataset_name == 'grid': paths = self.data[index] video_path = os.path.join(self.root, self.video_bag, paths[0], paths[1] + '_crop.mp4') lmark_path = os.path.join(self.root, self.video_bag, paths[0], paths[1] + '_original.npy') rt_path = os.path.join(self.root, self.video_bag, paths[0], paths[1] + '_rt.npy') front_path = os.path.join(self.root, self.video_bag, paths[0], paths[1] + '_front.npy') elif self.opt.dataset_name == 'lrs': paths = self.data[index] paths[1] = paths[1].split('_')[0] video_path = os.path.join(self.root, self.video_bag, paths[0], paths[1] + '_crop.mp4') lmark_path = os.path.join(self.root, self.video_bag, paths[0], paths[1] + '_original.npy') ani_path = os.path.join(self.root, self.video_bag, paths[0], paths[1] + "_ani.mp4") rt_path = os.path.join(self.root, self.video_bag, paths[0], paths[1] + '_rt.npy') front_path = os.path.join(self.root, self.video_bag, paths[0], paths[1] + '_front.npy') if self.opt.warp_ani: ani_id = int(paths[2]) elif self.opt.dataset_name == 'lrw': paths = self.data[index] video_path = os.path.join(paths[0] + '_crop.mp4') lmark_path = os.path.join(paths[0] + '_original.npy') ani_path = os.path.join(paths[0] + "_ani.mp4") rt_path = os.path.join(paths[0] + '_rt.npy') front_path = os.path.join(paths[0] + '_front.npy') if self.opt.warp_ani: ani_id = int(paths[1]) elif self.opt.dataset_name == 'crema': paths = self.data[index] video_path = os.path.join(self.root, self.video_bag, paths[0][:-10] + '_crop.mp4') lmark_path = os.path.join(self.root, self.video_bag, paths[0][:-10] + '_original.npy') rt_path = os.path.join(self.root, self.video_bag, paths[0][:-10] + '_rt.npy') front_path = os.path.join(self.root, self.video_bag, paths[0][:-10] + '_front.npy') elif self.opt.dataset_name == 'obama': paths = self.data[index] video_path = os.path.join(self.root, self.video_bag, paths[0][:-11] + '_crop2.mp4') lmark_path = os.path.join(self.root, self.video_bag, paths[0][:-11] + '_original2.npy') ani_path = os.path.join(self.root, self.video_bag, paths[0][:-11] + "_ani2.mp4") rt_path = os.path.join(self.root, self.video_bag, paths[0][:-11] + '_rt2.npy') front_path = os.path.join(self.root, self.video_bag, paths[0][:-11] + '_front2.npy') ani_id = int(paths[1]) elif self.opt.dataset_name == 'ouyang': self.real_video = None self.ani_video = None self.use_for_finetune = True paths = self.data[index] video_path = os.path.join(self.root, self.video_bag, paths + '_crop.mp4') lmark_path = os.path.join(self.root, self.video_bag, paths + '__original.npy') ani_path = os.path.join(self.root, self.video_bag, paths + "__ani.mp4") rt_path = os.path.join(self.root, self.video_bag, paths + '__rt.npy') front_path = os.path.join(self.root, self.video_bag, paths + '__front.npy') ani_id = 11174 # reseed np.random.seed(int.from_bytes(os.urandom(4), byteorder='little')) # read in data self.video_path = video_path lmarks = np.load(lmark_path) #[:,:,:-1] real_video = self.read_videos(video_path) if self.opt.dataset_name == 'face': lmarks = lmarks[:-1] else: if self.opt.warp_ani: front = np.load(front_path) rt = np.load(rt_path) cor_num, wro_nums = self.clean_lmarks(lmarks) lmarks = lmarks[cor_num] real_video = np.asarray(real_video)[cor_num] rt = rt[cor_num] # smooth landmarks for i in range(lmarks.shape[1]): x = lmarks[:, i, 0] x = face_utils.smooth(x, window_len=5) lmarks[:, i, 0] = x[2:-2] y = lmarks[:, i, 1] y = face_utils.smooth(y, window_len=5) lmarks[:, i, 1] = y[2:-2] if self.opt.warp_ani: ani_video = self.read_videos(ani_path) # clean data ani_video = np.asarray(ani_video)[cor_num] v_length = len(real_video) # sample index of frames for embedding network if self.opt.ref_ratio is not None: input_indexs, target_id = self.get_image_index_ratio( self.n_frames_total, v_length) elif self.opt.for_finetune: input_indexs, target_id = self.get_image_index_finetune( self.n_frames_total, lmarks=lmarks) else: input_indexs, target_id = self.get_image_index( self.n_frames_total, v_length) # whether get open mouth if self.opt.find_largest_mouth: result_indexs = self.get_open_mouth(lmarks) input_indexs = result_indexs if result_indexs is not None else input_indexs # define scale scale = self.define_scale() transform, transform_L, transform_T = self.get_transforms() # get reference ref_images, ref_lmarks, ref_coords = self.prepare_datas( real_video, lmarks, input_indexs, transform, transform_L, scale) # get target tgt_images, tgt_lmarks, tgt_crop_coords = self.prepare_datas( real_video, lmarks, target_id, transform, transform_L, scale) # get template for target tgt_templates = [] tgt_templates_eyes = [] tgt_templates_mouth = [] for gg in target_id: lmark = lmarks[gg] tgt_templates.append( self.get_template(lmark, transform_T, self.output_shape, tgt_crop_coords)) tgt_templates_eyes.append( self.get_template(lmark, transform_T, self.output_shape, tgt_crop_coords, only_eyes=True)) tgt_templates_mouth.append( self.get_template(lmark, transform_T, self.output_shape, tgt_crop_coords, only_mouth=True)) if self.opt.warp_ani: # get animation & get cropped ground truth ani_lmarks_back = [] ani_lmarks = [] ani_images = [] cropped_images = [] cropped_lmarks = [] for gg in target_id: cropped_gt = real_video[gg].copy() ani_lmarks.append(self.reverse_rt(front[int(ani_id)], rt[gg])) ani_lmarks[-1] = np.array(ani_lmarks[-1]) ani_lmarks_back.append(ani_lmarks[-1]) ani_images.append(ani_video[gg]) mask = ani_video[gg] < 10 # mask = scipy.ndimage.morphology.binary_dilation(mask.numpy(),iterations = 5).astype(np.bool) cropped_gt[mask] = 0 cropped_images.append(cropped_gt) cropped_lmarks.append(lmarks[gg]) ani_images, ani_lmarks, ani_coords = self.prepare_datas( ani_images, ani_lmarks, list(range(len(target_id))), transform, transform_L, scale) cropped_images, cropped_lmarks, _ = self.prepare_datas(cropped_images, cropped_lmarks, list(range(len(target_id))), \ transform, transform_L, scale, crop_coords=tgt_crop_coords) # get warping reference rt = rt[:, :3] warping_ref_ids = self.get_warp_ref(rt, input_indexs, target_id) warping_refs = [ref_images[w_ref_id] for w_ref_id in warping_ref_ids] warping_ref_lmarks = [ ref_lmarks[w_ref_id] for w_ref_id in warping_ref_ids ] ori_warping_refs = copy.deepcopy(warping_refs) # get template for warp reference and animation if self.opt.crop_ref: # for warp reference for warp_id, warp_ref in enumerate(warping_refs): lmark_id = input_indexs[warping_ref_ids[warp_id]] warp_ref_lmark = lmarks[lmark_id] warp_ref_template = torch.Tensor( self.get_template(warp_ref_lmark, transform_T, self.output_shape, ref_coords)) warp_ref_template_inter = -warp_ref * warp_ref_template + ( 1 - warp_ref_template) warping_refs[warp_id] = warp_ref / warp_ref_template_inter # for animation if self.opt.warp_ani: for ani_lmark_id, ani_lmark_temp in enumerate(ani_lmarks_back): ani_template = torch.Tensor( self.get_template(ani_lmark_temp, transform_T, self.output_shape, ani_coords)) ani_template_inter = -ani_images[ ani_lmark_id] * ani_template + (1 - ani_template) ani_images[ani_lmark_id] = ani_images[ ani_lmark_id] / ani_template_inter # preprocess target_img_path = [ os.path.join(video_path[:-4], '%05d.png' % t_id) for t_id in target_id ] ref_images = torch.cat( [ref_img.unsqueeze(0) for ref_img in ref_images], axis=0) ref_lmarks = torch.cat( [ref_lmark.unsqueeze(0) for ref_lmark in ref_lmarks], axis=0) tgt_images = torch.cat( [tgt_img.unsqueeze(0) for tgt_img in tgt_images], axis=0) tgt_lmarks = torch.cat( [tgt_lmark.unsqueeze(0) for tgt_lmark in tgt_lmarks], axis=0) if self.opt.isTrain: tgt_templates = torch.cat([ torch.Tensor(tgt_template).unsqueeze(0).unsqueeze(0) for tgt_template in tgt_templates ], axis=0) tgt_templates_eyes = torch.cat([ torch.Tensor(tgt_template).unsqueeze(0).unsqueeze(0) for tgt_template in tgt_templates_eyes ], axis=0) tgt_templates_mouth = torch.cat([ torch.Tensor(tgt_template).unsqueeze(0).unsqueeze(0) for tgt_template in tgt_templates_mouth ], axis=0) else: tgt_templates = torch.cat([ torch.Tensor(tgt_template).unsqueeze(0).unsqueeze(0) for tgt_template in tgt_templates ], axis=0) tgt_templates_eyes = torch.cat([ torch.Tensor(tgt_template).unsqueeze(0).unsqueeze(0) for tgt_template in tgt_templates_eyes ], axis=0) tgt_templates_mouth = torch.cat([ torch.Tensor(tgt_template).unsqueeze(0).unsqueeze(0) for tgt_template in tgt_templates_mouth ], axis=0) warping_refs = torch.cat( [warping_ref.unsqueeze(0) for warping_ref in warping_refs], 0) warping_ref_lmarks = torch.cat([ warping_ref_lmark.unsqueeze(0) for warping_ref_lmark in warping_ref_lmarks ], 0) ori_warping_refs = torch.cat([ ori_warping_ref.unsqueeze(0) for ori_warping_ref in ori_warping_refs ], 0) if self.opt.warp_ani: ani_images = torch.cat( [ani_image.unsqueeze(0) for ani_image in ani_images], 0) ani_lmarks = torch.cat( [ani_lmark.unsqueeze(0) for ani_lmark in ani_lmarks], 0) cropped_images = torch.cat([ cropped_image.unsqueeze(0) for cropped_image in cropped_images ], 0) cropped_lmarks = torch.cat([ cropped_lmark.unsqueeze(0) for cropped_lmark in cropped_lmarks ], 0) # crop eyes and mouth from reference if self.opt.crop_ref: if self.opt.warp_ani: crop_template_inter = -cropped_images * tgt_templates + ( 1 - tgt_templates) cropped_images = cropped_images / crop_template_inter tgt_template_inter = -tgt_images * tgt_templates + (1 - tgt_templates) tgt_mask_images = tgt_images / tgt_template_inter input_dic = {'v_id' : target_img_path, 'tgt_label': tgt_lmarks, 'tgt_template': tgt_templates, 'ref_image':ref_images , 'ref_label': ref_lmarks, \ 'tgt_image': tgt_images, 'target_id': target_id , 'warping_ref': warping_refs , 'warping_ref_lmark': warping_ref_lmarks, \ 'ori_warping_refs': ori_warping_refs, 'path': video_path} if self.opt.warp_ani: input_dic.update({ 'ani_image': ani_images, 'ani_lmark': ani_lmarks, 'cropped_images': cropped_images, 'cropped_lmarks': cropped_lmarks }) if self.opt.crop_ref: input_dic.update({'tgt_mask_images': tgt_mask_images}) else: input_dic.update({'tgt_mask_images': tgt_images}) return input_dic
def initialize(self, opt): """ Instantiates the Dataset. :param root: Path to the folder where the pre-processed dataset is stored. :param extension: File extension of the pre-processed video files. :param shuffle: If True, the video files will be shuffled. :param transform: Transformations to be done to all frames of the video files. :param shuffle_frames: If True, each time a video is accessed, its frames will be shuffled. """ assert not opt.isTrain self.output_shape = tuple([opt.loadSize, opt.loadSize]) self.num_frames = opt.n_shot self.n_frames_total = opt.n_frames_G self.opt = opt self.root = opt.dataroot self.fix_crop_pos = True self.ref_search = opt.ref_rt_path is not None and opt.tgt_rt_path is not None # mapping from keypoints to face part self.add_upper_face = not opt.no_upper_face self.part_list = [ [ list(range(0, 17)) + ((list(range(68, 83)) + [0]) if self.add_upper_face else []) ], # face [range(17, 22)], # right eyebrow [range(22, 27)], # left eyebrow [[28, 31], range(31, 36), [35, 28]], # nose [[36, 37, 38, 39], [39, 40, 41, 36]], # right eye [[42, 43, 44, 45], [45, 46, 47, 42]], # left eye [ range(48, 55), [54, 55, 56, 57, 58, 59, 48], range(60, 65), [64, 65, 66, 67, 60] ], # mouth and tongue ] if self.opt.audio_drive: self.tgt_part_list = copy.deepcopy(self.part_list) self.part_list = [ [ list(range(0, 17)) + ((list(range(68, 83)) + [0]) if self.add_upper_face else []) ], # face [range(17, 22)], # right eyebrow [range(22, 27)], # left eyebrow [[28, 31], range(31, 36), [35, 28]], # nose [[36, 37, 38, 39], [39, 40, 41, 36]], # right eye [[42, 43, 44, 45], [45, 46, 47, 42]], # left eye ] # load path self.tgt_video_path = opt.tgt_video_path self.tgt_lmarks_path = opt.tgt_lmarks_path self.tgt_ani_path = opt.tgt_ani_path self.tgt_rt_path = opt.tgt_rt_path if self.opt.audio_drive: self.tgt_audio_path = opt.tgt_audio_path self.ref_video_path = opt.ref_video_path self.ref_lmarks_path = opt.ref_lmarks_path self.ref_rt_path = opt.ref_rt_path self.ref_front_path = opt.ref_front_path # read in data self.tgt_lmarks = np.load(self.tgt_lmarks_path) #[:,:,:-1] self.tgt_video = self.read_videos(self.tgt_video_path) if self.opt.audio_drive: fs, self.tgt_audio = wavfile.read(self.tgt_audio_path) self.chunck_size = int(fs / 25) # get enough video associate with landmark if len(self.tgt_video) < self.tgt_lmarks.shape[0]: self.tgt_video.extend([ self.tgt_video[-1] for i in range(self.tgt_lmarks.shape[0] - len(self.tgt_video)) ]) self.ref_lmarks = np.load(self.ref_lmarks_path) self.ref_video = self.read_videos(self.ref_video_path) if self.opt.warp_ani: self.tgt_ani_video = self.read_videos(self.tgt_ani_path) self.ref_front = np.load(self.ref_front_path) self.ref_ani_id = self.opt.ref_ani_id if self.opt.warp_ani or self.ref_search: self.tgt_rt = np.load(self.tgt_rt_path) self.ref_rt = np.load(self.ref_rt_path) # clean correct_nums, wro_nums = self.clean_lmarks(self.tgt_lmarks) self.tgt_lmarks = self.tgt_lmarks[correct_nums] self.tgt_video = np.asarray(self.tgt_video)[correct_nums] if self.opt.warp_ani: self.tgt_ani_video = np.asarray(self.tgt_ani_video)[correct_nums] if self.opt.warp_ani or self.ref_search: self.tgt_rt = self.tgt_rt[correct_nums] # audio if self.opt.audio_drive: if wro_nums.shape[0] != 0: self.tgt_audio = self.clean_audio(self.chunck_size, self.tgt_audio, wro_nums) # append left_append = self.tgt_audio[:self.opt.audio_append * self.chunck_size] right_append = self.tgt_audio[-(self.opt.audio_append + 1) * self.chunck_size:] self.tgt_audio = np.insert(self.tgt_audio, 0, left_append, axis=0) self.tgt_audio = np.insert(self.tgt_audio, -1, right_append, axis=0) # smooth landmarks for i in range(self.tgt_lmarks.shape[1]): x = self.tgt_lmarks[:, i, 0] x = face_utils.smooth(x, window_len=5) self.tgt_lmarks[:, i, 0] = x[2:-2] y = self.tgt_lmarks[:, i, 1] y = face_utils.smooth(y, window_len=5) self.tgt_lmarks[:, i, 1] = y[2:-2] # get eyes # self.tgt_lmarks = eye_blinking(self.tgt_lmarks) correct_nums, _ = self.clean_lmarks(self.ref_lmarks) self.ref_lmarks = self.ref_lmarks[correct_nums] self.ref_video = np.asarray(self.ref_video)[correct_nums] if self.opt.warp_ani or self.ref_search: self.ref_rt = self.ref_rt[correct_nums] # smooth landmarks for i in range(self.ref_lmarks.shape[1]): x = self.ref_lmarks[:, i, 0] x = face_utils.smooth(x, window_len=5) self.ref_lmarks[:, i, 0] = x[2:-2] y = self.ref_lmarks[:, i, 1] y = face_utils.smooth(y, window_len=5) self.ref_lmarks[:, i, 1] = y[2:-2] # get transform for image and landmark self.transform = transforms.Compose([ transforms.Lambda(lambda img: self.__scale_image( img, self.output_shape, Image.BICUBIC)), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) self.transform_L = transforms.Compose([ transforms.Lambda(lambda img: self.__scale_image( img, self.output_shape, Image.BILINEAR)), transforms.ToTensor() ]) self.transform_T = transforms.Compose([ transforms.Lambda(lambda img: self.__scale_image( img, self.output_shape, Image.BILINEAR)), ]) # define parameters for inference self.ref_lmarks_temp = self.ref_lmarks self.ref_video, self.ref_lmarks, self.ref_indices, self.ref_coords = self.define_inference( self.ref_video, self.ref_lmarks) # get id for target if self.opt.tgt_ids is None: self.tgt_ids = list(range(len(self.tgt_video))) else: self.tgt_ids = [int(t_id) for t_id in self.opt.tgt_ids.split(',')]