Beispiel #1
0
    def on_render(self, *args):
        if self._verbose <= 0:
            return tensor2bgr(args[0])
        elif self._verbose == 1:
            return tensor2bgr(torch.cat(args, dim=2))
        else:
            if self._fig is None:
                self._fig = plt.figure(figsize=self._figsize)
            results_bgr1 = tensor2bgr(torch.cat(args[:5], dim=1))
            results_bgr2 = tensor2bgr(torch.cat(args[5:10], dim=1))
            tgt_pose = args[-1].numpy()
            appearance_map_bgr = render_appearance_map(
                self._fig, self._appearance_map.tri,
                self._appearance_map.points, tgt_pose[:2])
            appearance_map_bgr = cv2.resize(appearance_map_bgr,
                                            self._appearance_map_size,
                                            interpolation=cv2.INTER_CUBIC)
            render_bgr = np.concatenate(
                (appearance_map_bgr, results_bgr1, results_bgr2), axis=1)
            tgt_pose *= 99.  # Unnormalize the target pose for printing
            msg = 'Pose: %.1f, %.1f, %.1f' % (tgt_pose[0], tgt_pose[1],
                                              tgt_pose[2])
            cv2.putText(render_bgr, msg, (10, 20), cv2.FONT_HERSHEY_SIMPLEX,
                        0.5, (0, 0, 255), 1, cv2.LINE_AA)

            return render_bgr
Beispiel #2
0
 def on_render(self, *args):
     if self._verbose <= 0:
         write_bgr = tensor2bgr(args[0])
         # return write_bgr
     elif self._verbose == 1:
         print(args)
         write_bgr = tensor2bgr(torch.cat(args, dim=2))
         # return write_bgr
     else:
         if self._fig is None:
             self._fig = plt.figure(figsize=self._figsize)
         results_bgr = tensor2bgr(torch.cat(args[:5], dim=1))
         tgt_pose = args[5].numpy()
         appearance_map_bgr = render_appearance_map(
             self._fig, self._appearance_map.tri,
             self._appearance_map.points, tgt_pose[:2])
         appearance_map_bgr = cv2.resize(appearance_map_bgr,
                                         self._appearance_map_size,
                                         interpolation=cv2.INTER_CUBIC)
         render_bgr = np.concatenate((appearance_map_bgr, results_bgr),
                                     axis=1)
         tgt_pose *= 99.  # Unnormalize the target pose for printing
         msg = 'Pose: %.1f, %.1f, %.1f' % (tgt_pose[0], tgt_pose[1],
                                           tgt_pose[2])
         cv2.putText(render_bgr, msg, (10, 20), cv2.FONT_HERSHEY_SIMPLEX,
                     0.5, (0, 0, 255), 1, cv2.LINE_AA)
         write_bgr = render_bgr
     cv2.imwrite(r'C:\Users\zenbook\Documents\examples\result.jpg',
                 write_bgr)
     return write_bgr
Beispiel #3
0
def main(dataset='fake_detection.datasets.image_list_dataset.ImageListDataset',
         np_transforms=None,
         tensor_transforms=(
             'img_landmarks_transforms.ToTensor()',
             'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'),
         workers=4,
         batch_size=4):
    import time
    from fsgan.utils.obj_factory import obj_factory
    from fsgan.utils.img_utils import tensor2bgr

    np_transforms = obj_factory(
        np_transforms) if np_transforms is not None else []
    tensor_transforms = obj_factory(
        tensor_transforms) if tensor_transforms is not None else []
    img_transforms = img_landmarks_transforms.Compose(np_transforms +
                                                      tensor_transforms)
    dataset = obj_factory(dataset, transform=img_transforms)
    dataloader = data.DataLoader(dataset,
                                 batch_size=4,
                                 num_workers=workers,
                                 pin_memory=True,
                                 drop_last=True,
                                 shuffle=True)

    start = time.time()
    if isinstance(dataset, ImageListDataset):
        for img, target in dataloader:
            print(img.shape)
            print(target)

            # For each batch
            for b in range(img.shape[0]):
                render_img = tensor2bgr(img[b]).copy()
                cv2.imshow('render_img', render_img)
                if cv2.waitKey(0) & 0xFF == ord('q'):
                    break
    else:
        for img1, img2, target in dataloader:
            print(img1.shape)
            print(img2.shape)
            print(target)

            # For each batch
            for b in range(target.shape[0]):
                left_img = tensor2bgr(img1[b]).copy()
                right_img = tensor2bgr(img2[b]).copy()
                render_img = np.concatenate((left_img, right_img), axis=1)
                cv2.imshow('render_img', render_img)
                if cv2.waitKey(0) & 0xFF == ord('q'):
                    break
    end = time.time()
    print('elapsed time: %f[s]' % (end - start))
