Пример #1
0
    def finetune(self, source_path, save_checkpoint=True):
        checkpoint_path = os.path.splitext(source_path)[0] + '_Gr.pth'
        if os.path.isfile(checkpoint_path):
            print('=> Loading the reenactment generator finetuned on: "%s"...' % os.path.basename(source_path))
            checkpoint = torch.load(checkpoint_path)
            if self.gpus and len(self.gpus) > 1:
                self.Gr.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.Gr.load_state_dict(checkpoint['state_dict'])
            return

        print('=> Finetuning the reenactment generator on: "%s"...' % os.path.basename(source_path))
        torch.set_grad_enabled(True)
        self.Gr.train(True)
        img_transforms = img_lms_pose_transforms.Compose([Pyramids(2), ToTensor(), Normalize()])
        train_dataset = SingleSeqRandomPairDataset(source_path, transform=img_transforms, postfixes=('_lms.npz',))
        train_sampler = RandomSampler(train_dataset, replacement=True, num_samples=self.finetune_iterations)
        train_loader = DataLoader(train_dataset, batch_size=self.finetune_batch_size, sampler=train_sampler,
                                  num_workers=self.finetune_workers, pin_memory=True, drop_last=True, shuffle=False)
        optimizer = optim.Adam(self.Gr.parameters(), lr=self.finetune_lr, betas=(0.5, 0.999))

        # For each batch in the training data
        for i, (img, landmarks) in enumerate(tqdm(train_loader, unit='batches', file=sys.stdout)):
            # Prepare input
            with torch.no_grad():
                # For each view images and landmarks
                landmarks[1] = landmarks[1].to(self.device)
                for j in range(len(img)):
                    # For each pyramid image: push to device
                    for p in range(len(img[j])):
                        img[j][p] = img[j][p].to(self.device)

                # Concatenate pyramid images with context to derive the final input
                input = []
                for p in range(len(img[0])):
                    context = self.landmarks_decoders[p](landmarks[1])
                    input.append(torch.cat((img[0][p], context), dim=1))

            # Reenactment
            img_pred = self.Gr(input)

            # Reconstruction loss
            loss_pixelwise = self.criterion_pixelwise(img_pred, img[1][0])
            loss_id = self.criterion_id(img_pred, img[1][0])
            loss_rec = 0.1 * loss_pixelwise + loss_id

            # Update generator weights
            optimizer.zero_grad()
            loss_rec.backward()
            optimizer.step()

        # Save finetuned weights to file
        if save_checkpoint:
            arch = self.Gr.module.arch if self.gpus and len(self.gpus) > 1 else self.Gr.arch
            state_dict = self.Gr.module.state_dict() if self.gpus and len(self.gpus) > 1 else self.Gr.state_dict()
            torch.save({'state_dict': state_dict, 'arch': arch}, checkpoint_path)

        torch.set_grad_enabled(False)
        self.Gr.train(False)
