def finetune(self, source_path, save_checkpoint=True): checkpoint_path = os.path.splitext(source_path)[0] + '_Gr.pth' if os.path.isfile(checkpoint_path): print('=> Loading the reenactment generator finetuned on: "%s"...' % os.path.basename(source_path)) checkpoint = torch.load(checkpoint_path) if self.gpus and len(self.gpus) > 1: self.Gr.module.load_state_dict(checkpoint['state_dict']) else: self.Gr.load_state_dict(checkpoint['state_dict']) return print('=> Finetuning the reenactment generator on: "%s"...' % os.path.basename(source_path)) torch.set_grad_enabled(True) self.Gr.train(True) img_transforms = img_lms_pose_transforms.Compose([Pyramids(2), ToTensor(), Normalize()]) train_dataset = SingleSeqRandomPairDataset(source_path, transform=img_transforms, postfixes=('_lms.npz',)) train_sampler = RandomSampler(train_dataset, replacement=True, num_samples=self.finetune_iterations) train_loader = DataLoader(train_dataset, batch_size=self.finetune_batch_size, sampler=train_sampler, num_workers=self.finetune_workers, pin_memory=True, drop_last=True, shuffle=False) optimizer = optim.Adam(self.Gr.parameters(), lr=self.finetune_lr, betas=(0.5, 0.999)) # For each batch in the training data for i, (img, landmarks) in enumerate(tqdm(train_loader, unit='batches', file=sys.stdout)): # Prepare input with torch.no_grad(): # For each view images and landmarks landmarks[1] = landmarks[1].to(self.device) for j in range(len(img)): # For each pyramid image: push to device for p in range(len(img[j])): img[j][p] = img[j][p].to(self.device) # Concatenate pyramid images with context to derive the final input input = [] for p in range(len(img[0])): context = self.landmarks_decoders[p](landmarks[1]) input.append(torch.cat((img[0][p], context), dim=1)) # Reenactment img_pred = self.Gr(input) # Reconstruction loss loss_pixelwise = self.criterion_pixelwise(img_pred, img[1][0]) loss_id = self.criterion_id(img_pred, img[1][0]) loss_rec = 0.1 * loss_pixelwise + loss_id # Update generator weights optimizer.zero_grad() loss_rec.backward() optimizer.step() # Save finetuned weights to file if save_checkpoint: arch = self.Gr.module.arch if self.gpus and len(self.gpus) > 1 else self.Gr.arch state_dict = self.Gr.module.state_dict() if self.gpus and len(self.gpus) > 1 else self.Gr.state_dict() torch.save({'state_dict': state_dict, 'arch': arch}, checkpoint_path) torch.set_grad_enabled(False) self.Gr.train(False)
def __call__(self, source_path, target_path, output_path=None, select_source='longest', select_target='longest', finetune=None): is_vid = os.path.splitext(source_path)[1] == '.mp4' finetune = self.finetune_enabled and is_vid if finetune is None else finetune and is_vid # Validation assert os.path.isfile( source_path), 'Source path "%s" does not exist' % source_path assert os.path.isfile( target_path), 'Target path "%s" does not exist' % target_path # Cache input source_cache_dir, source_seq_file_path, _ = self.cache(source_path) target_cache_dir, target_seq_file_path, _ = self.cache(target_path) # Load sequences from file with open(source_seq_file_path, "rb") as fp: # Unpickling source_seq_list = pickle.load(fp) with open(target_seq_file_path, "rb") as fp: # Unpickling target_seq_list = pickle.load(fp) # Select source and target sequence source_seq = select_seq(source_seq_list, select_source) target_seq = select_seq(target_seq_list, select_target) # Set source and target sequence videos paths src_path_no_ext, src_ext = os.path.splitext(source_path) src_vid_seq_name = os.path.basename( src_path_no_ext) + '_seq%02d%s' % (source_seq.id, src_ext) src_vid_seq_path = os.path.join(source_cache_dir, src_vid_seq_name) tgt_path_no_ext, tgt_ext = os.path.splitext(target_path) tgt_vid_seq_name = os.path.basename( tgt_path_no_ext) + '_seq%02d%s' % (target_seq.id, tgt_ext) tgt_vid_seq_path = os.path.join(target_cache_dir, tgt_vid_seq_name) # Set output path if output_path is not None: if os.path.isdir(output_path): output_filename = f'{os.path.basename(src_path_no_ext)}_{os.path.basename(tgt_path_no_ext)}.mp4' output_path = os.path.join(output_path, output_filename) # Initialize appearance map src_transform = img_lms_pose_transforms.Compose( [Rotate(), Pyramids(2), ToTensor(), Normalize()]) tgt_transform = img_lms_pose_transforms.Compose( [ToTensor(), Normalize()]) appearance_map = AppearanceMapDataset( src_vid_seq_path, tgt_vid_seq_path, src_transform, tgt_transform, self.landmarks_postfix, self.pose_postfix, self.segmentation_postfix, self.min_radius) appearance_map_loader = DataLoader(appearance_map, batch_size=self.batch_size, num_workers=1, pin_memory=True, drop_last=False, shuffle=False) # Initialize video writer self.video_renderer.init(target_path, target_seq, output_path, _appearance_map=appearance_map) # Finetune reenactment model on source sequences if finetune: self.finetune(src_vid_seq_path, self.finetune_save) print( f'=> Face swapping: "{src_vid_seq_name}" -> "{tgt_vid_seq_name}"...' ) # For each batch of frames in the target video for i, (src_frame, src_landmarks, src_poses, bw, tgt_frame, tgt_landmarks, tgt_pose, tgt_mask) \ in enumerate(tqdm(appearance_map_loader, unit='batches', file=sys.stdout)): # Prepare input for p in range(len(src_frame)): src_frame[p] = src_frame[p].to(self.device) tgt_frame = tgt_frame.to(self.device) tgt_landmarks = tgt_landmarks.to(self.device) # tgt_mask = tgt_mask.unsqueeze(1).to(self.device) tgt_mask = tgt_mask.unsqueeze(1).int().to(self.device).bool( ) # TODO: check if the boolean tensor bug is fixed bw = bw.to(self.device) bw_indices = torch.nonzero(torch.any(bw > 0, dim=0), as_tuple=True)[0] bw = bw[:, bw_indices] # For each source frame perform reenactment reenactment_triplet = [] for j in bw_indices: input = [] for p in range(len(src_frame)): context = self.landmarks_decoders[p](tgt_landmarks) input.append( torch.cat((src_frame[p][:, j], context), dim=1)) # Reenactment reenactment_triplet.append(self.Gr(input).unsqueeze(1)) reenactment_tensor = torch.cat(reenactment_triplet, dim=1) # Barycentric interpolation of reenacted frames reenactment_tensor = (reenactment_tensor * bw.view(*bw.shape, 1, 1, 1)).sum(dim=1) # Compute reenactment segmentation reenactment_seg = self.S(reenactment_tensor) reenactment_background_mask_tensor = (reenactment_seg.argmax(1) != 1).unsqueeze(1) # Remove the background of the aligned face reenactment_tensor.masked_fill_(reenactment_background_mask_tensor, -1.0) # Soften target mask soft_tgt_mask, eroded_tgt_mask = self.smooth_mask(tgt_mask) # Complete face inpainting_input_tensor = torch.cat( (reenactment_tensor, eroded_tgt_mask.float()), dim=1) inpainting_input_tensor_pyd = create_pyramid( inpainting_input_tensor, 2) completion_tensor = self.Gc(inpainting_input_tensor_pyd) # Blend faces transfer_tensor = transfer_mask(completion_tensor, tgt_frame, eroded_tgt_mask) blend_input_tensor = torch.cat( (transfer_tensor, tgt_frame, eroded_tgt_mask.float()), dim=1) blend_input_tensor_pyd = create_pyramid(blend_input_tensor, 2) blend_tensor = self.Gb(blend_input_tensor_pyd) result_tensor = blend_tensor * soft_tgt_mask + tgt_frame * ( 1 - soft_tgt_mask) # Write output if self.verbose == 0: self.video_renderer.write(result_tensor) elif self.verbose == 1: curr_src_frames = [ src_frame[0][:, i] for i in range(src_frame[0].shape[1]) ] self.video_renderer.write(*curr_src_frames, result_tensor, tgt_frame) else: curr_src_frames = [ src_frame[0][:, i] for i in range(src_frame[0].shape[1]) ] tgt_seg_blend = blend_seg_label(tgt_frame, tgt_mask.squeeze(1), alpha=0.2) soft_tgt_mask = soft_tgt_mask.mul(2.).sub(1.).repeat( 1, 3, 1, 1) self.video_renderer.write(*curr_src_frames, result_tensor, tgt_frame, reenactment_tensor, completion_tensor, transfer_tensor, soft_tgt_mask, tgt_seg_blend, tgt_pose) # Load original reenactment weights if finetune: if self.gpus and len(self.gpus) > 1: self.Gr.module.load_state_dict(self.reenactment_state_dict) else: self.Gr.load_state_dict(self.reenactment_state_dict) # Finalize video and wait for the video writer to finish writing self.video_renderer.finalize() self.video_renderer.wait_until_finished()
def __call__(self, source_path, target_path, output_path=None, select_source='longest', select_target='longest', finetune=None): is_vid = os.path.splitext(source_path)[1] == '.mp4' finetune = self.finetune_enabled and is_vid if finetune is None else finetune and is_vid # Validation assert os.path.isfile( source_path), 'Source path "%s" does not exist' % source_path assert os.path.isfile( target_path), 'Target path "%s" does not exist' % target_path # Cache input source_cache_dir, source_seq_file_path, _ = self.cache(source_path) target_cache_dir, target_seq_file_path, _ = self.cache(target_path) # Load sequences from file with open(source_seq_file_path, "rb") as fp: # Unpickling source_seq_list = pickle.load(fp) with open(target_seq_file_path, "rb") as fp: # Unpickling target_seq_list = pickle.load(fp) # Select source and target sequence source_seq = select_seq(source_seq_list, select_source) target_seq = select_seq(target_seq_list, select_target) # Set source and target sequence videos paths src_path_no_ext, src_ext = os.path.splitext(source_path) src_vid_seq_name = os.path.basename( src_path_no_ext) + '_seq%02d%s' % (source_seq.id, src_ext) src_vid_seq_path = os.path.join(source_cache_dir, src_vid_seq_name) tgt_path_no_ext, tgt_ext = os.path.splitext(target_path) tgt_vid_seq_name = os.path.basename( tgt_path_no_ext) + '_seq%02d%s' % (target_seq.id, tgt_ext) tgt_vid_seq_path = os.path.join(target_cache_dir, tgt_vid_seq_name) # Set output path if output_path is not None: if os.path.isdir(output_path): output_filename = f'{os.path.basename(src_path_no_ext)}_{os.path.basename(tgt_path_no_ext)}.mp4' output_path = os.path.join(output_path, output_filename) # Initialize appearance map src_transform = img_lms_pose_transforms.Compose( [Rotate(), Pyramids(2), ToTensor(), Normalize()]) tgt_transform = img_lms_pose_transforms.Compose( [ToTensor(), Normalize()]) appearance_map = AppearanceMapDataset( src_vid_seq_path, tgt_vid_seq_path, src_transform, tgt_transform, self.landmarks_postfix, self.pose_postfix, self.segmentation_postfix, self.min_radius) appearance_map_loader = DataLoader(appearance_map, batch_size=self.batch_size, num_workers=1, pin_memory=True, drop_last=False, shuffle=False) # Initialize video renderer self.video_renderer.init(target_path, target_seq, output_path, _appearance_map=appearance_map) # Finetune reenactment model on source sequences if finetune: self.finetune(src_vid_seq_path, self.finetune_save) print( f'=> Face reenactment: "{src_vid_seq_name}" -> "{tgt_vid_seq_name}"...' ) # For each batch of frames in the target video for i, (src_frame, src_landmarks, src_poses, bw, tgt_frame, tgt_landmarks, tgt_pose, tgt_mask) \ in enumerate(tqdm(appearance_map_loader, unit='batches', file=sys.stdout)): # Prepare input for p in range(len(src_frame)): src_frame[p] = src_frame[p].to(self.device) tgt_landmarks = tgt_landmarks.to(self.device) bw = bw.to(self.device) bw_indices = torch.nonzero(torch.any(bw > 0, dim=0), as_tuple=True)[0] bw = bw[:, bw_indices] # For each source frame perform reenactment reenactment_triplet = [] for j in bw_indices: input = [] for p in range(len(src_frame)): context = self.landmarks_decoders[p](tgt_landmarks) input.append( torch.cat((src_frame[p][:, j], context), dim=1)) # Reenactment reenactment_triplet.append(self.Gr(input).unsqueeze(1)) reenactment_tensor = torch.cat(reenactment_triplet, dim=1) # Barycentric interpolation of reenacted frames reenactment_tensor = (reenactment_tensor * bw.view(*bw.shape, 1, 1, 1)).sum(dim=1) # Write output if self.verbose == 0: self.video_renderer.write(reenactment_tensor) elif self.verbose == 1: print( (src_frame[0][:, 0][0], reenactment_tensor[0], tgt_frame[0])) write_bgr = tensor2bgr( torch.cat((src_frame[0][:, 0][0], reenactment_tensor[0], tgt_frame[0]), dim=2)) cv2.imwrite(fr'{output_path}.jpg', write_bgr) self.video_renderer.write(src_frame[0][:, 0], reenactment_tensor, tgt_frame) else: self.video_renderer.write(src_frame[0][:, 0], src_frame[0][:, 1], src_frame[0][:, 2], reenactment_tensor, tgt_frame, tgt_pose) # Load original reenactment weights if finetune: if self.gpus and len(self.gpus) > 1: self.Gr.module.load_state_dict(self.reenactment_state_dict) else: self.Gr.load_state_dict(self.reenactment_state_dict) # Wait for the video render to finish rendering self.video_renderer.finalize() self.video_renderer.wait_until_finished()
if __name__ == '__main__': exp_name = os.path.splitext(os.path.basename(__file__))[0] exp_dir = os.path.join('../results/reenactment', exp_name) root = '/data/datasets/nirkin_face_videos' train_dataset = partial(SeqPairDataset, root, 'videos_train.txt', postfixes=('.mp4', '_lms.npz'), same_prob=1.0) val_dataset = partial(SeqPairDataset, root, 'videos_val.txt', postfixes=('.mp4', '_lms.npz'), same_prob=1.0) numpy_transforms = [RandomHorizontalFlip(), Pyramids(2)] resolutions = [128, 256] lr_gen = [1e-4, 4e-5] lr_dis = [1e-5, 4e-6] epochs = [24, 50] iterations = ['20k'] batch_size = [48, 24] workers = 32 pretrained = False criterion_id = VGGLoss('../../weights/vggface2_vgg19_256_1_2_id.pth') criterion_attr = VGGLoss('../../weights/celeba_vgg19_256_2_0_28_attr.pth') criterion_gan = GANLoss(use_lsgan=True) generator = MultiScaleResUNet(in_nc=101, out_nc=3, flat_layers=(2, 2, 2, 2), ngf=128)