Beispiel #4
0
    def on_render(self, *args):
        """ Given the input tensors this method produces a cropped rendered image.

        This method should be overridden by inheriting classes to customize the rendering. By default this method
        expects the first tensor to be a cropped image tensor of shape (B, 3, H, W) where B is the batch size,
        H is the height of the image and W is the width of the image.

        Args:
            *args (tuple of torch.Tensor): The tensors for rendering

        Returns:
            render_bgr (np.array): The cropped rendered image
        """
        return tensor2bgr(args[0])
Beispiel #5
0
def main(dataset='fsgan.datasets.image_seg_dataset.ImageSegDataset',
         np_transforms1=None,
         np_transforms2=None,
         tensor_transforms1=(
             'img_landmarks_transforms.ToTensor()',
             'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'),
         tensor_transforms2=('img_landmarks_transforms.ToTensor()', ),
         workers=4,
         batch_size=4):
    import time
    from fsgan.utils.obj_factory import obj_factory
    from fsgan.utils.seg_utils import blend_seg_pred, blend_seg_label
    from fsgan.utils.img_utils import tensor2bgr

    np_transforms1 = obj_factory(
        np_transforms1) if np_transforms1 is not None else []
    tensor_transforms1 = obj_factory(
        tensor_transforms1) if tensor_transforms1 is not None else []
    img_transforms1 = img_landmarks_transforms.Compose(np_transforms1 +
                                                       tensor_transforms1)
    np_transforms2 = obj_factory(
        np_transforms2) if np_transforms2 is not None else []
    tensor_transforms2 = obj_factory(
        tensor_transforms2) if tensor_transforms2 is not None else []
    img_transforms2 = img_landmarks_transforms.Compose(np_transforms2 +
                                                       tensor_transforms2)
    dataset = obj_factory(dataset,
                          transform=img_transforms1,
                          target_transform=img_transforms2)
    dataloader = data.DataLoader(dataset,
                                 batch_size=batch_size,
                                 num_workers=workers,
                                 pin_memory=True,
                                 drop_last=True,
                                 shuffle=True)

    start = time.time()
    for img, seg in dataloader:
        # For each batch
        for b in range(img.shape[0]):
            blend_tensor = blend_seg_pred(img, seg)
            render_img = tensor2bgr(blend_tensor[b])
            # render_img = tensor2bgr(img[b])
            cv2.imshow('render_img', render_img)
            if cv2.waitKey(0) & 0xFF == ord('q'):
                break
    end = time.time()
    print('elapsed time: %f[s]' % (end - start))