Пример #2
0
    def __call__(self,
                 source_path,
                 target_path,
                 output_path=None,
                 select_source='longest',
                 select_target='longest',
                 finetune=None):
        is_vid = os.path.splitext(source_path)[1] == '.mp4'
        finetune = self.finetune_enabled and is_vid if finetune is None else finetune and is_vid

        # Validation
        assert os.path.isfile(
            source_path), 'Source path "%s" does not exist' % source_path
        assert os.path.isfile(
            target_path), 'Target path "%s" does not exist' % target_path

        # Cache input
        source_cache_dir, source_seq_file_path, _ = self.cache(source_path)
        target_cache_dir, target_seq_file_path, _ = self.cache(target_path)

        # Load sequences from file
        with open(source_seq_file_path, "rb") as fp:  # Unpickling
            source_seq_list = pickle.load(fp)
        with open(target_seq_file_path, "rb") as fp:  # Unpickling
            target_seq_list = pickle.load(fp)

        # Select source and target sequence
        source_seq = select_seq(source_seq_list, select_source)
        target_seq = select_seq(target_seq_list, select_target)

        # Set source and target sequence videos paths
        src_path_no_ext, src_ext = os.path.splitext(source_path)
        src_vid_seq_name = os.path.basename(
            src_path_no_ext) + '_seq%02d%s' % (source_seq.id, src_ext)
        src_vid_seq_path = os.path.join(source_cache_dir, src_vid_seq_name)
        tgt_path_no_ext, tgt_ext = os.path.splitext(target_path)
        tgt_vid_seq_name = os.path.basename(
            tgt_path_no_ext) + '_seq%02d%s' % (target_seq.id, tgt_ext)
        tgt_vid_seq_path = os.path.join(target_cache_dir, tgt_vid_seq_name)

        # Set output path
        if output_path is not None:
            if os.path.isdir(output_path):
                output_filename = f'{os.path.basename(src_path_no_ext)}_{os.path.basename(tgt_path_no_ext)}.mp4'
                output_path = os.path.join(output_path, output_filename)

        # Initialize appearance map
        src_transform = img_lms_pose_transforms.Compose(
            [Rotate(), Pyramids(2),
             ToTensor(), Normalize()])
        tgt_transform = img_lms_pose_transforms.Compose(
            [ToTensor(), Normalize()])
        appearance_map = AppearanceMapDataset(
            src_vid_seq_path, tgt_vid_seq_path, src_transform, tgt_transform,
            self.landmarks_postfix, self.pose_postfix,
            self.segmentation_postfix, self.min_radius)
        appearance_map_loader = DataLoader(appearance_map,
                                           batch_size=self.batch_size,
                                           num_workers=1,
                                           pin_memory=True,
                                           drop_last=False,
                                           shuffle=False)

        # Initialize video writer
        self.video_renderer.init(target_path,
                                 target_seq,
                                 output_path,
                                 _appearance_map=appearance_map)

        # Finetune reenactment model on source sequences
        if finetune:
            self.finetune(src_vid_seq_path, self.finetune_save)

        print(
            f'=> Face swapping: "{src_vid_seq_name}" -> "{tgt_vid_seq_name}"...'
        )

        # For each batch of frames in the target video
        for i, (src_frame, src_landmarks, src_poses, bw, tgt_frame, tgt_landmarks, tgt_pose, tgt_mask) \
                in enumerate(tqdm(appearance_map_loader, unit='batches', file=sys.stdout)):
            # Prepare input
            for p in range(len(src_frame)):
                src_frame[p] = src_frame[p].to(self.device)
            tgt_frame = tgt_frame.to(self.device)
            tgt_landmarks = tgt_landmarks.to(self.device)
            # tgt_mask = tgt_mask.unsqueeze(1).to(self.device)
            tgt_mask = tgt_mask.unsqueeze(1).int().to(self.device).bool(
            )  # TODO: check if the boolean tensor bug is fixed
            bw = bw.to(self.device)
            bw_indices = torch.nonzero(torch.any(bw > 0, dim=0),
                                       as_tuple=True)[0]
            bw = bw[:, bw_indices]

            # For each source frame perform reenactment
            reenactment_triplet = []
            for j in bw_indices:
                input = []
                for p in range(len(src_frame)):
                    context = self.landmarks_decoders[p](tgt_landmarks)
                    input.append(
                        torch.cat((src_frame[p][:, j], context), dim=1))

                # Reenactment
                reenactment_triplet.append(self.Gr(input).unsqueeze(1))
            reenactment_tensor = torch.cat(reenactment_triplet, dim=1)

            # Barycentric interpolation of reenacted frames
            reenactment_tensor = (reenactment_tensor *
                                  bw.view(*bw.shape, 1, 1, 1)).sum(dim=1)

            # Compute reenactment segmentation
            reenactment_seg = self.S(reenactment_tensor)
            reenactment_background_mask_tensor = (reenactment_seg.argmax(1) !=
                                                  1).unsqueeze(1)

            # Remove the background of the aligned face
            reenactment_tensor.masked_fill_(reenactment_background_mask_tensor,
                                            -1.0)

            # Soften target mask
            soft_tgt_mask, eroded_tgt_mask = self.smooth_mask(tgt_mask)

            # Complete face
            inpainting_input_tensor = torch.cat(
                (reenactment_tensor, eroded_tgt_mask.float()), dim=1)
            inpainting_input_tensor_pyd = create_pyramid(
                inpainting_input_tensor, 2)
            completion_tensor = self.Gc(inpainting_input_tensor_pyd)

            # Blend faces
            transfer_tensor = transfer_mask(completion_tensor, tgt_frame,
                                            eroded_tgt_mask)
            blend_input_tensor = torch.cat(
                (transfer_tensor, tgt_frame, eroded_tgt_mask.float()), dim=1)
            blend_input_tensor_pyd = create_pyramid(blend_input_tensor, 2)
            blend_tensor = self.Gb(blend_input_tensor_pyd)

            result_tensor = blend_tensor * soft_tgt_mask + tgt_frame * (
                1 - soft_tgt_mask)

            # Write output
            if self.verbose == 0:
                self.video_renderer.write(result_tensor)
            elif self.verbose == 1:
                curr_src_frames = [
                    src_frame[0][:, i] for i in range(src_frame[0].shape[1])
                ]
                self.video_renderer.write(*curr_src_frames, result_tensor,
                                          tgt_frame)
            else:
                curr_src_frames = [
                    src_frame[0][:, i] for i in range(src_frame[0].shape[1])
                ]
                tgt_seg_blend = blend_seg_label(tgt_frame,
                                                tgt_mask.squeeze(1),
                                                alpha=0.2)
                soft_tgt_mask = soft_tgt_mask.mul(2.).sub(1.).repeat(
                    1, 3, 1, 1)
                self.video_renderer.write(*curr_src_frames, result_tensor,
                                          tgt_frame, reenactment_tensor,
                                          completion_tensor, transfer_tensor,
                                          soft_tgt_mask, tgt_seg_blend,
                                          tgt_pose)

        # Load original reenactment weights
        if finetune:
            if self.gpus and len(self.gpus) > 1:
                self.Gr.module.load_state_dict(self.reenactment_state_dict)
            else:
                self.Gr.load_state_dict(self.reenactment_state_dict)

        # Finalize video and wait for the video writer to finish writing
        self.video_renderer.finalize()
        self.video_renderer.wait_until_finished()