Beispiel #6
0
    def __call__(self,
                 source_path,
                 target_path,
                 output_path=None,
                 select_source='longest',
                 select_target='longest',
                 finetune=None):
        is_vid = os.path.splitext(source_path)[1] == '.mp4'
        finetune = self.finetune_enabled and is_vid if finetune is None else finetune and is_vid

        # Validation
        assert os.path.isfile(
            source_path), 'Source path "%s" does not exist' % source_path
        assert os.path.isfile(
            target_path), 'Target path "%s" does not exist' % target_path

        # Cache input
        source_cache_dir, source_seq_file_path, _ = self.cache(source_path)
        target_cache_dir, target_seq_file_path, _ = self.cache(target_path)

        # Load sequences from file
        with open(source_seq_file_path, "rb") as fp:  # Unpickling
            source_seq_list = pickle.load(fp)
        with open(target_seq_file_path, "rb") as fp:  # Unpickling
            target_seq_list = pickle.load(fp)

        # Select source and target sequence
        source_seq = select_seq(source_seq_list, select_source)
        target_seq = select_seq(target_seq_list, select_target)

        # Set source and target sequence videos paths
        src_path_no_ext, src_ext = os.path.splitext(source_path)
        src_vid_seq_name = os.path.basename(
            src_path_no_ext) + '_seq%02d%s' % (source_seq.id, src_ext)
        src_vid_seq_path = os.path.join(source_cache_dir, src_vid_seq_name)
        tgt_path_no_ext, tgt_ext = os.path.splitext(target_path)
        tgt_vid_seq_name = os.path.basename(
            tgt_path_no_ext) + '_seq%02d%s' % (target_seq.id, tgt_ext)
        tgt_vid_seq_path = os.path.join(target_cache_dir, tgt_vid_seq_name)

        # Set output path
        if output_path is not None:
            if os.path.isdir(output_path):
                output_filename = f'{os.path.basename(src_path_no_ext)}_{os.path.basename(tgt_path_no_ext)}.mp4'
                output_path = os.path.join(output_path, output_filename)

        # Initialize appearance map
        src_transform = img_lms_pose_transforms.Compose(
            [Rotate(), Pyramids(2),
             ToTensor(), Normalize()])
        tgt_transform = img_lms_pose_transforms.Compose(
            [ToTensor(), Normalize()])
        appearance_map = AppearanceMapDataset(
            src_vid_seq_path, tgt_vid_seq_path, src_transform, tgt_transform,
            self.landmarks_postfix, self.pose_postfix,
            self.segmentation_postfix, self.min_radius)
        appearance_map_loader = DataLoader(appearance_map,
                                           batch_size=self.batch_size,
                                           num_workers=1,
                                           pin_memory=True,
                                           drop_last=False,
                                           shuffle=False)

        # Initialize video renderer
        self.video_renderer.init(target_path,
                                 target_seq,
                                 output_path,
                                 _appearance_map=appearance_map)

        # Finetune reenactment model on source sequences
        if finetune:
            self.finetune(src_vid_seq_path, self.finetune_save)

        print(
            f'=> Face reenactment: "{src_vid_seq_name}" -> "{tgt_vid_seq_name}"...'
        )

        # For each batch of frames in the target video
        for i, (src_frame, src_landmarks, src_poses, bw, tgt_frame, tgt_landmarks, tgt_pose, tgt_mask) \
                in enumerate(tqdm(appearance_map_loader, unit='batches', file=sys.stdout)):
            # Prepare input
            for p in range(len(src_frame)):
                src_frame[p] = src_frame[p].to(self.device)
            tgt_landmarks = tgt_landmarks.to(self.device)
            bw = bw.to(self.device)
            bw_indices = torch.nonzero(torch.any(bw > 0, dim=0),
                                       as_tuple=True)[0]
            bw = bw[:, bw_indices]

            # For each source frame perform reenactment
            reenactment_triplet = []
            for j in bw_indices:
                input = []
                for p in range(len(src_frame)):
                    context = self.landmarks_decoders[p](tgt_landmarks)
                    input.append(
                        torch.cat((src_frame[p][:, j], context), dim=1))

                # Reenactment
                reenactment_triplet.append(self.Gr(input).unsqueeze(1))
            reenactment_tensor = torch.cat(reenactment_triplet, dim=1)

            # Barycentric interpolation of reenacted frames
            reenactment_tensor = (reenactment_tensor *
                                  bw.view(*bw.shape, 1, 1, 1)).sum(dim=1)

            # Write output
            if self.verbose == 0:
                self.video_renderer.write(reenactment_tensor)
            elif self.verbose == 1:
                print(
                    (src_frame[0][:,
                                  0][0], reenactment_tensor[0], tgt_frame[0]))
                write_bgr = tensor2bgr(
                    torch.cat((src_frame[0][:, 0][0], reenactment_tensor[0],
                               tgt_frame[0]),
                              dim=2))
                cv2.imwrite(fr'{output_path}.jpg', write_bgr)
                self.video_renderer.write(src_frame[0][:, 0],
                                          reenactment_tensor, tgt_frame)
            else:
                self.video_renderer.write(src_frame[0][:, 0], src_frame[0][:,
                                                                           1],
                                          src_frame[0][:,
                                                       2], reenactment_tensor,
                                          tgt_frame, tgt_pose)

        # Load original reenactment weights
        if finetune:
            if self.gpus and len(self.gpus) > 1:
                self.Gr.module.load_state_dict(self.reenactment_state_dict)
            else:
                self.Gr.load_state_dict(self.reenactment_state_dict)

        # Wait for the video render to finish rendering
        self.video_renderer.finalize()
        self.video_renderer.wait_until_finished()
Beispiel #7
0
def main(dataset='opencv_video_seq_dataset.VideoSeqDataset',
         np_transforms=None,
         tensor_transforms=(
             'img_landmarks_transforms.ToTensor()',
             'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'),
         workers=4,
         batch_size=4):
    import time
    from fsgan.utils.obj_factory import obj_factory
    from fsgan.utils.img_utils import tensor2bgr

    np_transforms = obj_factory(
        np_transforms) if np_transforms is not None else []
    tensor_transforms = obj_factory(
        tensor_transforms) if tensor_transforms is not None else []
    img_transforms = img_landmarks_transforms.Compose(np_transforms +
                                                      tensor_transforms)
    dataset = obj_factory(dataset, transform=img_transforms)
    # dataset = VideoSeqDataset(root_path, img_list_path, transform=img_transforms, frame_window=frame_window)
    dataloader = data.DataLoader(dataset,
                                 batch_size=batch_size,
                                 num_workers=workers,
                                 pin_memory=True,
                                 drop_last=True,
                                 shuffle=True)

    start = time.time()
    for frame_window, landmarks_window in dataloader:
        # print(frame_window.shape)

        if isinstance(frame_window, (list, tuple)):
            # For each batch
            for b in range(frame_window[0].shape[0]):
                # For each frame window in the list
                for p in range(len(frame_window)):
                    # For each frame in the window
                    for f in range(frame_window[p].shape[2]):
                        print(frame_window[p][b, :, f, :, :].shape)
                        # Render
                        render_img = tensor2bgr(
                            frame_window[p][b, :, f, :, :]).copy()
                        landmarks = landmarks_window[p][b, f, :, :].numpy()
                        # for point in np.round(landmarks).astype(int):
                        for point in landmarks:
                            cv2.circle(render_img, (point[0], point[1]), 2,
                                       (0, 0, 255), -1)
                        cv2.imshow('render_img', render_img)
                        if cv2.waitKey(0) & 0xFF == ord('q'):
                            break
        else:
            # For each batch
            for b in range(frame_window.shape[0]):
                # For each frame in the window
                for f in range(frame_window.shape[2]):
                    print(frame_window[b, :, f, :, :].shape)
                    # Render
                    render_img = tensor2bgr(frame_window[b, :, f, :, :]).copy()
                    landmarks = landmarks_window[b, f, :, :].numpy()
                    # for point in np.round(landmarks).astype(int):
                    for point in landmarks:
                        cv2.circle(render_img, (point[0], point[1]), 2,
                                   (0, 0, 255), -1)
                    cv2.imshow('render_img', render_img)
                    if cv2.waitKey(0) & 0xFF == ord('q'):
                        break
    end = time.time()
    print('elapsed time: %f[s]' % (end - start))