Пример #3
0
    def __call__(self,
                 source_path,
                 target_path,
                 output_path=None,
                 select_source='longest',
                 select_target='longest',
                 finetune=None):
        is_vid = os.path.splitext(source_path)[1] == '.mp4'
        finetune = self.finetune_enabled and is_vid if finetune is None else finetune and is_vid

        # Validation
        assert os.path.isfile(
            source_path), 'Source path "%s" does not exist' % source_path
        assert os.path.isfile(
            target_path), 'Target path "%s" does not exist' % target_path

        # Cache input
        source_cache_dir, source_seq_file_path, _ = self.cache(source_path)
        target_cache_dir, target_seq_file_path, _ = self.cache(target_path)

        # Load sequences from file
        with open(source_seq_file_path, "rb") as fp:  # Unpickling
            source_seq_list = pickle.load(fp)
        with open(target_seq_file_path, "rb") as fp:  # Unpickling
            target_seq_list = pickle.load(fp)

        # Select source and target sequence
        source_seq = select_seq(source_seq_list, select_source)
        target_seq = select_seq(target_seq_list, select_target)

        # Set source and target sequence videos paths
        src_path_no_ext, src_ext = os.path.splitext(source_path)
        src_vid_seq_name = os.path.basename(
            src_path_no_ext) + '_seq%02d%s' % (source_seq.id, src_ext)
        src_vid_seq_path = os.path.join(source_cache_dir, src_vid_seq_name)
        tgt_path_no_ext, tgt_ext = os.path.splitext(target_path)
        tgt_vid_seq_name = os.path.basename(
            tgt_path_no_ext) + '_seq%02d%s' % (target_seq.id, tgt_ext)
        tgt_vid_seq_path = os.path.join(target_cache_dir, tgt_vid_seq_name)

        # Set output path
        if output_path is not None:
            if os.path.isdir(output_path):
                output_filename = f'{os.path.basename(src_path_no_ext)}_{os.path.basename(tgt_path_no_ext)}.mp4'
                output_path = os.path.join(output_path, output_filename)

        # Initialize appearance map
        src_transform = img_lms_pose_transforms.Compose(
            [Rotate(), Pyramids(2),
             ToTensor(), Normalize()])
        tgt_transform = img_lms_pose_transforms.Compose(
            [ToTensor(), Normalize()])
        appearance_map = AppearanceMapDataset(
            src_vid_seq_path, tgt_vid_seq_path, src_transform, tgt_transform,
            self.landmarks_postfix, self.pose_postfix,
            self.segmentation_postfix, self.min_radius)
        appearance_map_loader = DataLoader(appearance_map,
                                           batch_size=self.batch_size,
                                           num_workers=1,
                                           pin_memory=True,
                                           drop_last=False,
                                           shuffle=False)

        # Initialize video renderer
        self.video_renderer.init(target_path,
                                 target_seq,
                                 output_path,
                                 _appearance_map=appearance_map)

        # Finetune reenactment model on source sequences
        if finetune:
            self.finetune(src_vid_seq_path, self.finetune_save)

        print(
            f'=> Face reenactment: "{src_vid_seq_name}" -> "{tgt_vid_seq_name}"...'
        )

        # For each batch of frames in the target video
        for i, (src_frame, src_landmarks, src_poses, bw, tgt_frame, tgt_landmarks, tgt_pose, tgt_mask) \
                in enumerate(tqdm(appearance_map_loader, unit='batches', file=sys.stdout)):
            # Prepare input
            for p in range(len(src_frame)):
                src_frame[p] = src_frame[p].to(self.device)
            tgt_landmarks = tgt_landmarks.to(self.device)
            bw = bw.to(self.device)
            bw_indices = torch.nonzero(torch.any(bw > 0, dim=0),
                                       as_tuple=True)[0]
            bw = bw[:, bw_indices]

            # For each source frame perform reenactment
            reenactment_triplet = []
            for j in bw_indices:
                input = []
                for p in range(len(src_frame)):
                    context = self.landmarks_decoders[p](tgt_landmarks)
                    input.append(
                        torch.cat((src_frame[p][:, j], context), dim=1))

                # Reenactment
                reenactment_triplet.append(self.Gr(input).unsqueeze(1))
            reenactment_tensor = torch.cat(reenactment_triplet, dim=1)

            # Barycentric interpolation of reenacted frames
            reenactment_tensor = (reenactment_tensor *
                                  bw.view(*bw.shape, 1, 1, 1)).sum(dim=1)

            # Write output
            if self.verbose == 0:
                self.video_renderer.write(reenactment_tensor)
            elif self.verbose == 1:
                print(
                    (src_frame[0][:,
                                  0][0], reenactment_tensor[0], tgt_frame[0]))
                write_bgr = tensor2bgr(
                    torch.cat((src_frame[0][:, 0][0], reenactment_tensor[0],
                               tgt_frame[0]),
                              dim=2))
                cv2.imwrite(fr'{output_path}.jpg', write_bgr)
                self.video_renderer.write(src_frame[0][:, 0],
                                          reenactment_tensor, tgt_frame)
            else:
                self.video_renderer.write(src_frame[0][:, 0], src_frame[0][:,
                                                                           1],
                                          src_frame[0][:,
                                                       2], reenactment_tensor,
                                          tgt_frame, tgt_pose)

        # Load original reenactment weights
        if finetune:
            if self.gpus and len(self.gpus) > 1:
                self.Gr.module.load_state_dict(self.reenactment_state_dict)
            else:
                self.Gr.load_state_dict(self.reenactment_state_dict)

        # Wait for the video render to finish rendering
        self.video_renderer.finalize()
        self.video_renderer.wait_until_finished()