Beispiel #8
0
def main(dataset='fsgan.datasets.seq_dataset.SeqDataset', np_transforms=None,
         tensor_transforms=('img_lms_pose_transforms.ToTensor()', 'img_lms_pose_transforms.Normalize()'),
         workers=4, batch_size=4):
    import time
    import fsgan
    from fsgan.utils.obj_factory import obj_factory
    from fsgan.utils.img_utils import tensor2bgr

    np_transforms = obj_factory(np_transforms) if np_transforms is not None else []
    tensor_transforms = obj_factory(tensor_transforms) if tensor_transforms is not None else []
    img_transforms = img_lms_pose_transforms.Compose(np_transforms + tensor_transforms)
    dataset = obj_factory(dataset, transform=img_transforms)
    # dataset = VideoSeqDataset(root_path, img_list_path, transform=img_transforms, frame_window=frame_window)
    dataloader = data.DataLoader(dataset, batch_size=batch_size, num_workers=workers, pin_memory=True, drop_last=True,
                                 shuffle=True)

    start = time.time()
    if isinstance(dataset, fsgan.datasets.seq_dataset.SeqPairDataset):
        for frame, landmarks, pose, target in dataloader:
            pass
    elif isinstance(dataset, fsgan.datasets.seq_dataset.SeqDataset):
        for frame, landmarks, pose in dataloader:
            # For each batch
            for b in range(frame.shape[0]):
                # Render
                render_img = tensor2bgr(frame[b]).copy()
                curr_landmarks = landmarks[b].numpy() * render_img.shape[0]
                curr_pose = pose[b].numpy() * 99.

                for point in curr_landmarks:
                    cv2.circle(render_img, (point[0], point[1]), 2, (0, 0, 255), -1)
                msg = 'Pose: %.1f, %.1f, %.1f' % (curr_pose[0], curr_pose[1], curr_pose[2])
                cv2.putText(render_img, msg, (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA)
                cv2.imshow('render_img', render_img)
                if cv2.waitKey(0) & 0xFF == ord('q'):
                    break


        # print(frame_window.shape)

        # if isinstance(frame_window, (list, tuple)):
        #     # For each batch
        #     for b in range(frame_window[0].shape[0]):
        #         # For each frame window in the list
        #         for p in range(len(frame_window)):
        #             # For each frame in the window
        #             for f in range(frame_window[p].shape[2]):
        #                 print(frame_window[p][b, :, f, :, :].shape)
        #                 # Render
        #                 render_img = tensor2bgr(frame_window[p][b, :, f, :, :]).copy()
        #                 landmarks = landmarks_window[p][b, f, :, :].numpy()
        #                 # for point in np.round(landmarks).astype(int):
        #                 for point in landmarks:
        #                     cv2.circle(render_img, (point[0], point[1]), 2, (0, 0, 255), -1)
        #                 cv2.imshow('render_img', render_img)
        #                 if cv2.waitKey(0) & 0xFF == ord('q'):
        #                     break
        # else:
        #     # For each batch
        #     for b in range(frame_window.shape[0]):
        #         # For each frame in the window
        #         for f in range(frame_window.shape[2]):
        #             print(frame_window[b, :, f, :, :].shape)
        #             # Render
        #             render_img = tensor2bgr(frame_window[b, :, f, :, :]).copy()
        #             landmarks = landmarks_window[b, f, :, :].numpy()
        #             # for point in np.round(landmarks).astype(int):
        #             for point in landmarks:
        #                 cv2.circle(render_img, (point[0], point[1]), 2, (0, 0, 255), -1)
        #             cv2.imshow('render_img', render_img)
        #             if cv2.waitKey(0) & 0xFF == ord('q'):
        #                 break
    end = time.time()
    print('elapsed time: %f[s]' % (end - start))