Пример #4
0
if __name__ == '__main__':
    exp_name = os.path.splitext(os.path.basename(__file__))[0]
    exp_dir = os.path.join('../results/reenactment', exp_name)
    root = '/data/datasets/nirkin_face_videos'
    train_dataset = partial(SeqPairDataset,
                            root,
                            'videos_train.txt',
                            postfixes=('.mp4', '_lms.npz'),
                            same_prob=1.0)
    val_dataset = partial(SeqPairDataset,
                          root,
                          'videos_val.txt',
                          postfixes=('.mp4', '_lms.npz'),
                          same_prob=1.0)
    numpy_transforms = [RandomHorizontalFlip(), Pyramids(2)]
    resolutions = [128, 256]
    lr_gen = [1e-4, 4e-5]
    lr_dis = [1e-5, 4e-6]
    epochs = [24, 50]
    iterations = ['20k']
    batch_size = [48, 24]
    workers = 32
    pretrained = False
    criterion_id = VGGLoss('../../weights/vggface2_vgg19_256_1_2_id.pth')
    criterion_attr = VGGLoss('../../weights/celeba_vgg19_256_2_0_28_attr.pth')
    criterion_gan = GANLoss(use_lsgan=True)
    generator = MultiScaleResUNet(in_nc=101,
                                  out_nc=3,
                                  flat_layers=(2, 2, 2, 2),
                                  ngf=128)